In [5]:
import numpy as np
import os
import tabulate
import torch
import torch.nn.functional as F

import load_data
import models
import curves
import utils


c_dir = './eval_polychain'                # training directory 
num_points = 61                 # number of points on the curve
batch_size = 64                 # input batch size
model_name = 'LSTMClassifier'   # model name
curve_type = 'PolyChain'        # curve type to use
num_bends = 3                   # number of curve bends
ckpt = './saved_models/LSTMClassifier_curve_polychain4_5-22.pt'  
#ckpt = './saved_models/LSTMClassifier_curve2-35.pt' # checkpoint of polychain
wd = 1e-4                       # weight decay

torch.backends.cudnn.benchmark = True

In [6]:
TEXT, vocab_size, num_classes, word_embeddings, train_loader, valid_loader, test_loader = \
                                load_data.load_dataset(batch_size=batch_size)

Length of Text Vocabulary: 135872
Vector size of Text Vocabulary:  torch.Size([135872, 300])
Label Length: 4


In [7]:
kwargs = {
    'batch_size': batch_size,
    'hidden_size': 256,
    'embedding_length': 300,
    'vocab_size': vocab_size,
    'weights': word_embeddings
}

architecture = getattr(models, model_name)
curve = getattr(curves, curve_type)
model = curves.CurveNet(
    num_classes,
    curve,
    architecture.curve,
    num_bends,
    architecture_kwargs=kwargs,
)
model.cuda()
checkpoint = torch.load(ckpt)
model.load_state_dict(checkpoint['model_state'])

#criterion = F.cross_entropy
criterion = torch.nn.CrossEntropyLoss()
regularizer = curves.l2_regularizer(wd)

T = num_points
ts = np.linspace(0.0, 1.0, T)
tr_loss = np.zeros(T)
tr_nll = np.zeros(T)
tr_acc = np.zeros(T)
te_loss = np.zeros(T)
te_nll = np.zeros(T)
te_acc = np.zeros(T)
tr_err = np.zeros(T)
te_err = np.zeros(T)
dl = np.zeros(T)

previous_weights = None

columns = ['t', 'Train loss', 'Train nll', 'Train error (%)', 'Test nll', 'Test error (%)']

t = torch.FloatTensor([0.0]).cuda()
for i, t_value in enumerate(ts):
    t.data.fill_(t_value)
    weights = model.weights(t)
    if previous_weights is not None:
        dl[i] = np.sqrt(np.sum(np.square(weights - previous_weights)))
    previous_weights = weights.copy()

    tr_res = utils.eval_model(train_loader, model, criterion, batch_size, regularizer, t=t)
    te_res = utils.eval_model(test_loader, model, criterion, batch_size, regularizer, t=t)
    tr_loss[i] = tr_res['loss']
    tr_nll[i] = tr_res['nll']
    tr_acc[i] = tr_res['acc']
    tr_err[i] = 100.0 - tr_acc[i]
    te_loss[i] = te_res['loss']
    te_nll[i] = te_res['nll']
    te_acc[i] = te_res['acc']
    te_err[i] = 100.0 - te_acc[i]

    values = [t, tr_loss[i], tr_nll[i], tr_err[i], te_nll[i], te_err[i]]
    table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='10.4f')
    if i % 40 == 0:
        table = table.split('\n')
        table = '\n'.join([table[1]] + table)
    else:
        table = table.split('\n')[2]
    print(table)

HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


----------  ------------  -----------  -----------------  ----------  ----------------
         t    Train loss    Train nll    Train error (%)    Test nll    Test error (%)
----------  ------------  -----------  -----------------  ----------  ----------------
    0.0000      127.1520       0.0140             0.6016      0.4701           10.5462


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))




HBox(children=(IntProgress(value=0, max=119), HTML(value='')))


    0.0167      127.1130       0.0109             0.4649      0.4789           10.4202


HBox(children=(IntProgress(value=0, max=1594), HTML(value='')))

KeyboardInterrupt: 

In [20]:
def stats(values, dl):
    min = np.min(values)
    max = np.max(values)
    avg = np.mean(values)
    int = np.sum(0.5 * (values[:-1] + values[1:]) * dl[1:]) / np.sum(dl[1:])
    return min, max, avg, int


tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int = stats(tr_loss, dl)
tr_nll_min, tr_nll_max, tr_nll_avg, tr_nll_int = stats(tr_nll, dl)
tr_err_min, tr_err_max, tr_err_avg, tr_err_int = stats(tr_err, dl)

te_loss_min, te_loss_max, te_loss_avg, te_loss_int = stats(te_loss, dl)
te_nll_min, te_nll_max, te_nll_avg, te_nll_int = stats(te_nll, dl)
te_err_min, te_err_max, te_err_avg, te_err_int = stats(te_err, dl)

print('Length: %.2f' % np.sum(dl))
print(tabulate.tabulate([
        ['train loss', tr_loss[0], tr_loss[-1], tr_loss_min, tr_loss_max, tr_loss_avg, tr_loss_int],
        ['train nll', tr_nll[0], tr_nll[-1], tr_nll_min, tr_nll_max, tr_nll_avg, tr_nll_int],
        ['train error (%)', tr_err[0], tr_err[-1], tr_err_min, tr_err_max, tr_err_avg, tr_err_int],
        ['test nll', te_nll[0], te_nll[-1], te_nll_min, te_nll_max, te_nll_avg, te_nll_int],
        ['test error (%)', te_err[0], te_err[-1], te_err_min, te_err_max, te_err_avg, te_err_int],
    ], [
        '', 'start', 'end', 'min', 'max', 'avg', 'int'
    ], tablefmt='simple', floatfmt='10.4f'))

np.savez(
    os.path.join(c_dir, 'curve4_5.npz'),
    ts=ts,
    dl=dl,
    tr_loss=tr_loss,
    tr_loss_min=tr_loss_min,
    tr_loss_max=tr_loss_max,
    tr_loss_avg=tr_loss_avg,
    tr_loss_int=tr_loss_int,
    tr_nll=tr_nll,
    tr_nll_min=tr_nll_min,
    tr_nll_max=tr_nll_max,
    tr_nll_avg=tr_nll_avg,
    tr_nll_int=tr_nll_int,
    tr_acc=tr_acc,
    tr_err=tr_err,
    tr_err_min=tr_err_min,
    tr_err_max=tr_err_max,
    tr_err_avg=tr_err_avg,
    tr_err_int=tr_err_int,
    te_loss=te_loss,
    te_loss_min=te_loss_min,
    te_loss_max=te_loss_max,
    te_loss_avg=te_loss_avg,
    te_loss_int=te_loss_int,
    te_nll=te_nll,
    te_nll_min=te_nll_min,
    te_nll_max=te_nll_max,
    te_nll_avg=te_nll_avg,
    te_nll_int=te_nll_int,
    te_acc=te_acc,
    te_err=te_err,
    te_err_min=te_err_min,
    te_err_max=te_err_max,
    te_err_avg=te_err_avg,
    te_err_int=te_err_int,
)

Length: 227.92
                      start         end         min         max         avg         int
---------------  ----------  ----------  ----------  ----------  ----------  ----------
train loss         126.5277    126.4215    126.0718    126.5277    126.1584    126.1808
train nll            0.1489      0.1054      0.0973      0.1679      0.1400      0.1352
train error (%)      5.9125      3.8877      3.6647      6.8814      5.5700      5.3511
test nll             0.2413      0.2365      0.2321      0.2677      0.2514      0.2490
test error (%)       9.3109      8.8866      8.6050      9.9832      9.3545      9.2545


Length: 276.64
                      start         end         min         max         avg         int
---------------  ----------  ----------  ----------  ----------  ----------  ----------
train loss         126.6103    126.5062    126.1005    126.6103    126.2208    126.2160
train nll            0.1365      0.1529      0.1332      0.1785      0.1552      0.1555
train error (%)      4.9589      5.7280      4.9043      6.7911      5.9461      5.9605
test nll             0.2443      0.2474      0.2327      0.2589      0.2423      0.2423
test error (%)       9.4118      9.4706      8.6218      9.7731      9.1926      9.1982


Length: 260.22
                      start         end         min         max         avg         int
---------------  ----------  ----------  ----------  ----------  ----------  ----------
train loss         126.6103    126.5062    126.1128    126.6103    126.2265    126.2219
train error (%)      4.9589      5.7280      4.8802      8.0267      6.3322      6.3456
test nll             0.2443      0.2474      0.2294      0.2629      0.2422      0.2422
test error (%)       9.4118      9.4706      8.7521      9.7773      9.2828      9.2867
