In [1]:
model_type = 'bert-base-uncased'
epochs = 1
batch_size = 4
lr = 1e-5

In [2]:
%%capture
!pip install datasets==1.0.2
!pip install transformers
!pip install loguru

In [3]:
import argparse
import os
import shutil
import torch
from torch import nn
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data.dataset import Dataset
import transformers
from transformers import AdamW
from transformers import (EncoderDecoderModel,
                          BertTokenizerFast,
                          BertGenerationEncoder,
                          BertGenerationDecoder)
from typing import Callable
from loguru import logger
import datasets
from torch.utils.data import DataLoader

In [4]:
tokenizer = BertTokenizerFast.from_pretrained(model_type)
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

In [5]:
def create_model(model_checkpoint_name):
    encoder = BertGenerationEncoder.from_pretrained(model_checkpoint_name,
                                                    bos_token_id=tokenizer.bos_token,
                                                    eos_token_id=tokenizer.eos_token)
    
    decoder = BertGenerationDecoder.from_pretrained(model_checkpoint_name,
                                                    add_cross_attention=True,
                                                    is_decoder=True,
                                                    bos_token_id=tokenizer.bos_token,
                                                    eos_token_id=tokenizer.eos_token)
    decoder.bert.encoder.requires_grad_(True)
    decoder.lm_head.requires_grad_(True)
    decoder.bert.embeddings.requires_grad_(False)

    encoder.requires_grad_(False)

    model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

    return model

In [6]:
from tqdm import tqdm

encoder_max_length = 512
decoder_max_length = 128

def run_epoch(model: nn.Module,
              data_loader: DataLoader,
              tokenizer: BertTokenizerFast,
              post_hook: Callable = None):
    
    loss = 0
    num_batches = len(data_loader)
    for i, batch in enumerate(data_loader):
        input_ids = tokenizer(batch["article"], 
                              padding="max_length", 
                              truncation=True, 
                              max_length=encoder_max_length,
                              return_tensors="pt").input_ids
        
        output_ids = tokenizer(batch["highlights"], 
                               padding="max_length", 
                               truncation=True, 
                               max_length=decoder_max_length,
                               return_tensors="pt").input_ids

        outputs = model(input_ids=input_ids,
                        decoder_input_ids=output_ids,
                        labels=output_ids,
                        return_dict=True)
        batch_loss = outputs.loss.sum()
        loss += batch_loss.item()
        
        if post_hook is not None:
            post_hook(i, num_batches, batch_loss)
    return loss

In [7]:
import math

def train(epochs: int,
          lr: float,
          train_data_loader: DataLoader,
          valid_data_loader: DataLoader = None,
          rank = None):
    model = create_model(model_type)
    optimizer = AdamW(model.parameters(), lr=lr)
    tokenizer = BertTokenizerFast.from_pretrained(model_type)

    def update_weights_hook(bi, num_batches, batch_loss):
        batch_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        pct10 = math.ceil(num_batches / 10)
        if bi % pct10 == 0 or bi == num_batches-1:
            logger.info(f'training: batch={bi+1}/{num_batches}; batch_error={batch_loss.item():.5f};')
                  
    for i in range(epochs):
        model.train()
        train_loss = run_epoch(model, train_data_loader, tokenizer, update_weights_hook)

        if valid_data_loader is not None:
            with torch.no_grad():
                model.eval()
                val_loss = run_epoch(model, valid_data_loader, tokenizer)
        else:
            val_loss = 'N/A'

        logger.info(f'epoch={i}; train_error={train_loss:.5f};  valid_error={val_loss:.5f};')

    return model

In [8]:
train_set = DataLoader(
    datasets.load_dataset("cnn_dailymail", "3.0.0", split="train").select(range(32)),
    batch_size=batch_size,
    shuffle=True)

valid_set = DataLoader(
    datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]").select(range(12)),
    batch_size=batch_size,
    shuffle=True)

Reusing dataset cnn_dailymail (/Users/oboiko/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/0128610a44e10f25b4af6689441c72af86205282d26399642f7db38fa7535602)
Reusing dataset cnn_dailymail (/Users/oboiko/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/0128610a44e10f25b4af6689441c72af86205282d26399642f7db38fa7535602)


In [9]:
dataset = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
dataset[0].keys()

Reusing dataset cnn_dailymail (/Users/oboiko/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/0128610a44e10f25b4af6689441c72af86205282d26399642f7db38fa7535602)


dict_keys(['article', 'highlights', 'id'])

In [10]:
model = train(epochs=1, lr=lr, train_data_loader=train_set, valid_data_loader=valid_set)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertGenerationEncoder: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'bert.embeddings.token_type_embeddings.weight']
- This IS expected if you are initializing BertGenerationEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertGenerationEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertGenerationDecoder: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.

2021-01-12 10:54:21.659 | INFO     | __main__:update_weights_hook:19 - training: batch=1/8; batch_error=10.18645;
2021-01-12 10:54:34.854 | INFO     | __main__:update_weights_hook:19 - training: batch=2/8; batch_error=10.05975;
2021-01-12 10:54:47.611 | INFO     | __main__:update_weights_hook:19 - training: batch=3/8; batch_error=9.99084;
2021-01-12 10:55:48.378 | INFO     | __main__:update_weights_hook:19 - training: batch=4/8; batch_error=9.92017;
2021-01-12 10:56:06.197 | INFO     | __main__:update_weights_hook:19 - training: batch=5/8; batch_error=9.78156;
2021-01-12 10:56:18.687 | INFO     | __main__:update_weights_hook:19 - training: batch=6/8; batch_error=9.58032;
2021-01-12 10:56:31.698 | INFO     | __main__:update_weights_hook:19 - training: batch=7/8; batch_error=9.51603;
2021-01-12 10:56:44.737 | INFO     | __main__:update_weights_hook:19 - training: batch=8/8; batch_error=9.27913;
2021-01-12 10:57:06.034 | INFO     | __main__:train:32 - epoch=0; train_error=78.31426;  valid