In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
# !pip install transformers
# !pip install loguru

In [3]:
from transformers import (EncoderDecoderModel,
                          PreTrainedModel,
                          BertTokenizer,
                          BertGenerationEncoder,
                          BertGenerationDecoder)

In [4]:
import os
import torch
from torch import nn
from tqdm import tqdm
from loguru import logger

In [5]:
logger.add("log.txt")

1

### Determining device to run on

In [6]:
if torch.cuda.is_available():
    main_device = torch.device("cuda:0")
    device_count = torch.cuda.device_count()
    if device_count > 1:
        src_device = torch.device("cuda:1")
    else:
        src_device = main_device
        
    logger.info("Running on the GPU", main_device)
else:
    main_device = torch.device("cpu")
    src_device = torch.device("cpu")
    device_count = 1
    logger.info("Running on the CPU", main_device)

2020-11-01 08:46:59.130 | INFO     | __main__:<module>:15 - Running on the CPU


### Reading hyperparameters from SageMaker env variables

In [7]:
model_type = os.environ.get('SM_HP_MODEL_TYPE', 'bert-base-uncased')
data_loc = os.environ.get('SM_HP_DATA_LOC', '../data')
epochs = int(os.environ.get('SM_HP_EPOCHS', 2))
batch = int(os.environ.get('SM_HP_BATCH', 32)) * device_count
lr = float(os.environ.get('SM_HP_LR', 1e-5))
train_remotely = bool(int(os.environ.get('SM_HP_TRAIN_REMOTELY', 1)))
is_sagemaker_estimator = 'TRAINING_JOB_NAME' in os.environ  # This code is running on the remote SageMaker estimator machine

notebook_name = os.environ['SM_HP_NOTEBOOK_NAME'] if is_sagemaker_estimator else ''

In [8]:
BOS_TOKEN_ID = 101
EOS_TOKEN_ID = 102

### Initializing data loaders for Oxford2019 dataset

In [9]:
from dataset import Oxford2019Dataset
from torch.utils.data import DataLoader

def make_data_loader(filename: str, file_loc: str = os.path.join(data_loc, 'Oxford-2019')) -> DataLoader:
    dataset = Oxford2019Dataset(data_loc=os.path.join(file_loc, filename))
    data_loader = DataLoader(dataset, batch_size=batch, shuffle=True, pin_memory=True)
    return data_loader

train_set = make_data_loader('train.txt')
test_set = make_data_loader('test.txt')
valid_set = make_data_loader('valid.txt')


### Function to run through one epoch
This function is used in training, validation, and testing phases.

In [10]:
from typing import Callable

def run(model: nn.Module,
        data_loader: DataLoader,
        tokenizer: BertTokenizer,
        post_hook: Callable = None):

    loss = 0
    num_batches = len(data_loader)
    for i, (words, examples, defs, _) in enumerate(tqdm(data_loader, disable=is_sagemaker_estimator)):
        input_ids = tokenizer(examples,
                              add_special_tokens=False,
                              padding=True,
                              truncation=True,
                              return_tensors="pt").input_ids
        output_ids = tokenizer(defs,
                               padding=True,
                               truncation=True,
                               return_tensors="pt").input_ids
        
        input_ids = input_ids.to(src_device)
        output_ids = output_ids.to(src_device)
        
        outputs = model(input_ids=input_ids,
                        decoder_input_ids=output_ids,
                        labels=output_ids,
                        return_dict=True)
        batch_loss = outputs.loss.sum()
        loss += batch_loss.item()

        if post_hook is not None:
            post_hook(i, src_device.index, num_batches, batch_loss)
    return loss

### Training loop function

In [42]:
from transformers import AdamW


def train(epochs: int, train_data_loader: DataLoader, valid_data_loader: DataLoader = None):
    model = create_model(model_type)
    optimizer = AdamW(model.parameters(), lr=lr)
    tokenizer = BertTokenizer.from_pretrained(model_type)

    def update_weights(bi, di, num_batches, batch_loss):
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        if bi % 100 == 0:
            logger.info(f'training: device={di}; batch={bi+1}/{num_batches}; batch_error={batch_loss.item()};')

    def valid_loss_progress_log(bi, di, num_batches, batch_loss):
        if bi % 100 == 0:
            logger.info(f'validation: device={di}; batch={bi+1}/{num_batches}; val_batch_error={batch_loss.item()};')

    for i in range(epochs):
        model.train()
        train_loss = run(model, train_data_loader, tokenizer, update_weights)

        if valid_data_loader is not None:
            with torch.no_grad():
                model.eval()
                val_loss = run(model, valid_data_loader, tokenizer, valid_loss_progress_log)
        else:
            val_loss = 'N/A'
        
        logger.info(f'epoch={i}; train_error={train_loss};  valid_error={val_loss};')
    return model.module

### Function for saving the model

In [43]:
def save_model(model: PreTrainedModel):
    out_loc = '/opt/ml/model' if is_sagemaker_estimator else '.'
    !mkdir -p {out_loc}

    model.save_pretrained(out_loc)

    !cp {notebook_name}.py {out_loc}
    !cp {notebook_name}.ipynb {out_loc}
    !cp log.txt {out_loc}

In [56]:
def create_model(model_checkpoint_name: str):
    encoder = BertGenerationEncoder.from_pretrained(model_checkpoint_name,
                                                    bos_token_id=BOS_TOKEN_ID,
                                                    eos_token_id=EOS_TOKEN_ID) # add cross attention layers and use BERT’s cls token as BOS token and sep token as EOS token

    decoder = BertGenerationDecoder.from_pretrained(model_checkpoint_name,
                                                    add_cross_attention=True,
                                                    is_decoder=True,
                                                    bos_token_id=BOS_TOKEN_ID,
                                                    eos_token_id=EOS_TOKEN_ID)
    encoder.requires_grad_(False)
    decoder.bert.embeddings.requires_grad_(False)

    model = EncoderDecoderModel(encoder=encoder, decoder=decoder).to(src_device)

    return model

### Quick sanity check for the training loop

In [55]:
if not is_sagemaker_estimator:
    # model = create_model(model_type)
    # torch.cuda.empty_cache()
    # train_file = os.path.join(data_loc, 'Oxford-2019', 'train.txt')
    # tiny_size = batch * 5
    # tiny_file = os.path.join(data_loc, 'Oxford-2019', 'tiny.txt')
    # !head -n {tiny_size} {train_file} > {tiny_file}
    tiny_set = make_data_loader('tiny.txt')

    model = train(epochs=1, train_data_loader=tiny_set, valid_data_loader=tiny_set)
    save_model(model)

    model.eval()
    tokenizer = BertTokenizer.from_pretrained(model_type)
    input_ids = torch.tensor(tokenizer.encode("Basketball Basketball's early adherents were dispatched to YMCAs throughout the United States, and it quickly spread through the United States and Canada", add_special_tokens=True)).unsqueeze(0).to(src_device)
    generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
    print(tokenizer.decode(generated.squeeze(), skip_special_tokens=True))

AssertionError: Config has to be initialized with encoder and decoder config

### Training
Training can be done either on the same machine where notebook is running or remotely on SageMaker estimator

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

if is_sagemaker_estimator:
    model = train(epochs=epochs, train_data_loader=train_set, valid_data_loader=valid_set)
    save_model(model)
elif train_remotely:
    role = sagemaker.get_execution_role()
    output_path = f's3://chegg-ds-data/oboiko/wdm-output'

    pytorch_estimator = PyTorch(entry_point='train.sh',
                                base_job_name='wdm-1',
                                role=role,
                                train_instance_count=1,
                                train_instance_type='ml.p2.8xlarge',  # GPU instance
                                train_volume_size=50,
                                train_max_run=86400,  # 24 hours
                                hyperparameters={
                                  'model_type': 'bert-base-uncased',
                                  'data_loc': '/opt/data',
                                  'batch': 32,
                                  'epochs': 10,
                                  'lr': 1e-5,
                                  'train_remotely': 0,
                                  'notebook_name': 'main_bert_distributed'
                                },
                                framework_version='1.6.0',
                                py_version='py3',
                                source_dir='.',  # This entire folder will be transferred to training instance
                                debugger_hook_config=False,
                                output_path=output_path,  # Model files will be uploaded here
                                image_name='954558792927.dkr.ecr.us-west-2.amazonaws.com/sagemaker/wdm:latest',
                                metric_definitions=[
                                    {'Name': 'train:error', 'Regex': 'train_error=(.*?);'},
                                    {'Name': 'validation:error', 'Regex': 'valid_error=(.*?);'},
                                    {'Name': 'batch:error', 'Regex': 'batch_error=(.*?);'}
                                ]
                     )

    pytorch_estimator.fit('s3://chegg-ds-data/oboiko/wdm/dummy.txt', wait=False)

In [None]:
TODO: For loss function... instead of doing log_softmax, do MSE with actual GloVe vector and minimize this loss function.
Then for BLEU evaluation, you'll need a function to find the closest vector to the one produced by the model.

Interesting to compare these results to log_softmax

TODO: For loss function... instead of doing log_softmax, do MSE with actual GloVe vector and minimize this loss function.
Then for BLEU evaluation, you'll need a function to find the closest vector to the one produced by the model.

Interesting to compare these results to log_softmax