-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
variable_unchain.py
34 lines (26 loc) · 1.11 KB
/
variable_unchain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from chainer import configuration
from chainer.training import extension
from chainer import variable
import six
class unchain_variables(extension.Extension):
"""Trainer extension to unchain all comptational graphs.
This extenstion unchains all comptational graphs after all extensions are
run to release memory and to avoid memory leak.
This extension can be used as a last resort when there is an extension that
use a variable graph and cannot release the graph in itself.
It observes the previous ``chainer.config.keep_graph_on_report`` flag.
The extension is triggered when the flag is turned on.
"""
priority = 0
def __init__(self):
self._prev_flag = None
def initialize(self, _):
self._prev_flag = configuration.config.keep_graph_on_report
def trigger(self, _):
flag = self._prev_flag
self._prev_flag = configuration.config.keep_graph_on_report
return flag
def __call__(self, trainer):
for var in six.itervalues(trainer.observation):
if isinstance(var, variable.Variable):
var.unchain_backward()