In [1]:
import os
from json import load
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
import seaborn as sns

import chainer
import chainer.links as L
import chainer.functions as F
from chainer import cuda
from chainer import optimizers
from chainer import iterators
from chainer import training
from chainer.training import extensions
from chainer import datasets
from chainer.datasets import TransformDataset
from chainer.datasets import LabeledImageDataset
from functools import partial
from chainercv import transforms

In [22]:
train, test = datasets.get_mnist(withlabel=True, ndim=3)

In [24]:
class Model(chainer.Chain):
    def __init__(self, out_dim=10):
        super(Model, self).__init__()
        with self.init_scope():
            self.conv1_1 = L.Convolution2D(None, 32, ksize=5, stride=1, pad=2)
            self.conv1_2 = L.Convolution2D(None, 32, ksize=5, stride=1, pad=2)
            self.conv2_1 = L.Convolution2D(None, 64, ksize=5, stride=1, pad=2)
            self.conv2_2 = L.Convolution2D(None, 64, ksize=5, stride=1, pad=2)
            self.conv3_1 = L.Convolution2D(None, 128, ksize=5, stride=1, pad=2)
            self.conv3_2 = L.Convolution2D(None, 128, ksize=5, stride=1, pad=2)
            self.fc1 = L.Linear(None, 2)
            self.fc2 = L.Linear(None, out_dim)
            
    def __call__(self, x):
        h = self.conv1_1(x)
        h = F.max_pooling_2d(self.conv1_2(h), 2, stride=2, pad=0)
        h = self.conv2_1(h)
        h = F.max_pooling_2d(self.conv2_2(h), 2, stride=2, pad=0)
        h = self.fc1(h)
        h = F.relu(h)
        y = self.fc2(h)
        return y
        
    
class SoftmaxLoss(chainer.Chain):
    def __init__(self, model):
        super(SoftmaxLoss, self).__init__()
        with self.init_scope():
            self.model = model
            
    def __call__(self, x, t):
        y=self.model(x)
        loss=F.softmax_cross_entropy(y, t)
        accuracy=F.accuracy(y, t)
        chainer.report({'loss':loss, 'accuracy':accuracy})
        return loss
        

In [25]:
model=Model()
loss=SoftmaxLoss(model)

In [26]:
n_epoch = 3
batchsize = 32
out_dir = './result/'
report_interval = (10, 'iteration')

train_iter = iterators.MultithreadIterator\
(train, batchsize, repeat=True, shuffle=True)

In [27]:
optimizer = chainer.optimizers.Adam(alpha=1e-3, amsgrad=True)
optimizer.setup(loss)

<chainer.optimizers.adam.Adam at 0x7f6210eebbe0>

In [28]:
updater = training.StandardUpdater(train_iter, optimizer, device=0)

In [29]:
trainer = training.Trainer(updater, (n_epoch, 'epoch'), out=out_dir)

trainer.extend(extensions.LogReport(trigger=report_interval))
trainer.extend(extensions.PrintReport(['epoch', 
                                       'iteration', 
                                       'main/loss',
                                      'main/accuracy']),
               trigger=report_interval)
trainer.extend(extensions.PlotReport(y_keys='main/loss', trigger=report_interval, file_name='MNIST_loss.png'))
trainer.extend(extensions.PlotReport(y_keys='main/accuracy', trigger=report_interval, file_name='MNIST_accuracy.png'))

In [30]:
trainer.run()

epoch       iteration   main/loss 
[J0           10                      
[J0           20                      
[J0           30                      
[J0           40                      
[J0           50                      
[J0           60                      
[J0           70                      
[J0           80                      
[J0           90                      
[J0           100                     
[J0           110                     
[J0           120                     
[J0           130                     
[J0           140                     
[J0           150                     
[J0           160                     
[J0           170                     
[J0           180                     
[J0           190                     
[J0           200                     
[J0           210                     
[J0           220                     
[J0           230                     
[J0           240                     
[J0 