In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [58]:
from embeddings import Embeddings, load_glove_embeddings

In [59]:
import torch
from torch import nn
from utils import *
from tqdm import tqdm

In [10]:
dim = 50
max_length = 50

In [11]:
embeddings = load_glove_embeddings(dim)
len(embeddings)

400003

In [12]:
embeddings['bar']

tensor([-9.4531e-01,  3.9686e-01, -8.0605e-01, -3.0215e-01,  2.7736e-01,
        -1.0019e-01, -4.0500e-01, -1.0095e-01, -6.5934e-02, -4.7258e-02,
        -2.0828e-01, -2.5721e-01,  6.8750e-02,  9.3751e-01, -8.1483e-02,
         1.3460e-01,  2.7302e-02, -1.8096e-01, -3.5638e-01, -8.8104e-01,
         1.1951e+00,  5.5556e-02, -3.1741e-01,  1.0244e+00, -8.4768e-01,
        -1.5959e+00,  2.1657e-02,  4.3628e-01,  8.8388e-04, -4.1820e-01,
         2.1247e+00, -4.3332e-01, -1.0816e+00,  3.3616e-01,  3.3399e-01,
        -2.0064e-01,  5.8633e-01,  9.0186e-02,  7.5054e-01,  4.8500e-01,
         1.7370e-01,  6.8129e-01, -1.6810e-01,  6.1265e-01,  7.6875e-02,
        -1.9797e-01, -9.9555e-02, -1.0231e+00,  9.5394e-01, -6.3500e-02])

In [13]:
dataset = Oxford2019Dataset()
dataset_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=10, shuffle=True)


In [14]:
data = [
    ("sequencer", "dna sequencing was carried out using the dye termination kit and an automatic sequencer", "an apparatus for determining the sequence of amino acids or other monomers in a biological polymer"),
    ("order", "the templars were also known as the order of christ", "a society of knights bound by a common rule of life and having a combined military and monastic character"),
    ("horde", "an army or tribe of nomadic warriors", "the viking hordes returned to york this weekend as fierce armoured warriors mingled with the city centre crowds ."),
    ("anaemic", "although it has been thought of as a symptom of iron deficiency , it is more commonly discovered in patients who are not anemic .", "suffering from anaemia")
]

In [15]:
from torch.nn.utils.rnn import (pad_sequence, 
                                pack_padded_sequence, 
                                pad_packed_sequence)

In [95]:
from itertools import chain

from typing import List

def to_packed_sequence(items: List[torch.Tensor]):
    lens = [len(s) for s in items]
    return pack_padded_sequence(pad_sequence(items), lens, enforce_sorted=False)

def sentence_to_list(sent):
    # return chain([Embeddings.SOS_STR], sent.split(), [Embeddings.EOS_STR])
    return chain([Embeddings.SOS_STR], sent.split(), [Embeddings.EOS_STR])

def strings_to_batch(strings: List[str]) -> torch.nn.utils.rnn.PackedSequence:
    sents_emb = [embeddings.sentence_to_tensor(sentence_to_list(sent)) for sent in strings]
    lens = [len(s) for s in sents_emb]
    batch = pack_padded_sequence(pad_sequence(sents_emb), lens, enforce_sorted=False)
    return batch, lens

def strings_to_ids(strings: List[str]) -> List[int]:
    ids = [embeddings.sentence_to_ids(sentence_to_list(sent)) for sent in strings]
    return to_packed_sequence(ids)


In [96]:
examples, examples_lens = strings_to_batch([d[1] for d in data])
defs, defs_lens = strings_to_batch([d[2] for d in data])
defs_ids = strings_to_ids([d[2] for d in data])

In [18]:
from wdm import LSTMEncoder, LSTMCellDecoder

In [19]:
encoder = LSTMEncoder(dim, dim)
# dim*2 because encoder is bidirectional
decoder = LSTMCellDecoder(dim, dim*2, len(embeddings))

In [98]:
epochs = 1
criterion = nn.NLLLoss()
encoder_optim = torch.optim.Adam(encoder.parameters())
decoder_optim = torch.optim.Adam(decoder.parameters())

for i in range(epochs):
    epoch_loss = 0
    for words, defs, examples in tqdm(dataset_loader):
        examples_ps, examples_lens = strings_to_batch(examples)
        defs_ps, defs_lens = strings_to_batch(defs)
        def_ids = strings_to_ids(defs)

#     for x, y in data_loader:
#         _, (contexts, _) = encoder(x)
#         outputs = decoder(contexts)
        
#         y = pad_sequence(y, outputs.shape[0])
        
#         loss = criterion(outputs, y)
#         loss.backward()
        
#         encoder_optimizer.step()
#         decoder_optimizer.step()
        
#         epoch_loss += loss.item() / len(x)
        
#     print(f'epoch_loss={epoch_loss})
# from utils import squash_packed
# criterion = nn.NLLLoss()
# log_softmax = nn.LogSoftmax(dim=1)
# loss = criterion(squash_packed(d_out, log_softmax).data, defs_ids.data)
# print(loss)

 13%|█▎        | 3740/29347 [00:07<00:50, 508.13it/s]


KeyboardInterrupt: 

In [36]:
e_out, e_hidden = encoder(examples)

In [37]:
decoder_input = torch.cat((e_hidden[0], e_hidden[1]), dim=1)
# decoder_input = decoder_input[examples.unsorted_indices]  # pytorch will do this automatically
# decoder_input = decoder_input[defs.sorted_indices]  # pytorch will do this automatically
decoder_input = decoder_input.unsqueeze(dim=0)

In [88]:
d_out = decoder(defs, decoder_input)

tensor(12.8941, grad_fn=<NllLossBackward>)


TODO: For loss function... instead of doing log_softmax, do MSE with actual GloVe vector and minimize this loss function.
Then for BLEU evaluation, you'll need a function to find the closest vector to the one produced by the model.

Interesting to compare these results to log_softmax