# PyTorch BERT Multi-Model Trainer + KFolds🎯

This is training script written in Vanilla PyTorch with a Custom Trainer Class. I hope it can help you in developing more sophisticated models.

Along with BERT Base Uncased and Cased, I have also included DistilBert Base Model code. If you want to use that, you'll have to make some small changes in the code. Refer to official HuggingFace documentation for more.

Think of this notebook has a skeleton for all BERT based Models (in-fact any PyTorch Hugginface model in reality). You can change chunks of code to suit your needs and it will work efficiently in most cases.

📌 Inference (Submission) Notebook: https://www.kaggle.com/heyytanay/inference-0-6-lb-vanilla-pytorch-bert-starter

📌 My EDA and Multi Linear Models Notebook: https://www.kaggle.com/heyytanay/commonlit-readability-eda-multi-models


**Feel free to fork and change the models and do some preprocessing, but if you do please leave an upvote : )**

In [1]:
import platform
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import gc
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold

import torch
import transformers
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader

import warnings
warnings.simplefilter('ignore')

We define a Config class to store variables and functions that are to be used globally inside our training script.
This makes the code more modular and easy to approach at the same time.

In [2]:
class Config:
    NB_EPOCHS = 4
    LR = 1e-6
    MAX_LEN = 185
    N_SPLITS = 5
    TRAIN_BS = 32
    VALID_BS = 64
    BERT_MODEL = '../input/huggingface-bert/bert-large-uncased'
    FILE_NAME = '../input/commonlitreadabilityprize/train.csv'
    TOKENIZER = transformers.BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
    scaler = GradScaler()

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Below is a Custom dataset we are making. This dataset consists of 3 classes:
* `__init__()`: Constructor function. Deals with class instance initalization, variable definition, etc
* `__getitem__()`: This function deals with getting the elements when the dataset is called in iteration
* `__len__()`: Length function overload. This function just returns the length of the dataset.

In [3]:
# Custom dataset: https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
class BERTDataset(Dataset):
    def __init__(self, review, target=None, is_test=False):
        self.review = review
        self.target = target
        self.is_test = is_test
        self.tokenizer = Config.TOKENIZER
        self.max_len = Config.MAX_LEN
    
    def __len__(self):
        return len(self.review)
    
    def __getitem__(self, idx):
        review = str(self.review[idx])
        review = ' '.join(review.split())
        global inputs
        
        inputs = self.tokenizer.encode_plus(
            review,
            None,
            truncation=True,
            add_special_tokens=True,
            max_length=self.max_len,
            pad_to_max_length=True
        )
        # https://huggingface.co/transformers/preprocessing.html
        # The "input_ids" are the indices corresponding to each token in our sentence
        # The attention_mask is a binary tensor indicating the position of the padded indices so that the model does not attend to them
        #  token_type_ids are for: they indicate to the model which part of the inputs correspond to the first sentence and which part corresponds to the second sentence.
        ids = torch.tensor(inputs['input_ids'], dtype=torch.long)
        mask = torch.tensor(inputs['attention_mask'], dtype=torch.long)
        token_type_ids = torch.tensor(inputs['token_type_ids'], dtype=torch.long)
        
        if self.is_test:
            # no target if is_test
            return {
                'ids': ids,
                'mask': mask,
                'token_type_ids': token_type_ids,
            }
        else:    
            targets = torch.tensor(self.target[idx], dtype=torch.float)
            return {
                'ids': ids,
                'mask': mask,
                'token_type_ids': token_type_ids,
                'targets': targets
            }

Below is a custom `Trainer` class that I wrote from scratch to facilitate my training and validation sub-routines.

This class hence provides a very "fastai" type interface for doing training.

In [4]:
class Trainer:
    def __init__(
        self, 
        model, 
        optimizer, 
        scheduler, 
        train_dataloader, 
        valid_dataloader,
        device
    ):
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.train_data = train_dataloader
        self.valid_data = valid_dataloader
        self.loss_fn = self.yield_loss
        self.device = device
        
    def yield_loss(self, outputs, targets):
        """
        This is the loss function for this task
        """
        return torch.sqrt(nn.MSELoss()(outputs, targets))
    
    def train_one_epoch(self):
        """
        This function trains the model for 1 epoch through all batches
        """
        prog_bar = tqdm(enumerate(self.train_data), total=len(self.train_data))
        self.model.train()
        # https://discuss.pytorch.org/t/what-step-backward-and-zero-grad-do/33301
        # torch.cuda.amp.autocast enable autocasting for chosen regions
        with autocast():
            for idx, inputs in prog_bar:
                ids = inputs['ids'].to(self.device, dtype=torch.long)
                mask = inputs['mask'].to(self.device, dtype=torch.long)
                ttis = inputs['token_type_ids'].to(self.device, dtype=torch.long)
                targets = inputs['targets'].to(self.device, dtype=torch.float)

                outputs = self.model(ids=ids, mask=mask, token_type_ids=ttis)            

                loss = self.loss_fn(outputs, targets)
                prog_bar.set_description('loss: {:.2f}'.format(loss.item()))
                
                # common pattern
                Config.scaler.scale(loss).backward()
                Config.scaler.step(self.optimizer)
                Config.scaler.update()
                self.optimizer.zero_grad()
                self.scheduler.step()
    
    def valid_one_epoch(self):
        """
        This function validates the model for one epoch through all batches of the valid dataset
        It also returns the validation Root mean squared error for assesing model performance.
        """
        prog_bar = tqdm(enumerate(self.valid_data), total=len(self.valid_data))
        self.model.eval()
        all_targets = []
        all_predictions = []
        with torch.no_grad():
            for idx, inputs in prog_bar:
                ids = inputs['ids'].to(self.device, dtype=torch.long)
                mask = inputs['mask'].to(self.device, dtype=torch.long)
                ttis = inputs['token_type_ids'].to(self.device, dtype=torch.long)
                targets = inputs['targets'].to(self.device, dtype=torch.float)

                outputs = self.model(ids=ids, mask=mask, token_type_ids=ttis)
                
                # .detach().cpu().numpy() => https://zhuanlan.zhihu.com/p/165219346
                all_targets.extend(targets.cpu().detach().numpy().tolist())
                all_predictions.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())

        val_rmse_loss = np.sqrt(mean_squared_error(all_targets, all_predictions))
        print('Validation RMSE: {:.2f}'.format(val_rmse_loss))
        
        return val_rmse_loss
    
    def get_model(self):
        return self.model

Below are multiple model classes we can use for this task.

In this notebook, I am only training the model on `bert-base-uncased` but you can train it on whatever model you want.

In [5]:
# Model
class BERT_BASE_UNCASED(nn.Module):
    def __init__(self):
        super(BERT_BASE_UNCASED, self).__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
        self.drop = nn.Dropout(0.3)
        self.fc = nn.Linear(768, 128)
        self.out = nn.Linear(128, 1)
    
    def forward(self, ids, mask, token_type_ids):
        _, output = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids, return_dict=False)
        output = self.drop(output)
        output = self.fc(output)
        output = self.out(output)
        return output

class BERT_BASE_CASED(nn.Module):
    def __init__(self):
        super(BERT_BASE_CASED, self).__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-cased')
        self.drop = nn.Dropout(0.5)
        self.fc = nn.Linear(788, 128)
        self.out = nn.Linear(128, 1)
    
    def forward(self, ids, mask, token_type_ids):
        o1, o2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids, return_dict=False)
        o2 = self.drop(o2)
        o2 = self.fc(o2)
        o2 = self.out(o2)
        return o2
    
class DBERT_BASE_UNCASED(nn.Module):
    def __init__(self):
        super(DBERT_BASE_UNCASED, self).__init__()
        self.dbert = transformers.DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.drop = nn.Dropout(0.2)
        self.out = nn.Linear(768, 1)
    
    def forward(self, ids, mask):
        output = self.dbert(ids, attention_mask=mask, return_dict=False)
        output = self.drop(output)
        output = self.out(output)
        return output

Below is the function to get the appropriate optimizer.

In [6]:
def yield_optimizer(model):
    """
    Returns optimizer for specific parameters
    """
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.003,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    return transformers.AdamW(optimizer_parameters, lr=Config.LR)

Below is the main training code.

I am also printing valuable information at multiple steps to make the code execution see more lively and not just absolute silence. It also used for future debugging purposes.

The below code for stratified kfolds is inspired from Abhishek Thakur's [Notebook](https://www.kaggle.com/abhishek/step-1-create-folds) on creating Kfolds.

In [7]:
# Training Code
if __name__ == '__main__':
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
        DEVICE = torch.device('cuda:0')
    else:
        print("\n[INFO] GPU not found. Using CPU: {}\n".format(platform.processor()))
        DEVICE = torch.device('cpu')
    
    data = pd.read_csv(Config.FILE_NAME)
    # https://stackoverflow.com/questions/29576430/shuffle-dataframe-rows
    # shuffle data
    data = data.sample(frac=1).reset_index(drop=True)
    data = data[['excerpt', 'target']]
    
    # Do Kfolds training and cross validation
    # The StratifiedKFold are made by preserving the percentage of samples for each class.
    # https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html
    kf = StratifiedKFold(n_splits=Config.N_SPLITS)
    nb_bins = int(np.floor(1 + np.log2(len(data))))
    # data.loc[:, 'bins'] => select "every row", but only "bin" column
    # add a new column bins to represent bin #
    data.loc[:, 'bins'] = pd.cut(data['target'], bins=nb_bins, labels=False)
    
    for fold, (train_idx, valid_idx) in enumerate(kf.split(X=data, y=data['bins'].values)):
        print(f"Fold: {fold}")
        print('-'*20)
        
        train_data = data.loc[train_idx]
        valid_data = data.loc[valid_idx]
        
        train_set = BERTDataset(
            review = train_data['excerpt'].values,
            target = train_data['target'].values
        )

        valid_set = BERTDataset(
            review = valid_data['excerpt'].values,
            target = valid_data['target'].values
        )

        train = DataLoader(
            train_set,
            batch_size = Config.TRAIN_BS,
            shuffle = True,
            num_workers=8
        )

        valid = DataLoader(
            valid_set,
            batch_size = Config.VALID_BS,
            shuffle = False,
            num_workers=8
        )
        
        # https://stackoverflow.com/questions/63061779/pytorch-when-do-i-need-to-use-todevice-on-a-model-or-tensor
        # pytorch syntax model.to(device)
        model = BERT_BASE_UNCASED().to(DEVICE)
        nb_train_steps = int(len(train_data) / Config.TRAIN_BS * Config.NB_EPOCHS)
        optimizer = yield_optimizer(model)
        scheduler = transformers.get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=nb_train_steps
        )

        trainer = Trainer(model, optimizer, scheduler, train, valid, DEVICE)

        best_loss = 100
        for epoch in range(1, Config.NB_EPOCHS+1):
            print(f"\n{'--'*5} EPOCH: {epoch} {'--'*5}\n")

            # Train for 1 epoch
            trainer.train_one_epoch()

            # Validate for 1 epoch
            current_loss = trainer.valid_one_epoch()

            if current_loss < best_loss:
                print(f"Saving best model in this fold: {current_loss:.4f}")
                torch.save(trainer.get_model().state_dict(), f"bert_base_uncased_fold_{fold}.pt")
                best_loss = current_loss
        
        print(f"Best RMSE in fold: {fold} was: {best_loss:.4f}")
        print(f"Final RMSE in fold: {fold} was: {current_loss:.4f}")
        
        del train_set, valid_set, train, valid, model, optimizer, scheduler, trainer, current_loss
        gc.collect()
        torch.cuda.empty_cache()

[INFO] Using GPU: Tesla P100-PCIE-16GB

Fold: 0
--------------------


Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]


---------- EPOCH: 1 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.75
Saving best model in this fold: 1.7464

---------- EPOCH: 2 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.69
Saving best model in this fold: 1.6892

---------- EPOCH: 3 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.66
Saving best model in this fold: 1.6620

---------- EPOCH: 4 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.66
Saving best model in this fold: 1.6554
Best RMSE in fold: 0 was: 1.6554
Final RMSE in fold: 0 was: 1.6554
Fold: 1
--------------------

---------- EPOCH: 1 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.67
Saving best model in this fold: 1.6653

---------- EPOCH: 2 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.63
Saving best model in this fold: 1.6346

---------- EPOCH: 3 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.62
Saving best model in this fold: 1.6242

---------- EPOCH: 4 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.62
Saving best model in this fold: 1.6217
Best RMSE in fold: 1 was: 1.6217
Final RMSE in fold: 1 was: 1.6217
Fold: 2
--------------------

---------- EPOCH: 1 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.66
Saving best model in this fold: 1.6583

---------- EPOCH: 2 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.62
Saving best model in this fold: 1.6241

---------- EPOCH: 3 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.61
Saving best model in this fold: 1.6126

---------- EPOCH: 4 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.61
Saving best model in this fold: 1.6101
Best RMSE in fold: 2 was: 1.6101
Final RMSE in fold: 2 was: 1.6101
Fold: 3
--------------------

---------- EPOCH: 1 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.64
Saving best model in this fold: 1.6428

---------- EPOCH: 2 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.62
Saving best model in this fold: 1.6191

---------- EPOCH: 3 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.61
Saving best model in this fold: 1.6122

---------- EPOCH: 4 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.61
Best RMSE in fold: 3 was: 1.6122
Final RMSE in fold: 3 was: 1.6123
Fold: 4
--------------------

---------- EPOCH: 1 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.71
Saving best model in this fold: 1.7078

---------- EPOCH: 2 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.67
Saving best model in this fold: 1.6698

---------- EPOCH: 3 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.65
Saving best model in this fold: 1.6477

---------- EPOCH: 4 ----------



  0%|          | 0/71 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

Validation RMSE: 1.64
Saving best model in this fold: 1.6417
Best RMSE in fold: 4 was: 1.6417
Final RMSE in fold: 4 was: 1.6417


**That's it folks!**

**If you like my work, don't forget to leave an upvote!**