Skip to content

Commit

Permalink
Merge pull request #4155 from kuenishi/iterator-with
Browse files Browse the repository at this point in the history
Provide a cleaner way to collect threads and processes in MultiprocessIterator
  • Loading branch information
okuta committed Apr 3, 2018
2 parents 5f3151a + dd6e88e commit 4b9bab9
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 43 deletions.
18 changes: 18 additions & 0 deletions chainer/dataset/iterator.py
Expand Up @@ -61,6 +61,24 @@ def finalize(self):
"""
pass

def __enter__(self):
"""With statement context manager method
This method does nothing by default. Implementation may override it to
better handle the internal resources by with statement.
"""
return self

def __exit__(self, exc_type, exc_value, traceback):
"""With statement context manager method
This method does nothing by default. Implementation may override it to
better handle the internal resources by with statement.
"""
return None

def serialize(self, serializer):
"""Serializes the internal state of the iterator.
Expand Down
6 changes: 6 additions & 0 deletions chainer/iterators/multiprocess_iterator.py
Expand Up @@ -116,6 +116,12 @@ def __del__(self):

finalize = __del__

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.finalize()

def __copy__(self):
other = MultiprocessIterator(
self.dataset, self.batch_size, self.repeat, self.shuffle,
Expand Down
5 changes: 4 additions & 1 deletion chainer/iterators/multithread_iterator.py
Expand Up @@ -61,7 +61,10 @@ def reset(self):
self._next = None
self._previous_epoch_detail = None

def __del__(self):
def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.finalize()

def finalize(self):
Expand Down
85 changes: 43 additions & 42 deletions examples/mnist/train_mnist_custom_loop.py
Expand Up @@ -11,6 +11,7 @@

import chainer
from chainer.dataset import convert
from chainer.iterators import MultiprocessIterator
import chainer.links as L
from chainer import serializers

Expand Down Expand Up @@ -62,48 +63,48 @@ def main():
train_count = len(train)
test_count = len(test)

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)

sum_accuracy = 0
sum_loss = 0

while train_iter.epoch < args.epoch:
batch = train_iter.next()
x_array, t_array = convert.concat_examples(batch, args.gpu)
x = chainer.Variable(x_array)
t = chainer.Variable(t_array)
optimizer.update(model, x, t)
sum_loss += float(model.loss.data) * len(t.data)
sum_accuracy += float(model.accuracy.data) * len(t.data)

if train_iter.is_new_epoch:
print('epoch: {}'.format(train_iter.epoch))
print('train mean loss: {}, accuracy: {}'.format(
sum_loss / train_count, sum_accuracy / train_count))
# evaluation
sum_accuracy = 0
sum_loss = 0
for batch in test_iter:
x_array, t_array = convert.concat_examples(batch, args.gpu)
x = chainer.Variable(x_array)
t = chainer.Variable(t_array)
loss = model(x, t)
sum_loss += float(loss.data) * len(t.data)
sum_accuracy += float(model.accuracy.data) * len(t.data)

test_iter.reset()
print('test mean loss: {}, accuracy: {}'.format(
sum_loss / test_count, sum_accuracy / test_count))
sum_accuracy = 0
sum_loss = 0

# Save the model and the optimizer
print('save the model')
serializers.save_npz('{}/mlp.model'.format(args.out), model)
print('save the optimizer')
serializers.save_npz('{}/mlp.state'.format(args.out), optimizer)
with MultiprocessIterator(train, args.batchsize) as train_iter, \
MultiprocessIterator(test, args.batchsize,
repeat=False, shuffle=False) as test_iter:

sum_accuracy = 0
sum_loss = 0

while train_iter.epoch < args.epoch:
batch = train_iter.next()
x_array, t_array = convert.concat_examples(batch, args.gpu)
x = chainer.Variable(x_array)
t = chainer.Variable(t_array)
optimizer.update(model, x, t)
sum_loss += float(model.loss.data) * len(t.data)
sum_accuracy += float(model.accuracy.data) * len(t.data)

if train_iter.is_new_epoch:
print('epoch: {}'.format(train_iter.epoch))
print('train mean loss: {}, accuracy: {}'.format(
sum_loss / train_count, sum_accuracy / train_count))
# evaluation
sum_accuracy = 0
sum_loss = 0
for batch in test_iter:
x_array, t_array = convert.concat_examples(batch, args.gpu)
x = chainer.Variable(x_array)
t = chainer.Variable(t_array)
loss = model(x, t)
sum_loss += float(loss.data) * len(t.data)
sum_accuracy += float(model.accuracy.data) * len(t.data)

test_iter.reset()
print('test mean loss: {}, accuracy: {}'.format(
sum_loss / test_count, sum_accuracy / test_count))
sum_accuracy = 0
sum_loss = 0

# Save the model and the optimizer
print('save the model')
serializers.save_npz('{}/mlp.model'.format(args.out), model)
print('save the optimizer')
serializers.save_npz('{}/mlp.state'.format(args.out), optimizer)


if __name__ == '__main__':
Expand Down

0 comments on commit 4b9bab9

Please sign in to comment.