In [9]:
import numpy as np
import os
import sys
import tabulate
import torch
import torch.nn.functional as F

from tqdm import tqdm_notebook as tqdm

from torchtext import data
from torchtext import datasets
from torchtext.vocab import Vectors, GloVe

import curves
import data
import load_data
import models
import utils

In [14]:
c_dir = './chain'
num_points = 12                # number of points between models
batch_size = 64               # input batch size
num_workers = 4               # number of workers
model_name = 'LSTMClassifier' # model name
wd = 1e-5                     # weight decay
ckpts = ['./saved_models/point4.pt', './saved_models/point5.pt'] # checkpoint to eval, pass all the models through this parameter

In [15]:
TEXT, vocab_size, num_classes, word_embeddings, train_loader, valid_loader, test_loader = \
                                load_data.load_dataset(batch_size=batch_size)

Length of Text Vocabulary: 135872
Vector size of Text Vocabulary:  torch.Size([135872, 300])
Label Length: 4


In [16]:
np.linspace(0.0, 1.0, num_points)

array([0.        , 0.09090909, 0.18181818, 0.27272727, 0.36363636,
       0.45454545, 0.54545455, 0.63636364, 0.72727273, 0.81818182,
       0.90909091, 1.        ])

In [17]:
torch.backends.cudnn.benchmark = True

architecture = getattr(models, model_name)
kwargs = {
    'batch_size': batch_size,
    'hidden_size': 256,
    'embedding_length': 300,
    'vocab_size': vocab_size,
    'weights': word_embeddings
}

base_model = architecture.base(num_classes=num_classes, **kwargs)
base_model.cuda()


criterion = torch.nn.CrossEntropyLoss()
#regularizer = utils.l2_regularizer(wd)
regularizer = None

def get_weights(model):
    return np.concatenate([p.data.cpu().numpy().ravel() for p in model.parameters()])

T = (num_points - 1) * (len(ckpts) - 1) + 1
ts = np.linspace(0.0, len(ckpts) - 1, T)
tr_loss = np.zeros(T)
tr_nll = np.zeros(T)
tr_acc = np.zeros(T)
te_loss = np.zeros(T)
te_nll = np.zeros(T)
te_acc = np.zeros(T)
tr_err = np.zeros(T)
te_err = np.zeros(T)

columns = ['t', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll', 'Test error (%)']

alphas = np.linspace(0.0, 1.0, num_points)

for path in ckpts:
    print(path)

step = 0
for i in range(len(ckpts) - 1):
    base_model.load_state_dict(torch.load(ckpts[i])['model_state'])
    w_1 = get_weights(base_model)

    base_model.load_state_dict(torch.load(ckpts[i + 1])['model_state'])
    w_2 = get_weights(base_model)
    for alpha in alphas[1 if i > 0 else 0:]:
        w = (1.0 - alpha) * w_1 + alpha * w_2
        offset = 0
        for parameter in base_model.parameters():
            size = np.prod(parameter.size())
            value = w[offset:offset+size].reshape(parameter.size())
            parameter.data.copy_(torch.from_numpy(value))
            offset += size

        #utils.update_bn(loaders['train'], base_model)

        tr_res = utils.eval_model(train_loader, base_model, criterion, batch_size, regularizer)
        te_res = utils.eval_model(test_loader, base_model, criterion, batch_size, regularizer)

        tr_loss[step] = tr_res['loss']
        tr_nll[step] = tr_res['nll']
        tr_acc[step] = tr_res['acc']
        tr_err[step] = 100.0 - tr_acc[step]
        te_loss[step] = te_res['loss']
        te_nll[step] = te_res['nll']
        te_acc[step] = te_res['acc']
        te_err[step] = 100.0 - te_acc[step]

        values = [ts[step], tr_loss[step], tr_nll[step], tr_err[step], te_nll[step], te_err[step]]
        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='10.4f')
        if step % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)
        step += 1


np.savez(
    os.path.join(c_dir, 'chain4_5.npz'),
    ts=ts,
    tr_loss=tr_loss,
    tr_nll=tr_nll,
    tr_acc=tr_acc,
    tr_err=tr_err,
    te_loss=te_loss,
    te_nll=te_nll,
    te_acc=te_acc,
    te_err=te_err,
)

./saved_models/point4.pt
./saved_models/point5.pt


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


----------  ------------  -----------  -----------------  ----------  ----------------
         t    Train loss    Train nll    Train error (%)    Test nll    Test error (%)
----------  ------------  -----------  -----------------  ----------  ----------------
    0.0000        0.0140       0.0140             0.6016      0.4701           10.5462


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.0909        0.0217       0.0217             0.9561      0.4318           10.5462


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.1818        0.0637       0.0637             2.8049      0.4164           10.6303


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.2727        0.1944       0.1944             6.8118      0.4545           12.7143


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.3636        0.4991       0.4991            16.1744      0.6505           19.7143


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.4545        0.9623       0.9623            34.3356      1.0110           36.2017


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.5455        1.3031       1.3031            50.9285      1.3062           51.7311


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.6364        1.7384       1.7384            60.1255      1.7346           60.5546


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.7273        0.4451       0.4451            12.8450      0.5646           16.2689


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.8182        0.1815       0.1815             5.4592      0.4691           11.3529


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.9091        0.0382       0.0382             1.5621      0.4493           10.0756


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    1.0000        0.0125       0.0125             0.4981      0.4807            9.9748
