In [None]:
import os

model_type = os.environ.get('SM_HP_MODEL_TYPE', 'bert-base-uncased')
epochs = int(os.environ.get('SM_HP_EPOCHS', 1))
batch_size = int(os.environ.get('SM_HP_BATCH', 4))
lr = float(os.environ.get('SM_HP_LR', 1e-5))

train_remotely = bool(int(os.environ.get('SM_HP_TRAIN_REMOTELY',1)))  # Should be False for local training

In [None]:
%%capture
!pip install datasets==1.0.2
!pip install transformers
!pip install loguru

In [1]:
import argparse
import os
import shutil
import torch
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data.dataset import Dataset
import transformers
from transformers import AdamW
from transformers import (EncoderDecoderModel,
                          BertTokenizerFast,
                          BertGenerationEncoder,
                          BertGenerationDecoder)
from typing import Callable
from loguru import logger
import datasets
from torch.utils.data import DataLoader

In [None]:
tokenizer = BertTokenizerFast.from_pretrained(model_type)
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

In [None]:
def create_model(model_checkpoint_name):
    encoder = BertGenerationEncoder.from_pretrained(model_checkpoint_name,
                                                    bos_token_id=tokenizer.bos_token,
                                                    eos_token_id=tokenizer.eos_token)
    
    decoder = BertGenerationDecoder.from_pretrained(model_checkpoint_name,
                                                    add_cross_attention=True,
                                                    is_decoder=True,
                                                    bos_token_id=tokenizer.bos_token,
                                                    eos_token_id=tokenizer.eos_token)
    decoder.bert.encoder.requires_grad_(True)
    decoder.lm_head.requires_grad_(True)
    decoder.bert.embeddings.requires_grad_(False)

    encoder.requires_grad_(False)

    model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

    return model

In [None]:
encoder_max_length = 512
decoder_max_length = 128

def run_epoch(model: nn.Module,
              data_loader: DataLoader,
              tokenizer: BertTokenizerFast,
              post_hook: Callable = None):
    
    loss = 0
    num_batches = len(data_loader)
    for i, batch in enumerate(data_loader):
        input_ids = tokenizer(batch["article"], 
                              padding="max_length", 
                              truncation=True, 
                              max_length=encoder_max_length,
                              return_tensors="pt").input_ids
        
        output_ids = tokenizer(batch["highlights"], 
                               padding="max_length", 
                               truncation=True, 
                               max_length=decoder_max_length,
                               return_tensors="pt").input_ids

        outputs = model(input_ids=input_ids,
                        decoder_input_ids=output_ids,
                        labels=output_ids,
                        return_dict=True)
        batch_loss = outputs.loss.sum()
        loss += batch_loss.item()
        
        if post_hook is not None:
            post_hook(i, num_batches, batch_loss)
    return loss

In [None]:
import math

def train(epochs: int,
          lr: float,
          train_data_loader: DataLoader,
          valid_data_loader: DataLoader = None,
          rank = None):
    model = create_model(model_type)
    optimizer = AdamW(model.parameters(), lr=lr)
    tokenizer = BertTokenizerFast.from_pretrained(model_type)

    def update_weights_hook(bi, num_batches, batch_loss):
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        pct10 = math.ceil(num_batches / 10)
        if bi % pct10 == 0 or bi == num_batches-1:
            logger.info(f'training: batch={bi+1}/{num_batches}; batch_error={batch_loss.item():.5f};')
                  
    for i in range(epochs):
        model.train()
        train_loss = run_epoch(model, train_data_loader, tokenizer, update_weights_hook)

        if valid_data_loader is not None:
            with torch.no_grad():
                model.eval()
                val_loss = run_epoch(model, valid_data_loader, tokenizer)
        else:
            val_loss = 'N/A'

        logger.info(f'epoch={i}; train_error={train_loss:.5f};  valid_error={val_loss:.5f};')

    return model

In [None]:
train_set = DataLoader(
    datasets.load_dataset("cnn_dailymail", "3.0.0", split="train").select(range(32)),
    batch_size=batch_size,
    shuffle=True)

valid_set = DataLoader(
    datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]").select(range(12)),
    batch_size=batch_size,
    shuffle=True)

In [None]:
dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
dataset[0].keys()

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

if not train_remotely:
    model = train(epochs=1, lr=lr, train_data_loader=train_set, valid_data_loader=valid_set)
elif:
    role = sagemaker.get_execution_role()
    output_path = f's3://chegg-ds-data/oboiko/bert_demo'

    pytorch_estimator = PyTorch(entry_point='train.sh',
                                base_job_name='bert_demo',
                                role=role,
                                train_instance_count=1,
                                train_instance_type='ml.p2.xlarge',  # GPU instance
                                train_volume_size=50,
                                train_max_run=86400,  # 24 hours
                                hyperparameters={
                                  'model_type': 'bert-base-uncased',
                                  'batch': 32,
                                  'epochs': 10,
                                  'lr': 1e-5,
                                    
                                  'train_remotely': 0,
                                  'notebook_name': 'simple_bert2bert_SageMaker'  # Inconvenient and error prone >:(
                                },
                                framework_version='1.6.0',
                                py_version='py3',
                                source_dir='.',  # This entire folder will be transferred to training instance
                                debugger_hook_config=False,
                                output_path=output_path,  # Model files will be uploaded here
                                image_name='954558792927.dkr.ecr.us-west-2.amazonaws.com/pytorch-yolov5:latest',
                                metric_definitions=[
                                    {'Name': 'train:error', 'Regex': 'train_error=(.*?);'},
                                    {'Name': 'validation:error', 'Regex': 'valid_error=(.*?);'},
                                    {'Name': 'batch:error', 'Regex': 'batch_error=(.*?);'}
                                ]
                     )
    pytorch_estimator.fit('s3://chegg-ds-data/oboiko/wdm/dummy.txt', wait=False)