In [1]:
import numpy as np
import os
import sys
import tabulate
import torch
import torch.nn.functional as F

from tqdm import tqdm_notebook as tqdm

from torchtext import data
from torchtext import datasets
from torchtext.vocab import Vectors, GloVe

import curves
import data
import load_data
import models
import utils

In [2]:
c_dir = './chain'
num_points = 10                # number of points between models
batch_size = 32               # input batch size
num_workers = 4               # number of workers
model_name = 'LSTMClassifier' # model name
wd = 1e-5                     # weight decay
ckpts = ['./saved_models/LSTMClassifier-6.pt', './saved_models/LSTMClassifier2-4.pt'] # checkpoint to eval, pass all the models through this parameter

In [3]:
TEXT, vocab_size, num_classes, word_embeddings, train_loader, valid_loader, test_loader = \
                                load_data.load_dataset(batch_size=batch_size)

Length of Text Vocabulary: 135872
Vector size of Text Vocabulary:  torch.Size([135872, 300])
Label Length: 4


In [4]:
np.linspace(0.0, 1.0, num_points)

array([0.        , 0.11111111, 0.22222222, 0.33333333, 0.44444444,
       0.55555556, 0.66666667, 0.77777778, 0.88888889, 1.        ])

In [5]:
torch.backends.cudnn.benchmark = True

architecture = getattr(models, model_name)
kwargs = {
    'batch_size': batch_size,
    'hidden_size': 256,
    'embedding_length': 300,
    'vocab_size': vocab_size,
    'weights': word_embeddings
}

base_model = architecture.base(num_classes=num_classes, **kwargs)
base_model.cuda()


criterion = torch.nn.CrossEntropyLoss()
#regularizer = utils.l2_regularizer(wd)
regularizer = None

def get_weights(model):
    return np.concatenate([p.data.cpu().numpy().ravel() for p in model.parameters()])

T = (num_points - 1) * (len(ckpts) - 1) + 1
ts = np.linspace(0.0, len(ckpts) - 1, T)
tr_loss = np.zeros(T)
tr_nll = np.zeros(T)
tr_acc = np.zeros(T)
te_loss = np.zeros(T)
te_nll = np.zeros(T)
te_acc = np.zeros(T)
tr_err = np.zeros(T)
te_err = np.zeros(T)

columns = ['t', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll', 'Test error (%)']

alphas = np.linspace(0.0, 1.0, num_points)

for path in ckpts:
    print(path)

step = 0
for i in range(len(ckpts) - 1):
    base_model.load_state_dict(torch.load(ckpts[i])['model_state'])
    w_1 = get_weights(base_model)

    base_model.load_state_dict(torch.load(ckpts[i + 1])['model_state'])
    w_2 = get_weights(base_model)
    for alpha in alphas[1 if i > 0 else 0:]:
        w = (1.0 - alpha) * w_1 + alpha * w_2
        offset = 0
        for parameter in base_model.parameters():
            size = np.prod(parameter.size())
            value = w[offset:offset+size].reshape(parameter.size())
            parameter.data.copy_(torch.from_numpy(value))
            offset += size

        #utils.update_bn(loaders['train'], base_model)

        tr_res = utils.eval_model(train_loader, base_model, criterion, batch_size, regularizer)
        te_res = utils.eval_model(test_loader, base_model, criterion, batch_size, regularizer)

        tr_loss[step] = tr_res['loss']
        tr_nll[step] = tr_res['nll']
        tr_acc[step] = tr_res['acc']
        tr_err[step] = 100.0 - tr_acc[step]
        te_loss[step] = te_res['loss']
        te_nll[step] = te_res['nll']
        te_acc[step] = te_res['acc']
        te_err[step] = 100.0 - te_acc[step]

        values = [ts[step], tr_loss[step], tr_nll[step], tr_err[step], te_nll[step], te_err[step]]
        table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='10.4f')
        if step % 40 == 0:
            table = table.split('\n')
            table = '\n'.join([table[1]] + table)
        else:
            table = table.split('\n')[2]
        print(table)
        step += 1


np.savez(
    os.path.join(c_dir, 'chain1.npz'),
    ts=ts,
    tr_loss=tr_loss,
    tr_nll=tr_nll,
    tr_acc=tr_acc,
    tr_err=tr_err,
    te_loss=te_loss,
    te_nll=te_nll,
    te_acc=te_acc,
    te_err=te_err,
)

./saved_models/LSTMClassifier-6.pt
./saved_models/LSTMClassifier2-4.pt


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))

  return Variable(arr, volatile=not train), lengths
  return Variable(arr, volatile=not train)



----------  ------------  -----------  -----------------  ----------  ----------------
         t    Train loss    Train nll    Train error (%)    Test nll    Test error (%)
----------  ------------  -----------  -----------------  ----------  ----------------
    0.0000        0.1365       0.1365             4.9696      0.2443            9.4118


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    0.1111        0.1650       0.1650             5.4316      0.2464            9.5336


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    0.2222        0.2836       0.2836             7.6876      0.3342           10.6555


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    0.3333        0.7139       0.7139            20.4241      0.7366           22.2353


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    0.4444        1.5526       1.5526            57.2343      1.5536           57.8571


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    0.5556        1.5694       1.5694            74.8582      1.5645           75.1765


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    0.6667        0.9462       0.9462            34.9922      0.9526           35.3319


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    0.7778        0.3863       0.3863            13.3491      0.4234           14.7185


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    0.8889        0.2095       0.2095             7.7070      0.2786           10.5252


HBox(children=(IntProgress(value=0, max=3188), HTML(value='')))




HBox(children=(IntProgress(value=0, max=238), HTML(value='')))


    1.0000        0.1530       0.1530             5.7287      0.2474            9.4706
