# Chainer exmaple on MNIST

## References

* [chainer/examples/mnist at master · chainer/chainer](https://github.com/chainer/chainer/tree/master/examples/mnist)

In [None]:
import argparse
import chainer
import chainer.functions as F


## Training

### Network definition

In [None]:
class MNISTConvNet(chainer.Chain):
    def __init__(self, n_units, n_out,
            conv1_in_channels: int, conv1_out_channels: int, conv1_kernel_size: int, conv1_stride: int,
            conv2_in_channels: int, conv2_out_channels: int, conv2_kernel_size: int, conv2_stride: int,
            pool1_kernel_size: int, dropout1_p: float, dropout2_p: float,
            fullconn1_in_features: int, fullconn1_out_features: int, fullconn2_in_features: int, fullconn2_out_features: int
            ) -> None:
        super(MNISTConvNet, self).__init__()
        with self.init_scope():
            self.conv1 = chainer.links.Convolution2D(in_channels=conv1_in_channels, out_channels=conv1_out_channels, ksize=conv1_kernel_size, stride=conv1_stride)
            self.conv2 = chainer.links.Convolution2D(in_channels=conv2_in_channels, out_channels=conv2_out_channels, ksize=conv2_kernel_size, stride=conv2_stride)
            self.fullconn1 = chainer.links.Linear(in_size=fullconn1_in_features, out_size=fullconn1_out_features)
            self.fullconn2 = chainer.links.Linear(in_size=fullconn2_in_features, out_size=fullconn2_out_features)

            self.pool1_kernel_size = pool1_kernel_size
            self.dropout1_p = dropout1_p
            self.dropout2_p = dropout2_p
        
    def __call__(self, x):  # NOTE: ...Or, def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pooling_2d(x, ksize=self.pool1_kernel_size)
        x = F.dropout(x, ratio=self.dropout1_p)
        x = F.flatten(x)
        x = self.fullconn1(x)
        x = F.relu(x)
        x = F.dropout(x, ratio=self.dropout2_p)
        x = self.fullconn2(x)
        return F.log_softmax(x)


In [None]:
def set_printing_parameters(trainer):
    trainer.extend(chainer.training.extensions.LogReport(), call_before_training=True)
    trainer.extend(chainer.training.extensions.PlotReport(y_keys=['main/loss', 'validation/main/loss'], x_key='epoch', filename='loss.png'), call_before_training=True)
    trainer.extend(chainer.training.extensions.PlotReport(y_keys=['main/accuracy', 'validation/main/accuracy'], x_key='epoch', filename='accuracy.png'), call_before_training=True)
    trainer.extend(chainer.training.extensions.PrintReport(entries=['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']), call_before_training=True)
    trainer.extend(chainer.training.extensions.ProgressBar())
    return trainer


In [None]:
def get_trainer(model, dataset, batchsize, device, epoch, out):
    optimizer = chainer.optimizers.AdaDelta()
    optimizer.setup(link=model)
    
    train_iter = chainer.iterators.SerialIterator(dataset=dataset, batch_size=batchsize, repeat=True, shuffle=None, order_sampler=None)

    updater = chainer.training.updaters.StandardUpdater(iterator=train_iter, optimizer=optimizer, device=device, loss_func=None, loss_scale=None, auto_new_epoch=True)
    trainer = chainer.training.Trainer(updater=updater, stop_trigger=(epoch, 'epoch'), out=out)
    return trainer

In [None]:
def prepare_trainer(trainer, device, model, dataset, batchsize, epoch, frequency, resume=None, autoload=True):
    test_iter = chainer.iterators.SerialIterator(dataset=dataset, batch_size=batchsize, repeat=False, shuffle=False, order_sampler=None)
    trainer.extend(extension=chainer.training.extensions.Evaluator(iterator=test_iter, target=model, device=device, eval_hook=None, eval_func=None),
        call_before_training=True)
    
    frequency = epoch if frequency == -1 else max(1, frequency)
    trainer.extend(chainer.training.extensions.snapshot(n_retains=1, autoload=autoload), name=None, trigger=(frequency, 'epoch'), priority=None, call_before_training=False)

    trainer = set_printing_parameters(trainer=trainer)

    if resume is not None:
        chainer.serializers.load_npz(file=resume, obj=trainer)
    
    return trainer


In [None]:
def train(predictor, device, unit, batchsize, epoch, out, frequency, resume=None, autoload=True) -> None:
    device = chainer.get_device(device_spec=device)

    print('Device: {}'.format(device))
    print('# unit: {}'.format(unit))
    print('# Minibatch-size: {}'.format(batchsize))
    print('# epoch: {}'.format(epoch))
    print('')

    model = chainer.links.Classifier(predictor=predictor)
    model.to_device(device=device)
    device.use()

    train_dataset, test_dataset = chainer.datasets.get_mnist()

    trainer = get_trainer(model=model, dataset=train_dataset, batchsize=batchsize, device=device, epoch=epoch, out=out)
    trainer = prepare_trainer(trainer=trainer, device=device, model=model, dataset=test_dataset, batchsize=batchsize, epoch=epoch, frequency=frequency, resume=resume, autoload=autoload)

    trainer.run()


## Inference

In [None]:
def predict(predictor, device, unit, snapshot) -> None:
    device = chainer.get_device(device_spec=device)

    print('Device: {}'.format(device))
    print('# unit: {}'.format(unit))
    print('')

    device.use()

    model = predictor

    try:
        chainer.serializers.load_npz(file=snapshot, obj=model, path='updater/model:main/predictor/', strict=True, ignore_names=None)
    except Exception:
        chainer.serializers.load_npz(file=snapshot, obj=model, path='predictor/', strict=True, ignore_names=None)
    
    model.to_device(device=device)

    _, test_dataset = chainer.datasets.get_mnist()

    x, answer = test_dataset[0]
    x = device.send(arrays=x)
    with chainer.using_config(name='train', value=False):
        prediction = model(x[None, ...])[0].array.argmax()
    
    print('Prediction:', prediction)
    print('Answer:', answer)


## Main

In [None]:
def get_argparser():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')

    # For the model
    parser.add_argument('--conv1-in-channels', type=int, default=1)
    parser.add_argument('--conv1-out-channels', type=int, default=32)
    parser.add_argument('--conv1-kernel-size', type=int, default=3)
    parser.add_argument('--conv1-stride', type=int, default=1)
    parser.add_argument('--conv2-in-channels', type=int, default=32)
    parser.add_argument('--conv2-out-channels', type=int, default=64)
    parser.add_argument('--conv2-kernel-size', type=int, default=3)
    parser.add_argument('--conv2-stride', type=int, default=1)
    parser.add_argument('--pool1-kernel-size', type=int, default=2)
    parser.add_argument('--dropout1-p', type=float, default=0.25)
    parser.add_argument('--dropout2-p', type=float, default=0.5)
    parser.add_argument('--fullconn1-in-features', type=int, default=12*12*64)
    parser.add_argument('--fullconn1-out-features', type=int, default=128)
    parser.add_argument('--fullconn2-in-features', type=int, default=128)
    parser.add_argument('--fullconn2-out-features', type=int, default=10)

    # Both for training and inference
    parser.add_argument('--device', '-d', type=str, default='-1',
                        help='Device specifier. Either ChainerX device '
                        'specifier or an integer. If non-negative integer, '
                        'CuPy arrays with specified device id are used. If '
                        'negative integer, NumPy arrays are used')
    parser.add_argument('--unit', '-u', type=int, default=128,
                        help='Number of units')
    group = parser.add_argument_group('deprecated arguments')
    group.add_argument('--gpu', '-g', dest='device',
                       type=int, nargs='?', const=0,
                       help='GPU ID (negative value indicates CPU)')
    
    # For training
    parser.add_argument('--batchsize', '-b', type=int, default=32,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=1,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', type=str,
                        help='Resume the training from snapshot')
    parser.add_argument('--autoload', action='store_true',
                        help='Automatically load trainer snapshots in case'
                        ' of preemption or other temporary system failure')
    
    # For inference
    parser.add_argument('--snapshot', '-s',
                        default='result/snapshot_iter_12000',
                        help='The path to a saved snapshot (NPZ)')

    return parser, group


In [None]:
def main(args=None) -> None:
    if not args:
        argparser, _ = get_argparser()
        args = argparser.parse_args()
    
    predictor = MNISTConvNet(n_units=args.unit, n_out=10,
        conv1_in_channels=args.conv1_in_channels, conv1_out_channels=args.conv1_out_channels, conv1_kernel_size=args.conv1_kernel_size, conv1_stride=args.conv1_stride,
        conv2_in_channels=args.conv2_in_channels, conv2_out_channels=args.conv2_out_channels, conv2_kernel_size=args.conv2_kernel_size, conv2_stride=args.conv2_stride,
        pool1_kernel_size=args.pool1_kernel_size, dropout1_p=args.dropout1_p, dropout2_p=args.dropout2_p,
        fullconn1_in_features=args.fullconn1_in_features, fullconn1_out_features=args.fullconn1_out_features, fullconn2_in_features=args.fullconn2_in_features, fullconn2_out_features=args.fullconn2_out_features
        )

    train(predictor=predictor, device=args.device, unit=args.unit, batchsize=args.batchsize, epoch=args.epoch, out=args.out, frequency=args.frequency)
    predict(predictor=predictor, device=args.device, unit=args.unit, snapshot=args.snapshot)


In [None]:
argparser, _ = get_argparser()
args = argparser.parse_args(
    [
        "--device", str(-1),
        "--unit", str(128),
        # "--gpu", str(0),
        "--batchsize", str(32),
        "--epoch", str(1),
        "--frequency", str(1),
        "--out", "result",
        # "--resume", "",
        "--autoload",
        "--snapshot", "result/snapshot_iter_3750",
    ]
)
main(args)