In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import (EncoderDecoderModel,
                          BertTokenizer,
                          BertGenerationEncoder,
                          BertGenerationDecoder)

In [3]:
import os
import torch
from torch import nn
from tqdm import tqdm

### Determining device to run on

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    device = torch.device("cpu")
    print("Running on the CPU")

Running on the CPU


### Reading hyperparameters from SageMaker env variables

In [5]:
model_type = os.environ.get('SM_HP_MODEL_TYPE', 'bert-large-uncased')
data_loc = os.environ.get('SM_HP_DATA_LOC', '../data')
epochs = int(os.environ.get('SM_HP_EPOCHS', 2))
batch = int(os.environ.get('SM_HP_BATCH', 32))
lr = float(os.environ.get('SM_HP_LR', 1e-5))
train_remotely = bool(int(os.environ.get('SM_HP_TRAIN_REMOTELY', 1)))
is_sagemaker_estimator = 'TRAINING_JOB_NAME' in os.environ  # This code is running on the remote SageMaker estimator machine

In [6]:
BOS_TOKEN_ID = 101
EOS_TOKEN_ID = 102

### Initializing data loaders for Oxford2019 dataset

In [7]:
from dataset import Oxford2019Dataset
from torch.utils.data import DataLoader

def make_data_loader(filename: str, file_loc: str = os.path.join(data_loc, 'Oxford-2019')) -> DataLoader:
    dataset = Oxford2019Dataset(data_loc=os.path.join(file_loc, filename))
    data_loader = DataLoader(dataset, batch_size=batch, shuffle=True)
    return data_loader

train_set = make_data_loader('train.txt')
test_set = make_data_loader('test.txt')
valid_set = make_data_loader('valid.txt')


In [16]:
encoder = BertGenerationEncoder.from_pretrained(model_type,
                                                bos_token_id=BOS_TOKEN_ID,
                                                eos_token_id=EOS_TOKEN_ID) # add cross attention layers and use BERT’s cls token as BOS token and sep token as EOS token

decoder = BertGenerationDecoder.from_pretrained(model_type,
                                                add_cross_attention=True,
                                                is_decoder=True,
                                                bos_token_id=BOS_TOKEN_ID,
                                                eos_token_id=EOS_TOKEN_ID)
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertGenerationEncoder: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.embeddings.token_type_embeddings.weight']
- This IS expected if you are initializing BertGenerationEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertGenerationEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertGenerationDecoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationshi

### Function to run through one epoch
This function is used in training, validation, and testing phases.

In [23]:
from typing import Callable

def run(model: nn.Module,
        data_loader: DataLoader,
        tokenizer: BertTokenizer,
        post_hook: Callable = lambda b: ''):

    loss = 0
    for words, examples, defs, _ in tqdm(data_loader, disable=is_sagemaker_estimator):
        input_ids = tokenizer(examples,
                              add_special_tokens=False,
                              padding=True,
                              truncation=True,
                              return_tensors="pt").input_ids
        output_ids = tokenizer(defs,
                               padding=True,
                               truncation=True,
                               return_tensors="pt").input_ids

        outputs = model(input_ids=input_ids,
                        decoder_input_ids=output_ids,
                        labels=output_ids,
                        return_dict=True)
        batch_loss = outputs.loss
        loss += batch_loss.item()

        post_hook(batch_loss)
    return loss

### Training loop function

In [21]:
from transformers import AdamW


def train(epochs: int, train_data_loader: DataLoader, valid_data_loader: DataLoader = None, model: nn.Module = None):
    if model is None:
        encoder = BertGenerationEncoder.from_pretrained(model_type,
                                                        bos_token_id=BOS_TOKEN_ID,
                                                        eos_token_id=EOS_TOKEN_ID) # add cross attention layers and use BERT’s cls token as BOS token and sep token as EOS token

        decoder = BertGenerationDecoder.from_pretrained(model_type,
                                                        add_cross_attention=True,
                                                        is_decoder=True,
                                                        bos_token_id=BOS_TOKEN_ID,
                                                        eos_token_id=EOS_TOKEN_ID)
        model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

    optimizer = AdamW(model.parameters(), lr=lr)

    tokenizer = BertTokenizer.from_pretrained(model_type)

    def update_weights(batch_loss):
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    for i in range(epochs):
        model.train()
        train_loss = run(model, train_data_loader, tokenizer, update_weights)

        if valid_data_loader is not None:
            with torch.no_grad():
                model.eval()
                val_loss = run(model, valid_data_loader, tokenizer)
        else:
            val_loss = 'N/A'

        msg = f'Epoch {i} | train loss {train_loss}, val loss {val_loss}'
        print(msg)
        !echo '{msg}' >> log.txt
    return model

### Quick sanity check for the training loop

In [None]:
train_file = os.path.join(data_loc, 'Oxford-2019', 'train.txt')
tiny_size = batch * 5
tiny_file = os.path.join(data_loc, 'Oxford-2019', 'tiny.txt')
!head -n {tiny_size} {train_file} > {tiny_file}
tiny_set = make_data_loader('tiny.txt')
model = train(epochs=2, train_data_loader=tiny_set, valid_data_loader=tiny_set, model=model)

 80%|████████  | 4/5 [05:50<01:23, 83.60s/it]

### Function for saving the model

In [13]:
def save_model(model):
    out_loc = '/opt/ml/model' if is_sagemaker_estimator else '.'
    !mkdir -p {out_loc}

    #TODO: Save model

    !cp main.py {out_loc}
    !cp main.ipynb {out_loc}
    !cp log.txt {out_loc}

### Training
Training can be done either on the same machine where notebook is running or remotely on SageMaker estimator

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

if is_sagemaker_estimator:
    model = train(epochs=epochs, train_data_loader=train_set, valid_data_loader=valid_set)
    save_model(model)
elif train_remotely:
    role = sagemaker.get_execution_role()
    output_path = f's3://chegg-ds-data/oboiko/wdm-output'

    pytorch_estimator = PyTorch(entry_point='train.sh',
                                base_job_name='wdm-1',
                                role=role,
                                train_instance_count=1,
                                train_instance_type='ml.g4dn.2xlarge',  # GPU instance
                                train_volume_size=50,
                                train_max_run=86400,  # 24 hours
                                hyperparameters={
                                  'model_type': 'bert-large-uncased',
                                  'data_loc': '/opt/data',
                                  'batch': 50,
                                  'epochs': 15,
                                  'lr': 0.01,
                                  'train_remotely': 0
                                },
                                framework_version='1.6.0',
                                py_version='py3',
                                source_dir='.',  # This entire folder will be transferred to training instance
                                debugger_hook_config=False,
                                output_path=output_path,  # Model files will be uploaded here
                                image_name='954558792927.dkr.ecr.us-west-2.amazonaws.com/sagemaker/wdm:latest'
                     )

    pytorch_estimator.fit('s3://chegg-ds-data/oboiko/wdm/dummy.txt', wait=False)

TODO: For loss function... instead of doing log_softmax, do MSE with actual GloVe vector and minimize this loss function.
Then for BLEU evaluation, you'll need a function to find the closest vector to the one produced by the model.

Interesting to compare these results to log_softmax