# Chainerの利用例

In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container{width:100% !important;}</style>"))

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, Variable, optimizers, serializers, utils
from chainer import Link, Chain, ChainList
import chainer.functions as F
import chainer.links as L

### irisデータセット

In [11]:
from sklearn import datasets

iris = datasets.load_iris()
X = iris.data.astype(np.float32)
Y = iris.target
N = Y.size

Y2 = np.zeros(3 * N).reshape(N, 3).astype(np.float32)
for i in range(N):
    # one-hotベクトルに変換
    Y2[i, Y[i]] = 1.0
    
index = np.arange(N)
xtrain = X[index[index % 2 != 0], :]
ytrain = Y2[index[index % 2 != 0], :]
xtest = X[index[index % 2 == 0], :]
yans = Y[index[index % 2 == 0]]

### IrisChainクラス

In [40]:
class IrisChain(Chain):
    def __init__(self):
        super(IrisChain, self).__init__(
            l1 = L.Linear(4, 6),
            l2 = L.Linear(6, 3)
        )
    
    def __call__(self, x, y):
        return F.mean_squared_error(self.fwd(x), y)
    
    def fwd(self, x):
        h1 = F.sigmoid(self.l1(x))
        h2 = self.l2(h1)
        return h2

パラメータの学習

In [41]:
model = IrisChain()
optimizer = optimizers.SGD()
optimizer.setup(model)
for i in range(10000):
    x = Variable(xtrain)
    y = Variable(ytrain)
    model.zerograds()
    loss = model(x, y)
    loss.backward()
    optimizer.update()

In [42]:
xt = Variable(xtest)
yt = model.fwd(xt)
ans = yt.data
nrow, ncol = ans.shape
ok = 0
for i in range(nrow):
    cls = np.argmax(ans[i, :])
    if cls == yans[i]:
        ok += 1

In [43]:
print("%d / %d = %f" % (ok, nrow, (ok / nrow)))

71 / 75 = 0.946667


### ミニバッチ

In [30]:
n = 75
bs = 25
for j in range(5000):
    sffindx = np.random.permutation(n)
    for i in range(0, n, bs):
        x = Variable(xtrain[sffindx[i: (i + bs) if (i + bs) < n else n]])
        y = Variable(ytrain[sffindx[i: (i + bs) if (i + bs) < n else n]])
        model.zerograds()
        loss = model(x, y)
        loss.backward()
        optimizer.update()

### 誤差の累積

In [32]:
n = 75
bs = 25
for j in range(2000):
    accum_loss = None
    sffindx = np.random.permutation(n)
    for i in range(0, n, bs):
        x = Variable(xtrain[sffindx[i: (i + bs) if (i + bs) < n else n]])
        y = Variable(ytrain[sffindx[i: (i + bs) if (i + bs) < n else n]])
        model.zerograds()
        loss = model(x, y)
        accum_loss = loss if accum_loss is None else accum_loss + loss
        loss.backward()
        optimizer.update()

### Softmax

In [33]:
class IrisChain(Chain):
    def __init__(self):
        super(IrisChain, self).__init__(
            l1 = L.Linear(4, 6),
            l2 = L.Linear(6, 3)
        )
    
    def __call__(self, x, y):
        return F.mean_squared_error(self.fwd(x), y)
    
    def fwd(self, x):
        h1 = F.sigmoid(self.l1(x))
        h2 = self.l2(h1)
        h3 = F.softmax(h2)
        return h3

In [38]:
ans[0]

array([ 0.84821469,  0.13467893,  0.0171064 ], dtype=float32)

### softmax cross entropy

In [44]:
X = iris.data.astype(np.float32)
Y = iris.target.astype(np.int32)
N = Y.size

index = np.arange(N)
xtrain = X[index[index % 2 != 0], :]
ytrain = Y[index[index % 2 != 0]]
xtest = X[index[index % 2 == 0], :]
yans = Y[index[index % 2 == 0]]

In [46]:
class IrisChain(Chain):
    def __init__(self):
        super(IrisChain, self).__init__(
            l1 = L.Linear(4, 6),
            l2 = L.Linear(6, 3)
        )
    
    def __call__(self, x, y):
        return F.softmax_cross_entropy(self.fwd(x), y)
    
    def fwd(self, x):
        h1 = F.sigmoid(self.l1(x))
        h2 = self.l2(h1)
        return h2

### ロジスティック回帰

In [47]:
class IrisRogi(Chain):
    def __init__(self):
        super(IrisRogi, self).__init__(
            l1 = L.Linear(4, 3)
        )
    
    def __call__(self, x, y):
        return F.mean_squared_error(self.fwd(x), y)
    
    def fwd(self, x):
        return F.softmax(self.l1(x))

In [49]:
model = IrisRogi()
optimizer = optimizers.Adam()
optimizer.setup(model)

<chainer.optimizers.adam.Adam at 0x1133cdac8>