In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import time
import six
import numpy as np
from sklearn.datasets import fetch_mldata
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
import chainer.functions as F
import chainer.links as L

In [3]:
batchsize = 100
n_epoch = 20
n_units = 1000

mnist = fetch_mldata('MNIST original')
mnist.data = mnist.data.astype(np.float32)
mnist.data /= 255

mnist.target = mnist.target.astype(np.int32)

In [4]:
N = 60000
x_train, x_test = np.split(mnist.data, [N])
y_train, y_test = np.split(mnist.target, [N])
N_test = y_test.size

In [5]:
class MLPModel(chainer.Chain):
    def __init__(self, n_in, n_units, n_out):
        super(MLPModel, self).__init__(
            l1 = L.Linear(n_in, n_units),
            l2 = L.Linear(n_units, n_units),
            l3 = L.Linear(n_units, n_out),
        )
    
    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)

In [9]:
net = MLPModel(784, n_units, 10)

In [10]:
model = L.Classifier(net)

In [11]:
optimizer = optimizers.Adam()
optimizer.setup(model)

In [12]:
mod = serializers.load_npz("mlp.model", model)

In [13]:
def draw_digit(data):
    size = 28
    plt.figure(figsize=(2.5, 3))

    X, Y = np.meshgrid(range(size),range(size))
    Z = data.reshape(size,size)   # convert from vector to 28x28 matrix
    Z = Z[::-1,:]             # flip vertical
    plt.xlim(0,27)
    plt.ylim(0,27)
    plt.pcolor(X, Y, Z)
    plt.gray()
    plt.tick_params(labelbottom="off")
    plt.tick_params(labelleft="off")

    plt.show()

In [14]:
def plotdata(num):
    x_data = Variable(x_train[[num]])
    y_data = y_train[num]
    x_num = list(mlp(x_data).data[0])
    ans = x_num.index(max(x_num))
    print("ans=%d, reg=%d" % (ans, y_data))
    draw_digit(x_train[num])

In [17]:
def get_uncorrect(num):
    d = 0
    while d <= num:
        rand = np.random.randint(0, N)
        x_data = Variable(x_train[[rand]])
        y_data = y_train[rand]
        x_num = list(mlp(x_data).data[0])
        ans = x_num.index(max(x_num))
        if ans != y_data:
            print("ans=%d, reg=%d" % (ans, y_data))
            draw_digit(x_train[rand])
            d += 1
        else:
            continue

In [18]:
def get_correct(num):
    d = 0
    while d <= num:
        rand = np.random.randint(0, N)
        x_data = Variable(x_train[[rand]])
        y_data = y_train[rand]
        x_num = list(mlp(x_data).data[0])
        ans = x_num.index(max(x_num))
        if ans == y_data:
            print("ans=%d, reg=%d" % (ans, y_data))
            draw_digit(x_train[rand])
            d += 1
        else:
            continue