In [1]:
%load_ext autoreload
%autoreload 2

In [15]:
from embeddings import Embeddings, load_glove_embeddings

In [16]:
import os
from utils import *
from tqdm import tqdm
from torch.nn.utils.rnn import (pad_sequence,
                                pack_padded_sequence,
                                pad_packed_sequence)

In [44]:
dim = int(os.environ.get('SM_HP_GLOVE_DIM', 50))
max_length = int(os.environ.get('SM_HP_MAX_LENGTH', 30))
data_loc = os.environ.get('SM_HP_DATA_LOC', '../data')
epochs = int(os.environ.get('SM_HP_EPOCHS', 1))
batch = int(os.environ.get('SM_HP_BATCH', 100))

lr = float(os.environ.get('SM_HP_LR', 0.01))
train_remotely = bool(int(os.environ.get('SM_HP_TRAIN_REMOTELY', 0)))

In [27]:
embeddings = load_glove_embeddings(dim, data_loc)
len(embeddings)

400003

In [28]:
embeddings['car']

tensor([ 0.4769, -0.0846,  1.4641,  0.0470,  0.1469,  0.5082, -1.2228, -0.2261,
         0.1931, -0.2976,  0.2060, -0.7128, -1.6288,  0.1710,  0.7480, -0.0619,
        -0.6577,  1.3786, -0.6804, -1.7551,  0.5832,  0.2516, -1.2114,  0.8134,
         0.0948, -1.6819, -0.6450,  0.6322,  1.1211,  0.1611,  2.5379,  0.2485,
        -0.2682,  0.3282,  1.2916,  0.2355,  0.6147, -0.1344, -0.1324,  0.2740,
        -0.1182,  0.1354,  0.0743, -0.6195,  0.4547, -0.3032, -0.2188, -0.5605,
         1.1177, -0.3659])

In [29]:
dataset = Oxford2019Dataset(data_loc=os.path.join(data_loc, 'Oxford-2019', 'ALL.txt'))
dataset_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch, shuffle=True)

In [30]:
data = [
    ("sequencer", "dna sequencing was carried out using the dye termination kit and an automatic sequencer", "an apparatus for determining the sequence of amino acids or other monomers in a biological polymer"),
    ("order", "the templars were also known as the order of christ", "a society of knights bound by a common rule of life and having a combined military and monastic character"),
    ("horde", "an army or tribe of nomadic warriors", "the viking hordes returned to york this weekend as fierce armoured warriors mingled with the city centre crowds ."),
    ("anaemic", "although it has been thought of as a symptom of iron deficiency , it is more commonly discovered in patients who are not anemic .", "suffering from anaemia")
]

In [31]:
from itertools import chain

from typing import List

def to_packed_sequence(items: List[torch.Tensor]):
    lens = [len(s) for s in items]
    return pack_padded_sequence(pad_sequence(items), lens, enforce_sorted=False)

def sentence_to_list(sent):
    return chain([Embeddings.SOS_STR], sent.split(), [Embeddings.EOS_STR])

def strings_to_batch(strings: List[str]) -> torch.nn.utils.rnn.PackedSequence:
    sents_emb = [embeddings.sentence_to_tensor(sentence_to_list(sent)) for sent in strings]
    lens = [len(s) for s in sents_emb]
    batch = pack_padded_sequence(pad_sequence(sents_emb), lens, enforce_sorted=False)
    return batch, lens

def strings_to_ids(strings: List[str]) -> List[int]:
    ids = [embeddings.sentence_to_ids(sentence_to_list(sent)) for sent in strings]
    return to_packed_sequence(ids)


In [32]:
examples, examples_lens = strings_to_batch([d[1] for d in data])
defs, defs_lens = strings_to_batch([d[2] for d in data])
defs_ids = strings_to_ids([d[2] for d in data])

In [33]:
from wdm import LSTMEncoder, LSTMCellDecoder

In [34]:
encoder = LSTMEncoder(dim, dim)
# dim*2 because encoder is bidirectional
decoder = LSTMCellDecoder(dim, dim*2, len(embeddings))

In [40]:
def train(epochs, lr):
    criterion = nn.NLLLoss()
    log_softmax = nn.LogSoftmax(dim=1)

    # encoder_optim = torch.optim.Adam(encoder.parameters())
    # decoder_optim = torch.optim.Adam(decoder.parameters())

    encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr)
    decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr)
    loss = 0

    for i in range(epochs):
        epoch_loss = 0
        j = 0
        for words, defs, examples in tqdm(dataset_loader, disable=True):
            encoder_optim.zero_grad()
            decoder_optim.zero_grad()

            examples_ps, examples_lens = strings_to_batch(examples)
            defs_ps, defs_lens = strings_to_batch(defs)
            def_ids = strings_to_ids(defs)

            e_out, e_hidden = encoder(examples_ps)

            decoder_input = torch.cat((e_hidden[0], e_hidden[1]), dim=1)
            # decoder_input = decoder_input[examples.unsorted_indices]  # pytorch will do this automatically
            # decoder_input = decoder_input[defs.sorted_indices]  # pytorch will do this automatically
            decoder_input = decoder_input.unsqueeze(dim=0)

            d_out = decoder(defs_ps, decoder_input)
            d_out_lsm = squash_packed(d_out, log_softmax).data

            loss = criterion(d_out_lsm, def_ids.data)

            loss.backward()

            encoder_optim.step()
            decoder_optim.step()

            j += 1

            if j % 10 == 0:
                break

        print(f'Epoch {i}, loss {loss.item()}')

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

if train_remotely:
    role = sagemaker.get_execution_role()
    output_path = f'{DATASET_S3_PATH}-output'

    pytorch_estimator_l = PyTorch(entry_point='train.sh',
                                base_job_name='wdm-1',
                                role=role,
                                train_instance_count=1,
                                train_instance_type='ml.g4dn.2xlarge',  # GPU instance
                                train_volume_size=50,
                                train_max_run=86400,  # 24 hours
                                hyperparameters={
                                  'glove_dim': 50,
                                  'max_length': 30,
                                  'data_loc': '/opt/data' 
                                  'epochs': 10,
                                  'lr': 0.01,
                                  'train_remotely': 0
                                },
                                framework_version='1.6.0',
                                py_version='py3',
                                source_dir='.',  # This entire folder will be transferred to training instance
                                debugger_hook_config=False,
                                output_path=output_path,  # Model files will be uploaded here
                                image_name='954558792927.dkr.ecr.us-west-2.amazonaws.com/sagemaker/wdm:latest'
                     )

    pytorch_estimator_l.fit('', wait=False)
    

In [42]:
if 'TRAINING_JOB_NAME' in os.environ or not train_remotely:
    train(epochs, lr)    

TODO: For loss function... instead of doing log_softmax, do MSE with actual GloVe vector and minimize this loss function.
Then for BLEU evaluation, you'll need a function to find the closest vector to the one produced by the model.

Interesting to compare these results to log_softmax