In [1]:
import os
import sys
sys.path.insert(0, "../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim

from learning.treelstm.model import *
from learning.treelstm.vocab import Vocab
import learning.treelstm.Constants as Constants
from learning.treelstm.dataset import QGDataset
from learning.treelstm.scripts.preprocess_lcquad import build_vocab
from learning.treelstm.trainer import Trainer
from learning.treelstm.metrics import Metrics
from itertools import product


data_path = '../learning/treelstm/data/lc_quad/'
train_path = data_path + 'train/'
dev_path = data_path + 'dev/'
test_path = data_path + 'test/'
checkpoints_path = '../learning/treelstm/checkpoints'

In [2]:
class Struct: pass
args = Struct()
args.seed = 41
args.cuda = False
args.batchsize = 20
args.mem_dim = 150
args.hidden_dim = 50
args.num_classes = 2
args.input_dim = 300
args.sparse = False
args.lr = 0.01
args.wd = 1e-4

args.epochs = 15

# Testing vocabulary

In [3]:
# mapping words to indexes
vocab = Vocab(
    os.path.join(data_path, 'dataset.vocab'),
    [Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]
)

In [4]:
#vocab.labelToIdx.keys()

In [5]:
# checking if the vocabulary generated using train, dev and test splits is equals to dataset.vocab
# yes it is
paths = [os.path.join(*pair) for pair in product([train_path, dev_path, test_path], ['a.toks', 'b.toks'])]
build_vocab(paths, 'teste.vocab', lowercase=False)

# Dataset

In [6]:
train_set = QGDataset(train_path, vocab, 2)
dev_set = QGDataset(dev_path, vocab, 2)
test_set = QGDataset(test_path, vocab, 2)

100%|██████████| 7896/7896 [00:00<00:00, 61360.27it/s]
100%|██████████| 7896/7896 [00:00<00:00, 89210.22it/s]
100%|██████████| 7896/7896 [00:00<00:00, 17281.04it/s]
100%|██████████| 7896/7896 [00:00<00:00, 35082.90it/s]
100%|██████████| 2265/2265 [00:00<00:00, 53764.91it/s]
100%|██████████| 2265/2265 [00:00<00:00, 56406.18it/s]
100%|██████████| 2265/2265 [00:00<00:00, 11074.85it/s]
100%|██████████| 2265/2265 [00:00<00:00, 56073.25it/s]
100%|██████████| 1090/1090 [00:00<00:00, 67599.05it/s]
100%|██████████| 1090/1090 [00:00<00:00, 52520.35it/s]
100%|██████████| 1090/1090 [00:00<00:00, 21166.39it/s]
100%|██████████| 1090/1090 [00:00<00:00, 30210.34it/s]


# Testing model

In [7]:
similarity = DASimilarity(mem_dim=150, hidden_dim=50, num_classes=2)

In [8]:
model = SimilarityTreeLSTM(
    vocab.size(),
    in_dim=300,
    mem_dim=150,
    similarity=similarity,
    sparsity=args.sparse
)

In [9]:
criterion = nn.KLDivLoss()
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01, weight_decay=1e-4)

In [10]:
emb = torch.load(os.path.join(data_path, 'dataset_embed.pth'))

In [11]:
emb.shape

torch.Size([8057, 300])

In [12]:
model.emb

Embedding(8057, 300, padding_idx=0)

In [13]:
model.emb.weight.data.copy_(emb)

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.8932, -1.1132,  0.1972,  ...,  0.1354,  0.2262, -0.3589],
        ...,
        [-0.0313, -0.0095, -0.0752,  ..., -0.1731, -0.0141, -0.0466],
        [-0.1493, -0.2350, -0.4891,  ...,  0.4781, -0.3923,  0.7577],
        [ 0.2707, -0.0874, -0.3683,  ...,  0.0368, -0.2172, -0.3406]])

In [14]:
checkpoint = torch.load(os.path.join(checkpoints_path, 'lc_quad.pt'), map_location=lambda storage, loc: storage)

In [15]:
checkpoint

{'args': Namespace(batchsize=25, cuda=True, data='data/lc_quad/', epochs=15, expname='lc_quad', glove='data/glove/', hidden_dim=50, input_dim=300, load='checkpoints/', lr=0.01, mem_dim=150, mode='train', num_classes=2, optim='adagrad', save='checkpoints/', seed=123, sparse=False, wd=0.0001),
 'epoch': 14,
 'model': OrderedDict([('emb.weight',
               tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
                         0.0000e+00,  0.0000e+00],
                       [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
                         0.0000e+00,  0.0000e+00],
                       [ 5.5470e-02, -1.5601e-01,  3.0261e-20,  ...,  1.9638e-39,
                         6.4665e-16, -2.6980e-07],
                       ...,
                       [-2.2153e-03,  6.7445e-03, -2.8541e-02,  ..., -9.2654e-02,
                         7.4230e-03, -1.6443e-02],
                       [ 3.3374e-04, -1.8750e-04, -1.4594e-03,  ...,  4.4125e-03,
             

In [16]:
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [17]:
trainer = Trainer(args, model, criterion, optimizer)
metrics = Metrics(args.num_classes)

In [19]:
loss, pred = trainer.test(test_set)

Testing epoch  0: 100%|██████████| 1090/1090 [00:10<00:00, 104.19it/s]


In [27]:
print(metrics.f1(pred.numpy(), test_set.labels))

(0.8303520784377278, 0.8298609355246523, 0.8284286699118448)
