In [1]:
import re

import numpy

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
import chainerx

In [2]:
# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def forward(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

In [3]:
def parse_device(args):
    gpu = None
    if args.gpu is not None:
        gpu = args.gpu
    elif re.match(r'(-|\+|)[0-9]+$', args.device):
        gpu = int(args.device)

    if gpu is not None:
        if gpu < 0:
            return chainer.get_device(numpy)
        else:
            import cupy
            return chainer.get_device((cupy, gpu))

    return chainer.get_device(args.device)

In [4]:
# JupyterはArgmentPraserが使えないので、Dummyクラスで定義する
class ArgDummy:
    def __init__(self):
        self.batchsize = 100  # Mumber of images in each mini-batch
        self.epoch = 10  # Number of sweeps over the dataset to train
        self.frequency = -1  # Frequency of taking a snapshot
        self.device = "-1"  # Device specifier. Either ChainerX device
        self.out = "result"  # Directory to output the result
        self.resume = False  # Resume the training from snapshot
        self.unit = 1000  # Number of units
        self.plot = False  # Disable PlotReport extension
        self.gpu = -1 # GPU ID (negative value indicates CPU)
        
args = ArgDummy()
device = parse_device(args)

# Set up a neural network to train
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
model = L.Classifier(MLP(args.unit, 10))
model.to_device(device)
device.use()

# Setup an optimizer
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                repeat=False, shuffle=False)

# Set up a trainer
updater = training.updaters.StandardUpdater(
    train_iter, optimizer, device=device)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

# Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(test_iter, model, device=device))

# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
# TODO(niboshi): Temporarily disabled for chainerx. Fix it.
if device.xp is not chainerx:
    trainer.extend(extensions.dump_graph('main/loss'))

# Take a snapshot for each specified epoch
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())

# Save two plot images to the result dir
if args.plot and extensions.PlotReport.available():
    trainer.extend(
        extensions.PlotReport(['main/loss', 'validation/main/loss'], 'epoch', file_name='loss.png'))
    trainer.extend(
        extensions.PlotReport(['main/accuracy', 'validation/main/accuracy'], 'epoch', file_name='accuracy.png'))
    
# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer.extend(extensions.PrintReport(
    ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

# Print a progress bar to stdout
trainer.extend(extensions.ProgressBar())

if args.resume:
    # Resume from a snapshot
    chainer.serializers.load_npz(args.resume, trainer)

# Run the training
trainer.run()

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
[J     total [..................................................]  1.67%
this epoch [########..........................................] 16.67%
       100 iter, 0 epoch / 10 epochs
       inf iters/sec. Estimated time to finish: 0:00:00.
[4A[J     total [#.................................................]  3.33%
this epoch [################..................................] 33.33%
       200 iter, 0 epoch / 10 epochs
    2.7865 iters/sec. Estimated time to finish: 0:34:41.474823.
[4A[J     total [##................................................]  5.00%
this epoch [#########################.........................] 50.00%
       300 iter, 0 epoch / 10 epochs
    2.6184 iters/sec. Estimated time to finish: 0:36:16.925105.
[4A[J     total [###...............................................]  6.67%
this epoch [#################################.................] 66.67%
       400 i

[4A[J     total [##########################........................] 53.33%
this epoch [################..................................] 33.33%
      3200 iter, 5 epoch / 10 epochs
    2.4264 iters/sec. Estimated time to finish: 0:19:13.952693.
[4A[J     total [###########################.......................] 55.00%
this epoch [#########################.........................] 50.00%
      3300 iter, 5 epoch / 10 epochs
    2.4359 iters/sec. Estimated time to finish: 0:18:28.429484.
[4A[J     total [############################......................] 56.67%
this epoch [#################################.................] 66.67%
      3400 iter, 5 epoch / 10 epochs
    2.4448 iters/sec. Estimated time to finish: 0:17:43.491299.
[4A[J     total [#############################.....................] 58.33%
this epoch [#########################################.........] 83.33%
      3500 iter, 5 epoch / 10 epochs
    2.4531 iters/sec. Estimated time to finish: 0:16:59.111812.
