Skip to content

Commit

Permalink
Fix iterator syntax in MNIST custom loop example
Browse files Browse the repository at this point in the history
  • Loading branch information
niboshi committed Jan 27, 2019
1 parent 8a2f05a commit 2a2a1ff
Showing 1 changed file with 45 additions and 46 deletions.
91 changes: 45 additions & 46 deletions examples/mnist/train_mnist_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import chainer
from chainer import configuration
from chainer.dataset import convert
from chainer.iterators import SerialIterator
import chainer.links as L
from chainer import serializers

Expand Down Expand Up @@ -91,51 +90,51 @@ def main():
train_count = len(train)
test_count = len(test)

with SerialIterator(train, args.batchsize) as train_iter, \
SerialIterator(
test, args.batchsize, repeat=False, shuffle=False) as test_iter:

sum_accuracy = 0
sum_loss = 0

while train_iter.epoch < args.epoch:
batch = train_iter.next()
x, t = convert.concat_examples(batch, device)
optimizer.update(model, x, t)
sum_loss += float(model.loss.array) * len(t)
sum_accuracy += float(model.accuracy.array) * len(t)

if train_iter.is_new_epoch:
print('epoch: {}'.format(train_iter.epoch))
print('train mean loss: {}, accuracy: {}'.format(
sum_loss / train_count, sum_accuracy / train_count))
# evaluation
sum_accuracy = 0
sum_loss = 0
# Enable evaluation mode.
with configuration.using_config('train', False):
# This is optional but can reduce computational overhead.
with chainer.using_config('enable_backprop', False):
for batch in test_iter:
x, t = convert.concat_examples(batch, device)
loss = model(x, t)
sum_loss += float(loss.array) * len(t)
sum_accuracy += float(
model.accuracy.array) * len(t)

test_iter.reset()
print('test mean loss: {}, accuracy: {}'.format(
sum_loss / test_count, sum_accuracy / test_count))
sum_accuracy = 0
sum_loss = 0

# Save the model and the optimizer
if not os.path.exists(args.out):
os.makedirs(args.out)
print('save the model')
serializers.save_npz('{}/mlp.model'.format(args.out), model)
print('save the optimizer')
serializers.save_npz('{}/mlp.state'.format(args.out), optimizer)
train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
repeat=False, shuffle=False)

sum_accuracy = 0
sum_loss = 0

while train_iter.epoch < args.epoch:
batch = train_iter.next()
x, t = convert.concat_examples(batch, device)
optimizer.update(model, x, t)
sum_loss += float(model.loss.array) * len(t)
sum_accuracy += float(model.accuracy.array) * len(t)

if train_iter.is_new_epoch:
print('epoch: {}'.format(train_iter.epoch))
print('train mean loss: {}, accuracy: {}'.format(
sum_loss / train_count, sum_accuracy / train_count))
# evaluation
sum_accuracy = 0
sum_loss = 0
# Enable evaluation mode.
with configuration.using_config('train', False):
# This is optional but can reduce computational overhead.
with chainer.using_config('enable_backprop', False):
for batch in test_iter:
x, t = convert.concat_examples(batch, device)
loss = model(x, t)
sum_loss += float(loss.array) * len(t)
sum_accuracy += float(
model.accuracy.array) * len(t)

test_iter.reset()
print('test mean loss: {}, accuracy: {}'.format(
sum_loss / test_count, sum_accuracy / test_count))
sum_accuracy = 0
sum_loss = 0

# Save the model and the optimizer
if not os.path.exists(args.out):
os.makedirs(args.out)
print('save the model')
serializers.save_npz('{}/mlp.model'.format(args.out), model)
print('save the optimizer')
serializers.save_npz('{}/mlp.state'.format(args.out), optimizer)


if __name__ == '__main__':
Expand Down

0 comments on commit 2a2a1ff

Please sign in to comment.