Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
136 lines (113 sloc) 5.06 KB
from __future__ import print_function
import argparse
import time
import os
import numpy as np
import six
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda
from chainer import computational_graph
from chainer import serializers
import mlp
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--initmodel', '-m', default='',
help='Initialize the model from given file')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result/2',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
parser.add_argument('--unit', '-u', type=int, default=50,
help='Number of units')
args = parser.parse_args()
print('GPU: {}'.format(args.gpu))
print('# unit: {}'.format(args.unit))
print('# Minibatch-size: {}'.format(args.batchsize))
print('# epoch: {}'.format(args.epoch))
# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = mlp.MLP(args.unit, 10)
classifier_model = L.Classifier(model)
if args.gpu >= 0:
chainer.cuda.get_device(args.gpu).use() # Make a specified GPU current
classifier_model.to_gpu() # Copy the model to the GPU
xp = np if args.gpu < 0 else cuda.cupy
# Setup an optimizer
optimizer = chainer.optimizers.Adam()
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()
batchsize = args.batchsize
n_epoch = args.epoch
N = len(train) # training data size
N_test = len(test) # test data size
# Init/Resume
if args.initmodel:
print('Load model from', args.initmodel)
serializers.load_npz(args.initmodel, classifier_model)
if args.resume:
print('Load optimizer state from', args.resume)
serializers.load_npz(args.resume, optimizer)
if not os.path.exists(args.out):
# Learning loop
for epoch in six.moves.range(1, n_epoch + 1):
print('epoch', epoch)
# training
perm = np.random.permutation(N)
sum_accuracy = 0
sum_loss = 0
start = time.time()
for i in six.moves.range(0, N, batchsize):
x = chainer.Variable(xp.asarray(train[perm[i:i + batchsize]][0]))
t = chainer.Variable(xp.asarray(train[perm[i:i + batchsize]][1]))
# Pass the loss function (Classifier defines it) and its arguments
optimizer.update(classifier_model, x, t)
if epoch == 1 and i == 0:
with open('{}/'.format(args.out), 'w') as o:
g = computational_graph.build_computational_graph(
print('graph generated')
sum_loss += float( * len(
sum_accuracy += float( * len(
end = time.time()
elapsed_time = end - start
throughput = N / elapsed_time
print('train mean loss={}, accuracy={}, throughput={} images/sec'.format(
sum_loss / N, sum_accuracy / N, throughput))
# evaluation
sum_accuracy = 0
sum_loss = 0
for i in six.moves.range(0, N_test, batchsize):
index = np.asarray(list(range(i, i + batchsize)))
x = chainer.Variable(xp.asarray(test[index][0]))
t = chainer.Variable(xp.asarray(test[index][1]))
with chainer.no_backprop_mode():
# When back propagation is not necessary,
# we can omit constructing graph path for better performance.
# `no_backprop_mode()` is introduced from chainer v2,
# while `volatile` flag was used in chainer v1.
loss = classifier_model(x, t)
sum_loss += float( * len(
sum_accuracy += float( * len(
print('test mean loss={}, accuracy={}'.format(
sum_loss / N_test, sum_accuracy / N_test))
# Save the model and the optimizer
print('save the model')
serializers.save_npz('{}/classifier.model'.format(args.out), classifier_model)
serializers.save_npz('{}/mlp.model'.format(args.out), model)
print('save the optimizer')
serializers.save_npz('{}/mlp.state'.format(args.out), optimizer)
if __name__ == '__main__':