In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from embeddings import Embeddings, load_glove_embeddings

In [3]:
import os
import torch
from torch import nn
from utils import squash_packed
from tqdm import tqdm

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


In [5]:
dim = int(os.environ.get('SM_HP_GLOVE_DIM', 50))
max_length = int(os.environ.get('SM_HP_MAX_LENGTH', 30))
data_loc = os.environ.get('SM_HP_DATA_LOC', './data')
epochs = int(os.environ.get('SM_HP_EPOCHS', 2))
batch = int(os.environ.get('SM_HP_BATCH', 32))
lr = float(os.environ.get('SM_HP_LR', 0.01))
train_remotely = bool(int(os.environ.get('SM_HP_TRAIN_REMOTELY', 0)))

In [6]:
embeddings = load_glove_embeddings(dim, data_loc)
len(embeddings)

400003

In [7]:
embeddings['car']

tensor([ 0.4769, -0.0846,  1.4641,  0.0470,  0.1469,  0.5082, -1.2228, -0.2261,
         0.1931, -0.2976,  0.2060, -0.7128, -1.6288,  0.1710,  0.7480, -0.0619,
        -0.6577,  1.3786, -0.6804, -1.7551,  0.5832,  0.2516, -1.2114,  0.8134,
         0.0948, -1.6819, -0.6450,  0.6322,  1.1211,  0.1611,  2.5379,  0.2485,
        -0.2682,  0.3282,  1.2916,  0.2355,  0.6147, -0.1344, -0.1324,  0.2740,
        -0.1182,  0.1354,  0.0743, -0.6195,  0.4547, -0.3032, -0.2188, -0.5605,
         1.1177, -0.3659])

In [8]:
from dataset import Oxford2019Dataset
from torch.utils.data import DataLoader

def make_data_loader(filename: str, file_loc: str = os.path.join(data_loc, 'Oxford-2019')) -> DataLoader:
    dataset = Oxford2019Dataset(data_loc=os.path.join(file_loc, filename))
    data_loader = DataLoader(dataset, batch_size=batch, shuffle=True)
    return data_loader

train = make_data_loader('train.txt')
test = make_data_loader('test.txt')
valid = make_data_loader('valid.txt')

In [9]:
from itertools import chain
from typing import List, Iterable
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, PackedSequence

def sentence_to_list(sent: str) -> Iterable[str]:
    return chain([Embeddings.SOS_STR], sent.split(), [Embeddings.EOS_STR])

def strings_to_batch(strings: List[str]) -> PackedSequence:
    sents_emb = [embeddings.sentence_to_tensor(sentence_to_list(sent)) for sent in strings]
    lens = [len(s) for s in sents_emb]
    batch = pack_padded_sequence(pad_sequence(sents_emb), lens, enforce_sorted=False).to(device)
    return batch, lens

def strings_to_ids(strings: List[str]) -> PackedSequence:
    ids = [embeddings.sentence_to_ids(sentence_to_list(sent)) for sent in strings]
    lens = [len(i) for i in ids]
    return pack_padded_sequence(pad_sequence(ids), lens, enforce_sorted=False).to(device)

In [10]:
from wdm import LSTMEncoder, LSTMCellDecoder

encoder = LSTMEncoder(dim, dim).to(device)
# dim*2 because encoder is bidirectional
decoder = LSTMCellDecoder(dim, dim*2, len(embeddings)).to(device)

In [11]:
def train(epochs: int, data_loader: DataLoader):
    criterion = nn.NLLLoss()
    log_softmax = nn.LogSoftmax(dim=1)

    encoder_optim = torch.optim.Adam(encoder.parameters())
    decoder_optim = torch.optim.Adam(decoder.parameters())

    for i in range(epochs):
        epoch_loss = 0
        for words, defs, examples in tqdm(data_loader, disable=train_remotely):
            encoder_optim.zero_grad()
            decoder_optim.zero_grad()

            word_first_examples = (''.join((w, e)) for w, e in zip(words, examples))

            examples_ps, examples_lens = strings_to_batch(word_first_examples)
            defs_ps, defs_lens = strings_to_batch(defs)
            def_ids = strings_to_ids(defs)

            e_out, e_hidden = encoder(examples_ps)

            decoder_input = torch.cat((e_hidden[0], e_hidden[1]), dim=1)
            # decoder_input = decoder_input[examples.unsorted_indices]  # pytorch will do this automatically
            # decoder_input = decoder_input[defs.sorted_indices]  # pytorch will do this automatically
            decoder_input = decoder_input.unsqueeze(dim=0)

            d_out = decoder(defs_ps, decoder_input)
            d_out_lsm = squash_packed(d_out, log_softmax).data

            loss = criterion(d_out_lsm, def_ids.data)
            epoch_loss += loss.item()

            loss.backward()

            encoder_optim.step()
            decoder_optim.step()
        print(f'Epoch {i}, loss {epoch_loss}')

In [25]:
# Quick check
train_file = os.path.join(data_loc, 'Oxford-2019', 'train.txt')
tiny_size = batch * 5
tiny_file = os.path.join(data_loc, 'Oxford-2019', 'tiny.txt')
!head -n {tiny_size} {train_file} > {tiny_file}
tiny = make_data_loader('tiny.txt')
train(epochs=2, data_loader=tiny)

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

if 'TRAINING_JOB_NAME' in os.environ:  # This code is running on the remote SageMaker estimator machine
    train(epochs=epochs, data_loader=train)
elif train_remotely:
    role = sagemaker.get_execution_role()
    output_path = f's3://chegg-ds-data/oboiko/wdm-output'

    pytorch_estimator = PyTorch(entry_point='train.sh',
                                base_job_name='wdm-1',
                                role=role,
                                train_instance_count=1,
                                train_instance_type='ml.g4dn.2xlarge',  # GPU instance
                                train_volume_size=50,
                                train_max_run=86400,  # 24 hours
                                hyperparameters={
                                  'glove_dim': 50,
                                  'max_length': 30,
                                  'data_loc': '/opt/data',
                                  'batch': 50,
                                  'epochs': 10,
                                  'lr': 0.01,
                                  'train_remotely': 0
                                },
                                framework_version='1.6.0',
                                py_version='py3',
                                source_dir='.',  # This entire folder will be transferred to training instance
                                debugger_hook_config=False,
                                output_path=output_path,  # Model files will be uploaded here
                                image_name='954558792927.dkr.ecr.us-west-2.amazonaws.com/sagemaker/wdm:latest'
                     )

    pytorch_estimator.fit('s3://chegg-ds-data/oboiko/wdm/dummy.txt', wait=True)

TODO: For loss function... instead of doing log_softmax, do MSE with actual GloVe vector and minimize this loss function.
Then for BLEU evaluation, you'll need a function to find the closest vector to the one produced by the model.

Interesting to compare these results to log_softmax