This code is copied from official chainer examples
import chainer
from chainer import training
# Custom updater for truncated BackProp Through Time (BPTT)
class BPTTUpdater(training.StandardUpdater):
def __init__(self, train_iter, optimizer, bprop_len, device):
super(BPTTUpdater, self).__init__(
train_iter, optimizer, device=device)
self.bprop_len = bprop_len
# The core part of the update routine can be customized by overriding.
def update_core(self):
loss = 0
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter = self.get_iterator('main')
optimizer = self.get_optimizer('main')
# Progress the dataset iterator for bprop_len words at each iteration.
for i in range(self.bprop_len):
# Get the next batch (a list of tuples of two word IDs)
batch = train_iter.__next__()
# Concatenate the word IDs to matrices and send them to the device
# self.converter does this job
# (it is chainer.dataset.concat_examples by default)
x, t = self.converter(batch, self.device)
# Compute the loss at this time step and accumulate it
loss +=, chainer.Variable(t)) # Clear the parameter gradients
loss.backward() # Backprop
loss.unchain_backward() # Truncate the graph
optimizer.update() # Update the parameters