Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
42 lines (33 sloc) 1.65 KB
"""
This code is copied from official chainer examples
- https://github.com/chainer/chainer/blob/e2fe6f8023e635f8c1fc9c89e85d075ebd50c529/examples/ptb/train_ptb.py
"""
import chainer
from chainer import training
# Custom updater for truncated BackProp Through Time (BPTT)
class BPTTUpdater(training.StandardUpdater):
def __init__(self, train_iter, optimizer, bprop_len, device):
super(BPTTUpdater, self).__init__(
train_iter, optimizer, device=device)
self.bprop_len = bprop_len
# The core part of the update routine can be customized by overriding.
def update_core(self):
loss = 0
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter = self.get_iterator('main')
optimizer = self.get_optimizer('main')
# Progress the dataset iterator for bprop_len words at each iteration.
for i in range(self.bprop_len):
# Get the next batch (a list of tuples of two word IDs)
batch = train_iter.__next__()
# Concatenate the word IDs to matrices and send them to the device
# self.converter does this job
# (it is chainer.dataset.concat_examples by default)
x, t = self.converter(batch, self.device)
# Compute the loss at this time step and accumulate it
loss += optimizer.target(chainer.Variable(x), chainer.Variable(t))
optimizer.target.cleargrads() # Clear the parameter gradients
loss.backward() # Backprop
loss.unchain_backward() # Truncate the graph
optimizer.update() # Update the parameters