In [2]:
import torch
from torch import nn
from torch.autograd import Variable
from torchvision import datasets, transforms
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torch import optim
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
%matplotlib inline

In [3]:
bsz = 10

train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=bsz, shuffle=True)
test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
batch_size=bsz, shuffle=True)

In [51]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

    
class MLPNet(nn.Module):
    def __init__(self):
        super(MLPNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 32)
        self.fc2 = nn.Linear(32, 10)
    def forward(self, x, T=1):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

    
def validate(model, batches_val):
    model.eval()
    y_pred = []
    y_true = []
    for batch in batches_val:
        features, targets = batch
        y_true += targets.tolist()
        out = model(Variable(features))
        y_pred += F.log_softmax(out, dim=1).topk(1)[1].squeeze().data.tolist()
    model.train()
    return accuracy_score(y_true, y_pred)

def validate_tree(model, batches_val):
    model.eval()
    y_pred = []
    y_true = []
    for batch in batches_val:
        features, targets = batch
        y_true += targets.tolist()
        y_pred += model(Variable(features.view(-1, 28*28))).topk(1)[1].squeeze().data.tolist()
    model.train()
    return accuracy_score(y_true, y_pred)

def distill_loss(out, labels, teacher, T, alpha=1):
    # KLD instead of cross-entropy
    f = nn.NLLLoss()
    out_sm = F.log_softmax(out/T, dim=1)
    teacher_sm = F.log_softmax(out/T, dim=1)
    term1 = f(F.log_softmax(out, dim=1), labels)
    term2 = nn.KLDivLoss()(out_sm, teacher_sm.detach())*alpha*T*T
    return term1 + term2

In [48]:
EPS = 1e-10

class Leaf(nn.Module):
    def __init__(self, i_size, o_size, h_size=128):
        super(Leaf, self).__init__()
        self.i2h = nn.Linear(i_size, h_size)
        self.h2o = nn.Linear(h_size, o_size)
        self.soft = nn.LogSoftmax(1)
        self.relu = nn.ReLU()
        self.is_leaf = True

    def forward(self, features):
        out = self.i2h(features)
        out = self.relu(out)
        out = self.h2o(out)
        return self.soft(out)

    def accum_probs(self, features, path_prob):
        return [[path_prob, self.forward(features)]]

    def calc_regularization(self, features, path_prob):
        return 0

class Node(nn.Module):
    def __init__(self, i_size, o_size):
        super(Node, self).__init__()
        self.o_size = o_size
        self.i_size = i_size
        self.i2o = nn.Linear(i_size, 1)
        self.sigmoid = nn.Sigmoid()
        self.is_leaf = False
    
    def build_tree(self, depth):
        if depth - 1 < 0:
            raise ValueError("Depth must be greater than zero.")
        if depth - 1 > 0:
            self.left = Node(self.i_size, self.o_size)
            self.right = Node(self.i_size, self.o_size)
            self.left.build_tree(depth - 1)
            self.right.build_tree(depth - 1)
        else:
            self.left = Leaf(self.i_size, self.o_size)
            self.right = Leaf(self.i_size, self.o_size)

    def forward(self, features):
        pr = self.prob_left(features)
        return pr*self.left(features) + (1 - pr)*self.right(features)

    def prob_left(self, features):
        return self.sigmoid(self.i2o(features))

    def accum_probs(self, features, path_prob):
        res = []
        p_l = self.sigmoid(self.i2o(features)).squeeze()
        res_l = self.left.accum_probs(features, p_l*path_prob)
        res_r = self.right.accum_probs(features, (1 - p_l)*path_prob)
        res.extend(res_l)
        res.extend(res_r)
        return res

    def calc_regularization(self, features, path_prob):
        p_l = self.prob_left(features).squeeze()
        alpha = (path_prob*p_l).sum()/(path_prob.sum())
        C_here = -0.5*torch.log(alpha + EPS) - 0.5*torch.log(1 - alpha + EPS)
        C = self.left.calc_regularization(features, p_l*path_prob) + \
                        self.right.calc_regularization(features, (1 - p_l)*path_prob)
        C = C + C_here
        return C
        

def tree_loss(path_probs, y_true, C, gamma):
    loss = 0
    criterion = nn.NLLLoss(reduce=False)
    for p, pred in path_probs:
        loss += (p*criterion(pred, y_true)).mean()
    return loss.mean() + gamma*C

def tree_logloss(path_probs, y_true):
    """
    Original loss from paper
    """
    loss = 0
    criterion = nn.NLLLoss()
    for p, pred in path_probs:
        loss -= (p.squeeze()*criterion(pred, y_true)).mean()
    return -torch.log(loss.mean())

def train(model, batches_train, batches_val, n_epoch=5, gamma=0.1,
          criterion=tree_loss, val_every=500, print_every=100):
    model.train()
    optimizer = optim.Adam(model.parameters())
    all_losses = np.zeros(print_every)
    plot_train = []
    plot_val = []
    for epoch in range(n_epoch):
        print('Epoch: {}'.format(epoch))
        for i, batch in enumerate(batches_train):
            optimizer.zero_grad()
            features, targets = batch
            features = Variable(features.view(-1, 28*28))
            targets = Variable(targets)
            prb = model.accum_probs(features, Variable(torch.Tensor([1]*targets.shape[0])))
            C = model.calc_regularization(features, Variable(torch.Tensor([1]*targets.shape[0])))
            loss = tree_loss(prb, targets, C, gamma=gamma)
            loss.backward()
            optimizer.step()
            plot_train += [loss.data[0]]
            all_losses[(i + 1)%print_every] = loss.data[0]
            if (i + 1) % print_every == 0:
                print(all_losses.max(), all_losses.mean())
            if (i + 1) % val_every == 0:
                plot_val += [validate(model, batches_val)]
    return plot_train, plot_val


In [11]:
model = MLPNet()

In [44]:
validate(model, test_loader)

0.0986

In [22]:
validate(model, test_loader)

0.102

In [58]:
dataiter = iter(train_loader)
images, labels = dataiter.next()

In [28]:
nn.KLDivLoss()(out, teacher.detach())

Variable containing:
-0.1736
[torch.FloatTensor of size 1]

In [27]:
teacher.detach()

RuntimeError: you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().

In [97]:
nn.NLLLoss()(model(Variable(images), T=1), Variable(labels))

Variable containing:
 0.1548
[torch.FloatTensor of size 1]

In [157]:
out = model(Variable(images), T=1)

In [122]:
tmp = torch.randn((10, 10))

In [124]:
tmp


 0.1539  0.0262 -1.1954  0.1588  0.7278  2.2301 -0.5055 -0.5056  0.1926  0.2100
 1.1154  0.0551  0.7055  0.3275  1.2062 -0.0442 -0.2309 -0.2658  0.5998 -0.8250
 1.2547  0.4871  0.5015 -1.1518  0.9102 -1.9922 -0.9561  0.9763  0.5878  0.4499
 1.0981 -0.4271 -0.3846 -0.0518 -0.2813  0.2450  0.1398 -1.0106 -0.3075 -0.2331
 0.2396 -1.1178 -0.8445 -0.1227 -0.0193 -1.4706 -0.7373  0.0439  0.2850  1.2408
-0.3649  1.8554  0.0346  1.3589 -2.3692 -1.1487  0.2179 -0.9799 -0.8932 -0.1461
 0.5386  1.1653 -1.5315 -1.1691 -0.0695 -0.4274  0.6384 -2.2958 -0.5693 -0.1638
 0.8861  2.8362 -0.7377  0.3162  0.0122  0.2033  0.3851  0.5844 -0.3663 -1.4459
 1.6428  0.1251 -1.1547 -1.7958 -0.5284  0.0292 -1.0785  0.1804  0.6491 -0.7858
-0.2267  0.5243  0.7828 -0.3112  0.0115  0.1925 -0.7195 -0.9484  0.1149 -1.3008
[torch.FloatTensor of size 10x10]

In [125]:
out

Variable containing:
-4.1462 -2.8970 -1.7497 -1.0912 -3.6110 -3.0325 -4.7618 -1.7315 -2.3107 -2.8245
-2.9200 -2.7526 -2.1662 -2.6158 -1.4263 -2.5391 -1.9825 -2.4884 -2.4612 -2.6684
-4.0901 -3.0911 -2.9071 -2.0496 -3.1610 -2.1145 -2.3699 -3.8722 -0.8469 -3.0309
-4.0136 -3.9463 -4.2655 -2.0009 -3.4587 -0.8159 -3.3900 -2.5985 -2.7108 -1.8015
-4.7107 -4.5514 -3.5822 -2.8478 -2.2337 -3.7733 -7.2636 -1.8880 -2.6121 -0.6179
-2.8180 -3.8386 -3.4167 -2.1056 -2.7148 -1.1374 -2.7607 -2.6737 -3.0516 -1.6207
-3.3092 -3.5563 -0.6738 -1.7274 -4.7652 -4.0080 -4.2577 -2.7040 -2.1067 -4.0201
-2.1608 -3.8429 -2.7477 -3.5961 -2.2231 -2.7197 -0.7823 -3.8699 -2.4188 -3.4923
-1.5345 -2.5138 -2.2990 -2.9358 -2.4746 -2.6045 -1.9151 -2.5914 -2.4913 -2.4437
-5.1462 -7.8088 -6.5441 -2.5508 -5.7172 -0.2723 -3.0963 -6.2423 -3.4774 -2.6397
[torch.FloatTensor of size 10x10]

In [131]:
nn.KLDivLoss()(model(Variable(images), T=1), Variable(out.data))

Variable containing:
 0
[torch.FloatTensor of size 1]

In [145]:
dist_loss(model(Variable(images), T=1) , Variable(labels), Variable(out.data), 0)

Variable containing:
 0.2022
[torch.FloatTensor of size 1]

In [89]:
Variable(targets)

RuntimeError: Variable data has to be a tensor, but got Variable

In [6]:

criterion = nn.NLLLoss()
n_epochs = 5
print_every = 1000

<h1>Train CNN

In [46]:
train_loss = []
val_loss = []
model = Net()
optimizer = optim.Adam(model.parameters())
for epoch in range(n_epochs):
    print('EPOCH: {}'.format(epoch))
    val_loss += [validate(model, test_loader)]
    print(val_loss[-1])
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        features, targets = batch
        features = Variable(features)
        targets = Variable(targets)
        out = model(features)
        loss = criterion(F.log_softmax(out, dim=1), targets)
        loss.backward()
        optimizer.step()
        train_loss += [loss.data[0]]
        if (i + 1) % 1000 == 0:
            print(np.max(train_loss), np.mean(train_loss))
            train_loss = []

EPOCH: 0
0.068
3.250565767288208 0.9064928530603648
3.8386521339416504 0.4244081753175706
1.914120078086853 0.3759179995940067
2.171348810195923 0.31199426844133993
2.8077845573425293 0.29984591309778624
2.394568920135498 0.2554379003227223
EPOCH: 1
0.9752
1.5989887714385986 0.24528537144861184
1.7577078342437744 0.2408153131652798
2.990764856338501 0.2377243892763654
2.6081299781799316 0.24577138381524127
2.079117774963379 0.21674371492350475
2.119439125061035 0.21550026029965375
EPOCH: 2
0.9814
2.019820213317871 0.20597821886582096
2.925922393798828 0.21297107608425722
4.348562717437744 0.21364294508202875
1.6269140243530273 0.20620368498855532
1.5557153224945068 0.20654489904981166
2.0109620094299316 0.20133193865249632
EPOCH: 3
0.9834
1.5962858200073242 0.18980152795968752
1.8202632665634155 0.1882713423241512
2.0586209297180176 0.19039900656670944
1.7367193698883057 0.17603074376726227
2.737109422683716 0.19075105903312214
2.298799753189087 0.1919629090946255
EPOCH: 4
0.9839
1.686

In [47]:
validate(model, test_loader)

0.9848

<h1>Train MLP

In [8]:
train_loss = []
val_loss = []
model_mlp = MLPNet()
optimizer = optim.Adam(model_mlp.parameters())
for epoch in range(n_epochs):
    print('EPOCH: {}'.format(epoch))
    val_loss += [validate(model_mlp, test_loader)]
    print(val_loss[-1])
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        features, targets = batch
        features = Variable(features)
        targets = Variable(targets)
        out = model_mlp(features)
        loss = criterion(F.log_softmax(out, dim=1), targets)
        loss.backward()
        optimizer.step()
        train_loss += [loss.data[0]]
        if (i + 1) % 1000 == 0:
            print(np.max(train_loss), np.mean(train_loss))
            train_loss = []

EPOCH: 0
0.106
2.6320433616638184 1.605462523818016
2.6214921474456787 1.526390792965889
2.5565125942230225 1.5039487244188785
2.7537596225738525 1.4267951794117688
2.78882098197937 1.338061695612967
2.969398021697998 1.2950184574052692
EPOCH: 1
0.5508
2.7672393321990967 1.2906241402179002
2.3240065574645996 1.2787590429186821
2.6269984245300293 1.2688658639788628
2.5018692016601562 1.2613810243895278
2.7148659229278564 1.2835158790051937
2.7108378410339355 1.2717669401466847
EPOCH: 2
0.5599
2.3090381622314453 1.2524490236639976
2.567070960998535 1.2493149854838848
2.50915789604187 1.2678505235612392
2.518629312515259 1.259647961884737
2.3025853633880615 1.2522857138365506
2.6223530769348145 1.1735659117400645
EPOCH: 3
0.6464
2.3260626792907715 1.0507244988897582
2.4671120643615723 1.042844795199111
2.797755718231201 1.042956533075543
2.5100579261779785 1.0464946614067303
2.163175344467163 1.033682730421424
2.5838637351989746 1.019644301354885
EPOCH: 4
0.6603
2.8707222938537598 1.04129

<h1>Train MLP with distillation loss

In [36]:
train_loss = []
val_loss = []
model_mlp = MLPNet()
optimizer = optim.Adam(model_mlp.parameters())
for epoch in range(n_epochs):
    print('EPOCH: {}'.format(epoch))
    val_loss += [validate(model_mlp, test_loader)]
    print(val_loss[-1])
    for i, batch in enumerate(train_loader):
        optimizer.zero_grad()
        features, targets = batch
        features = Variable(features)
        targets = Variable(targets)
        out = model_mlp(features)
        teacher = model(features)
        loss = distill_loss(out, targets, teacher, 5)
        loss.backward()
        optimizer.step()
        train_loss += [loss.data[0]]
        if (i + 1) % 1000 == 0:
            print(np.max(train_loss), np.mean(train_loss))
            train_loss = []

EPOCH: 0
0.1053
2.4506101608276367 1.0127332982420922
2.4707589149475098 0.8858727484801784
2.3847155570983887 0.8652417615484447
2.117615222930908 0.8130836763428524
2.5091567039489746 0.8332938201297074
2.2437922954559326 0.8160942573484499
EPOCH: 1
0.6932
2.991323471069336 0.8001047033295036
2.4121506214141846 0.7835619333037175
2.2176852226257324 0.7993163074050099
2.1255030632019043 0.770697611205047
2.687572717666626 0.7975997999751707
1.8874967098236084 0.7750216433053138
EPOCH: 2
0.6934
2.0765976905822754 0.7638222025652358
2.196305751800537 0.76588880770572
1.9965606927871704 0.7484665497506503
2.3403964042663574 0.7736654026816832
2.188589572906494 0.7683640420625161
2.149075984954834 0.7705404664249509
EPOCH: 3
0.6976
2.0743086338043213 0.737164663964184
2.232529878616333 0.7573174828453921
2.172527313232422 0.7608498331243172
2.3801751136779785 0.7467925194528653
2.496978521347046 0.7548395561763027
2.0883266925811768 0.753184190693195
EPOCH: 4
0.6966
1.9575586318969727 0.7

<h1>Distilling into tree

In [49]:
net = Node(28*28, 10)
net.build_tree(2)
training, testing = train(net, train_loader, test_loader, gamma=0, n_epoch=1)

Epoch: 0
2.336723566055298 0.9686341431736946
1.3108975887298584 0.4646588634699583
1.5012015104293823 0.3861819225549698
1.3249850273132324 0.33273567527532577
1.506148338317871 0.32575706465169785


RuntimeError: size mismatch, m1: [280 x 28], m2: [784 x 1] at /pytorch/torch/lib/TH/generic/THTensorMath.c:1434

In [50]:
validate(net, test_loader)

RuntimeError: size mismatch, m1: [280 x 28], m2: [784 x 1] at /pytorch/torch/lib/TH/generic/THTensorMath.c:1434