In [1]:
import os
import sys
sys.path.insert(0, "../")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import random
from tqdm import tqdm_notebook as tqdm

from learning.treelstm.utils import load_word_vectors
from learning.treelstm.trainer import Trainer
from learning.treelstm.metrics import Metrics
from learning.treelstm.model import *
from learning.treelstm.vocab import Vocab
import learning.treelstm.Constants as Constants
from learning.treelstm.dataset import QGDataset
from learning.treelstm.scripts.preprocess_lcquad import build_vocab
from itertools import product

data_path = '../learning/treelstm/data/lc_quad/'
train_path = data_path + 'train/'
dev_path = data_path + 'dev/'
test_path = data_path + 'test/'
checkpoints_path = '../learning/treelstm/checkpoints'

In [2]:
class Struct: pass
args = Struct()
args.seed = 41
args.cuda = False
args.batchsize = 20
args.mem_dim = 150
args.hidden_dim = 50
args.num_classes = 2
args.input_dim = 300
args.sparse = False
args.lr = 0.01
args.wd = 1e-4

args.epochs = 15

In [3]:
torch.manual_seed(args.seed)
random.seed(args.seed)

In [4]:
torch.get_num_threads()
torch.set_num_threads(2)

### vocab

In [5]:
# mapping words to indexes
vocab = Vocab(
    os.path.join(data_path, 'dataset.vocab'),
    [Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]
)

In [20]:
len(vocab.idxToLabel)

8057

In [17]:
try:
    emb = torch.load('glove_lc_merged_emb.pth')
except:
    emb = torch.Tensor(vocab.size(), 300).normal_(-0.05, 0.05)
    # zero out the embeddings for padding and other special words if they are absent in vocab
    for idx, item in enumerate([Constants.PAD_WORD, Constants.UNK_WORD, Constants.BOS_WORD, Constants.EOS_WORD]):
        emb[idx].zero_()

    with open('../learning/treelstm/data/glove.840B.300d.txt', 'r') as file:
        for line in tqdm(file):
            contents = line.rstrip('\n').split(' ')
            idx = vocab.getIndex(contents[0])
            if(idx):
                emb[idx] = torch.Tensor(list(map(float, contents[1:])))

    torch.save(emb, 'glove_lc_merged_emb.pth')

### sets

In [6]:
train_set = QGDataset(train_path, vocab, args.num_classes)
dev_set = QGDataset(dev_path, vocab, args.num_classes)
test_set = QGDataset(test_path, vocab, args.num_classes)

100%|██████████| 7896/7896 [00:00<00:00, 55352.86it/s]
100%|██████████| 7896/7896 [00:00<00:00, 83627.02it/s]
100%|██████████| 7896/7896 [00:00<00:00, 16924.59it/s]
100%|██████████| 7896/7896 [00:00<00:00, 30231.59it/s]
100%|██████████| 2265/2265 [00:00<00:00, 47000.18it/s]
100%|██████████| 2265/2265 [00:00<00:00, 59091.24it/s]
100%|██████████| 2265/2265 [00:00<00:00, 10590.26it/s]
100%|██████████| 2265/2265 [00:00<00:00, 58802.66it/s]
100%|██████████| 1090/1090 [00:00<00:00, 38241.35it/s]
100%|██████████| 1090/1090 [00:00<00:00, 52868.36it/s]
100%|██████████| 1090/1090 [00:00<00:00, 17461.38it/s]
100%|██████████| 1090/1090 [00:00<00:00, 48279.12it/s]


In [7]:
7896+2265+1090

11251

### model

In [8]:
similarity = DASimilarity(args.mem_dim, args.hidden_dim, args.num_classes)
#similarity = CosSimilarity(1)

In [9]:
# left - a - sent
# right - b - query

class SimilarityEncoders(nn.Module):
    def __init__(self, vocab_size, in_dim, mem_dim, similarity, sparsity):
        super(SimilarityEncoders, self).__init__()
        self.emb = nn.Embedding(vocab_size, in_dim, padding_idx=Constants.PAD, sparse=sparsity)
        self.sent_treelstm = ChildSumTreeLSTM(in_dim, mem_dim)
        self.query_treelstm = ChildSumTreeLSTM(in_dim, mem_dim)
        self.similarity = similarity

    def forward(self, ltree, linputs, rtree, rinputs):
        linputs = self.emb(linputs)
        rinputs = self.emb(rinputs)
        lstate, lhidden = self.sent_treelstm(ltree, linputs)
        rstate, rhidden = self.query_treelstm(rtree, rinputs)
        output = self.similarity(lstate, rstate)
        return output

In [10]:
model = SimilarityEncoders(vocab.size(), args.input_dim, args.mem_dim, similarity, args.sparse)

In [11]:
criterion = nn.KLDivLoss()
optimizer = torch.optim.Adagrad(model.parameters(), lr=args.lr, weight_decay=args.wd)

In [12]:
# model.emb.weight.data.copy_(emb)

NameError: name 'emb' is not defined

In [13]:
trainer = Trainer(args, model, criterion, optimizer)
metrics = Metrics(args.num_classes)

In [14]:
model.parameters()

<generator object Module.parameters at 0x7fa3a4580f68>

### Treino

In [20]:
# train_loss, train_pred = trainer.test(train_set)
# print('train_loss:', train_loss)
# print('train_pred:', train_pred)

In [21]:
for epoch in range(args.epochs):
    train_loss = trainer.train(train_set)
    train_loss, train_pred = trainer.test(train_set)
    print('train_loss:', train_loss)
    print('train_pred:', train_pred)
    checkpoint = {'model': trainer.model.state_dict(), 'optim': trainer.optimizer,
                  'args': args, 'epoch': epoch}
    torch.save(checkpoint, 'checkpoint_' + str(epoch) + '.pth')

  out = F.log_softmax(self.wp(out))
Training epoch 1: 100%|██████████| 7896/7896 [05:51<00:00, 22.44it/s]
  linput, rinput = Var(lsent, volatile=True), Var(rsent, volatile=True)
  target = Var(map_label_to_target(label, dataset.num_classes), volatile=True)
Testing epoch  1: 100%|██████████| 7896/7896 [01:37<00:00, 80.90it/s] 
Training epoch 2:   0%|          | 3/7896 [00:00<05:38, 23.34it/s]

train_loss: tensor(0.1514)
train_pred: tensor([1.9249, 1.1357, 1.9325,  ..., 1.9290, 1.1791, 1.9251])


Training epoch 2: 100%|██████████| 7896/7896 [07:12<00:00, 18.25it/s]
Testing epoch  2: 100%|██████████| 7896/7896 [01:23<00:00, 94.83it/s] 
Training epoch 3:   0%|          | 3/7896 [00:00<05:46, 22.78it/s]

train_loss: tensor(0.1163)
train_pred: tensor([1.9398, 1.0694, 1.9587,  ..., 1.9484, 1.0891, 1.9549])


Training epoch 3: 100%|██████████| 7896/7896 [07:00<00:00, 18.79it/s]
Testing epoch  3: 100%|██████████| 7896/7896 [01:28<00:00, 88.93it/s] 
Training epoch 4:   0%|          | 2/7896 [00:00<07:18, 18.02it/s]

train_loss: tensor(0.0931)
train_pred: tensor([1.9634, 1.0237, 1.9721,  ..., 1.9667, 1.0269, 1.9711])


Training epoch 4: 100%|██████████| 7896/7896 [07:29<00:00, 17.55it/s]
Testing epoch  4: 100%|██████████| 7896/7896 [01:30<00:00, 87.55it/s] 
Training epoch 5:   0%|          | 3/7896 [00:00<06:39, 19.77it/s]

train_loss: tensor(0.0772)
train_pred: tensor([1.9757, 1.0180, 1.9809,  ..., 1.9777, 1.0216, 1.9818])


Training epoch 5: 100%|██████████| 7896/7896 [07:17<00:00, 18.05it/s]
Testing epoch  5: 100%|██████████| 7896/7896 [01:31<00:00, 86.20it/s] 
Training epoch 6:   0%|          | 2/7896 [00:00<06:45, 19.46it/s]

train_loss: tensor(0.0652)
train_pred: tensor([1.9818, 1.0104, 1.9859,  ..., 1.9851, 1.0102, 1.9870])


Training epoch 6: 100%|██████████| 7896/7896 [07:39<00:00, 17.17it/s]
Testing epoch  6: 100%|██████████| 7896/7896 [01:36<00:00, 82.11it/s] 
Training epoch 7:   0%|          | 3/7896 [00:00<05:45, 22.86it/s]

train_loss: tensor(0.0574)
train_pred: tensor([1.9876, 1.0074, 1.9890,  ..., 1.9883, 1.0074, 1.9900])


Training epoch 7: 100%|██████████| 7896/7896 [07:18<00:00, 18.00it/s]
Testing epoch  7: 100%|██████████| 7896/7896 [01:15<00:00, 104.73it/s]
Training epoch 8:   0%|          | 2/7896 [00:00<06:48, 19.34it/s]

train_loss: tensor(0.0497)
train_pred: tensor([1.9899, 1.0061, 1.9901,  ..., 1.9905, 1.0059, 1.9921])


Training epoch 8: 100%|██████████| 7896/7896 [06:03<00:00, 21.75it/s]
Testing epoch  8: 100%|██████████| 7896/7896 [01:14<00:00, 106.18it/s]
Training epoch 9:   0%|          | 3/7896 [00:00<06:44, 19.50it/s]

train_loss: tensor(0.0434)
train_pred: tensor([1.9890, 1.0039, 1.9906,  ..., 1.9911, 1.0034, 1.9939])


Training epoch 9: 100%|██████████| 7896/7896 [06:17<00:00, 20.91it/s]
Testing epoch  9: 100%|██████████| 7896/7896 [01:20<00:00, 98.36it/s] 
Training epoch 10:   0%|          | 0/7896 [00:00<?, ?it/s]

train_loss: tensor(0.0390)
train_pred: tensor([1.9921, 1.0038, 1.9923,  ..., 1.9934, 1.0031, 1.9949])


Training epoch 10: 100%|██████████| 7896/7896 [06:00<00:00, 21.92it/s]
Testing epoch  10: 100%|██████████| 7896/7896 [01:13<00:00, 107.18it/s]
Training epoch 11:   0%|          | 3/7896 [00:00<04:46, 27.54it/s]

train_loss: tensor(0.0348)
train_pred: tensor([1.9930, 1.0028, 1.9938,  ..., 1.9950, 1.0024, 1.9959])


Training epoch 11: 100%|██████████| 7896/7896 [05:52<00:00, 22.42it/s]
Testing epoch  11: 100%|██████████| 7896/7896 [01:12<00:00, 109.63it/s]
Training epoch 12:   0%|          | 3/7896 [00:00<04:52, 27.00it/s]

train_loss: tensor(0.0317)
train_pred: tensor([1.9929, 1.0022, 1.9947,  ..., 1.9956, 1.0019, 1.9965])


Training epoch 12: 100%|██████████| 7896/7896 [05:56<00:00, 22.12it/s]
Testing epoch  12: 100%|██████████| 7896/7896 [01:12<00:00, 109.52it/s]
Training epoch 13:   0%|          | 3/7896 [00:00<05:26, 24.18it/s]

train_loss: tensor(0.0292)
train_pred: tensor([1.9954, 1.0019, 1.9964,  ..., 1.9967, 1.0017, 1.9970])


Training epoch 13: 100%|██████████| 7896/7896 [05:52<00:00, 22.41it/s]
Testing epoch  13: 100%|██████████| 7896/7896 [01:13<00:00, 107.48it/s]
Training epoch 14:   0%|          | 3/7896 [00:00<05:05, 25.85it/s]

train_loss: tensor(0.0268)
train_pred: tensor([1.9955, 1.0018, 1.9967,  ..., 1.9970, 1.0015, 1.9973])


Training epoch 14: 100%|██████████| 7896/7896 [05:57<00:00, 22.07it/s]
Testing epoch  14: 100%|██████████| 7896/7896 [01:10<00:00, 111.73it/s]
Training epoch 15:   0%|          | 3/7896 [00:00<05:12, 25.24it/s]

train_loss: tensor(0.0287)
train_pred: tensor([1.9968, 1.0020, 1.9971,  ..., 1.9975, 1.0014, 1.9977])


Training epoch 15: 100%|██████████| 7896/7896 [05:44<00:00, 22.94it/s]
Testing epoch  15: 100%|██████████| 7896/7896 [01:18<00:00, 100.45it/s]

train_loss: tensor(0.0259)
train_pred: tensor([1.9966, 1.0016, 1.9972,  ..., 1.9978, 1.0012, 1.9979])





### Teste

In [23]:
for epoch in range(args.epochs):
    checkpoint = torch.load('checkpoint_' + str(epoch) + '.pth')
    model.load_state_dict(checkpoint['model'])
    loss, pred = trainer.test(dev_set)
    print('epoch', epoch, metrics.f1(pred.numpy(), dev_set.labels))

  linput, rinput = Var(lsent, volatile=True), Var(rsent, volatile=True)
  target = Var(map_label_to_target(label, dataset.num_classes), volatile=True)
  out = F.log_softmax(self.wp(out))
Testing epoch  0: 100%|██████████| 2265/2265 [00:22<00:00, 100.73it/s]
Testing epoch  0:   0%|          | 5/2265 [00:00<00:46, 48.78it/s]

epoch 0 (0.8488633220603211, 0.848108381892653, 0.8484479832602387)


Testing epoch  0: 100%|██████████| 2265/2265 [00:22<00:00, 101.06it/s]
Testing epoch  0:   0%|          | 8/2265 [00:00<00:28, 79.61it/s]

epoch 1 (0.8441505672148859, 0.8456562407873196, 0.8445672589748006)


Testing epoch  0: 100%|██████████| 2265/2265 [00:22<00:00, 102.12it/s]
Testing epoch  0:   0%|          | 8/2265 [00:00<00:29, 77.72it/s]

epoch 2 (0.8390773901883333, 0.8409644382695528, 0.8386607628478607)


Testing epoch  0: 100%|██████████| 2265/2265 [00:21<00:00, 104.24it/s]
Testing epoch  0:   0%|          | 11/2265 [00:00<00:21, 103.27it/s]

epoch 3 (0.8270811079252638, 0.82842696030158, 0.8251082568463877)


Testing epoch  0: 100%|██████████| 2265/2265 [00:21<00:00, 104.78it/s]
Testing epoch  0:   0%|          | 6/2265 [00:00<00:37, 59.54it/s]

epoch 4 (0.8214998664473592, 0.8227546933706336, 0.8193752082609883)


Testing epoch  0: 100%|██████████| 2265/2265 [00:22<00:00, 102.47it/s]
Testing epoch  0:   0%|          | 11/2265 [00:00<00:22, 101.60it/s]

epoch 5 (0.825507234371117, 0.8265714872637633, 0.8229269933178343)


Testing epoch  0: 100%|██████████| 2265/2265 [00:30<00:00, 74.88it/s] 
Testing epoch  0:   0%|          | 7/2265 [00:00<00:36, 61.32it/s]

epoch 6 (0.8133604429750928, 0.8140606672646415, 0.8101411655357724)


Testing epoch  0: 100%|██████████| 2265/2265 [00:31<00:00, 72.08it/s] 
Testing epoch  0:   0%|          | 10/2265 [00:00<00:24, 91.82it/s]

epoch 7 (0.8293490318322873, 0.8310249361776862, 0.8296202222493589)


Testing epoch  0: 100%|██████████| 2265/2265 [00:25<00:00, 89.45it/s] 
Testing epoch  0:   0%|          | 5/2265 [00:00<00:45, 49.75it/s]

epoch 8 (0.8041416283728055, 0.8035774210140063, 0.7986688644334486)


Testing epoch  0: 100%|██████████| 2265/2265 [00:26<00:00, 86.19it/s] 
Testing epoch  0:   0%|          | 6/2265 [00:00<00:38, 59.05it/s]

epoch 9 (0.8055789858434201, 0.8067716729286758, 0.8034766935846939)


Testing epoch  0: 100%|██████████| 2265/2265 [00:25<00:00, 88.27it/s] 
Testing epoch  0:   0%|          | 7/2265 [00:00<00:33, 68.16it/s]

epoch 10 (0.8138371699522639, 0.815571938881118, 0.8138073479675143)


Testing epoch  0: 100%|██████████| 2265/2265 [00:25<00:00, 87.97it/s] 
Testing epoch  0:   0%|          | 5/2265 [00:00<00:50, 44.98it/s]

epoch 11 (0.8203684309921542, 0.8218800767749503, 0.818892726923815)


Testing epoch  0: 100%|██████████| 2265/2265 [00:25<00:00, 88.72it/s] 
Testing epoch  0:   0%|          | 9/2265 [00:00<00:25, 88.36it/s]

epoch 12 (0.8193450753729594, 0.8205679558669485, 0.817169754814719)


Testing epoch  0: 100%|██████████| 2265/2265 [00:27<00:00, 83.55it/s] 
Testing epoch  0:   0%|          | 9/2265 [00:00<00:26, 86.55it/s]

epoch 13 (0.7883769119029309, 0.7857785224585547, 0.7795747691720813)


Testing epoch  0: 100%|██████████| 2265/2265 [00:26<00:00, 85.88it/s] 

epoch 14 (0.8013806161492643, 0.7995218814880791, 0.7937609346158474)





In [15]:
checkpoint = torch.load('checkpoint_14.pth')
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

In [16]:
model.emb.weight.data

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 3.1492e-06, -7.0746e-40,  6.9912e-40,  ..., -3.4392e-03,
         -4.6204e-04, -2.1826e-03],
        ...,
        [-7.7178e-02, -5.1056e-02, -8.5519e-02,  ..., -3.8988e-02,
         -4.5611e-03,  6.5343e-03],
        [-2.9133e-04,  3.0484e-04,  6.9033e-04,  ...,  3.2669e-04,
          4.2872e-05,  5.6230e-05],
        [ 8.6759e-02,  1.3279e-02, -2.2038e-01,  ...,  1.2827e-02,
         -1.1806e-01, -1.4734e-01]])

In [22]:
metrics.f1(pred.numpy(), dev_set.labels)

(0.8013806161492643, 0.7995218814880791, 0.7937609346158474)