[Chainerでマルチタスクニューラルネットワークを実装した](http://szdr.hatenablog.com/entry/2017/03/06/233530)の改変

In [1]:
import numpy as np
import pandas as pd
import chainer
import chainer.links as L
import chainer.functions as F
import chainer.computational_graph as c
from sklearn import datasets

ネットワークの定義

In [2]:
class SharedNet(chainer.Chain):
    def __init__(self, n_out):
        super(SharedNet, self).__init__(
            l1=L.Linear(None, n_out)
        )
    def __call__(self, x):
        a = self.l1(x)
        z = F.sigmoid(a)
        return z
class SeparatedNet(chainer.Chain):
    def __init__(self, n_out):
        super(SeparatedNet, self).__init__(
            l1=L.Linear(None, n_out)
        )
    def __call__(self, x):
        a = self.l1(x)
        z = a
        return z
class CombinedNet(chainer.Chain):
    def __init__(self, shared, separated0, separated1):
        super(CombinedNet, self).__init__(
            shared=shared,
            separated0=separated0,
            separated1=separated1
        )
    def __call__(self, x1, x2):
        s  = self.shared(x1)
        # task0の出力
        s0 = self.separated0(s)
        # task1の出力
        s1 = self.separated1(s)
        # 今回はtask0とtask1が排他なのでoutputを1つにまとめられる
        m  = (x2 == 0)
        b  = np.broadcast_to(m, s0.shape)
        sx = F.where(b, s0, s1)
        return sx

最適化の設定

In [3]:
net = CombinedNet(SharedNet(3), SeparatedNet(2), SeparatedNet(2))
optimizer = chainer.optimizers.Adam()
optimizer.use_cleargrads()
optimizer.setup(net)

データの読込

In [4]:
X, ys = datasets.load_iris(return_X_y=True)
X  = X.astype(np.float32)
ys = ys.astype(np.int32)
y1 = (ys == 0).astype(np.int32)
y2 = (ys == 2).astype(np.int32)

タスクに合わせた整形

In [5]:
np.random.seed(20160308)
# y=0はtask0
# y=2はtask1
# y=1はランダムにtaskを割り当てる
task_idx = np.random.choice([0, 1], len(ys), replace=True)
task_idx[ys==0] = 0
task_idx[ys==2] = 1
task_idx = task_idx.astype(np.int32)
print((task_idx==0).mean())

0.48


In [6]:
v_in  = [X, np.expand_dims(task_idx, 1)]
# task0のラベルとtask1のラベルを1つにまとめる
v_out = ((task_idx==0)*(ys==0) + (task_idx==1)*(ys==2)).astype(np.int32)

パラメータの更新

In [7]:
net.cleargrads()
optimizer.update(
    F.softmax_cross_entropy,
    net(v_in[0], v_in[1]),
    v_out
)