Example of fine-tuning Bert with Catalyst for https://www.kaggle.com/c/nlp-getting-started

Some pieces of code are taken from https://github.com/Yorko/bert-finetuning-catalyst

In [1]:
!pip install -U catalyst
!pip install transformers

Requirement already up-to-date: catalyst in /usr/local/lib/python3.6/dist-packages (20.12)


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import os
import torch
import torch.nn as nn
import numpy
from transformers import BertTokenizer, AutoConfig, AutoModel
from torch.utils.data import Dataset, DataLoader

In [4]:
SEED = 42
BATCH_SIZE = 16
ACCUM_STEPS = 4  # one optimization step for that many backward passes
LEARNING_RATE = 3e-5   # learning rate is typically ~1e-5 for transformers
EPOCHS = 4 # smth around 2-6 epochs is typically fine when finetuning transformers
MAX_LEN = 256 # TODO maybe better in flow
BERT_SHORTCUT_NAME = 'bert-base-multilingual-cased'

DATA_PATH =  "/content/drive/My Drive/bell/fine_tune_bench/disaster_tweet/"
LOG_DIR = './logdir/'
TRAIN_DATA = "train.csv"
VALID_DATA = "valid.csv"
TEST_DATA = "test.csv"

SENTENCE_LABEL = 'text'
TARGET_LABEL = 'target'

In [5]:
# TODO 
# - split input train for train and valid if valid doesn't exist
# - slanted traingle learning rate or similar

In [6]:
class BertClassificationDataset(Dataset):
  def __init__(self, 
               texts, 
               labels = None, 
               label2class = None,
               max_len = 512, 
               bert_model_name = BERT_SHORTCUT_NAME,
      ):
    self.texts = texts
    self.labels = labels
    self.label2class = label2class
    self.max_len = max_len
    if self.label2class is None and labels is not None:
      # using this instead of `sklearn.preprocessing.LabelEncoder`
      # no easily handle unknown target values
      self.label2class = dict(zip(sorted(set(labels)), range(len(set(labels)))))    
    self.tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    # suppresses tokenizer warnings
    # logging.getLogger("transformers.tokenization_utils").setLevel(logging.FATAL)


  def __len__(self):
    return len(self.texts)

  def __getitem__(self, index):
        # encoding the text
        x = self.texts[index]

        # a dictionary with `input_ids` and `attention_mask` as keys
        output_dict = self.tokenizer.encode_plus(
            x,
            add_special_tokens=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt",
            truncation=True,
            return_attention_mask=True,
        )

        # for Catalyst, there needs to be a key called features
        output_dict["features"] = output_dict["input_ids"].squeeze(0)
        del output_dict["input_ids"]

        # encoding target
        if self.labels is not None:
            y = self.labels[index]
            y_encoded = torch.Tensor([self.label2class.get(y, -1)]).long().squeeze(0)
            output_dict["targets"] = y_encoded

        return output_dict

In [7]:
import pandas as pd

train_df = pd.read_csv(os.path.join(DATA_PATH, TRAIN_DATA))
valid_df = pd.read_csv(os.path.join(DATA_PATH, VALID_DATA))
test_df = pd.read_csv(os.path.join(DATA_PATH, TEST_DATA))

train_dataset = BertClassificationDataset(
        texts=train_df[SENTENCE_LABEL].values.tolist(),
        labels=train_df[TARGET_LABEL].values,
        max_len=MAX_LEN,
    )

valid_dataset = BertClassificationDataset(
        texts=valid_df[SENTENCE_LABEL].values.tolist(),
        labels=valid_df[TARGET_LABEL].values,
        max_len=MAX_LEN,
    )

test_dataset = BertClassificationDataset(
        texts=test_df[SENTENCE_LABEL].values.tolist(),
        max_len=MAX_LEN,
    )

train_val_loaders = {
        "train": DataLoader(
            dataset=train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
        ),
        "valid": DataLoader(
            dataset=valid_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
        ),
    }

test_loaders = {
        "test": DataLoader(
            dataset=test_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
        )
    }

In [8]:
# TODO compare classifcation with pooling and CLS token
class BertClassifierModel(nn.Module):
  def __init__(self, num_classes, bert_model_name=BERT_SHORTCUT_NAME, freeze_bert = False, dropout = 0.3):
    super().__init__()
    
    config = AutoConfig.from_pretrained(bert_model_name, num_labels=num_classes)
    self.model = AutoModel.from_pretrained(bert_model_name, config=config)
    #Freeze bert layers
    if freeze_bert:
      for p in self.model.parameters():
        p.requires_grad = False

    self.classifier = nn.Linear(config.hidden_size, num_classes)
    self.dropout = nn.Dropout(dropout)    
  
  def forward(self, features, attention_mask=None, head_mask=None):
    assert attention_mask is not None, "attention mask is none"

    # taking BERTModel output
    # see https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel
    bert_output = self.model(
            input_ids=features, attention_mask=attention_mask, head_mask=head_mask
        )
    # we only need the hidden state here and don't need
    # transformer output, so index 0
    seq_output = bert_output[0]  # (bs, seq_len, dim)
    # mean pooling, i.e. getting average representation of all tokens
    pooled_output = seq_output.mean(axis=1)  # (bs, dim)
    pooled_output = self.dropout(pooled_output)  # (bs, dim)
    scores = self.classifier(pooled_output)  # (bs, num_classes)

    return scores

In [9]:
class BertForSequenceClassification(nn.Module):
    """
    Simplified version of the same class by HuggingFace.
    See transformers/modeling_distilbert.py in the transformers repository.
    """

    def __init__(
        self, pretrained_model_name: str, num_classes: int = None, dropout: float = 0.3
    ):
        """
        Args:
            pretrained_model_name (str): HuggingFace model name.
                See transformers/modeling_auto.py
            num_classes (int): the number of class labels
                in the classification task
        """
        super().__init__()

        config = AutoConfig.from_pretrained(
            pretrained_model_name, num_labels=num_classes
        )

        self.model = AutoModel.from_pretrained(pretrained_model_name, config=config)
        self.classifier = nn.Linear(config.hidden_size, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, features, attention_mask=None, head_mask=None):

        assert attention_mask is not None, "attention mask is none"

        # taking BERTModel output
        # see https://huggingface.co/transformers/model_doc/bert.html#transformers.BertModel
        bert_output = self.model(
            input_ids=features, attention_mask=attention_mask, head_mask=head_mask
        )
        # we only need the hidden state here and don't need
        # transformer output, so index 0
        seq_output = bert_output[0]  # (bs, seq_len, dim)
        
        # mean pooling, i.e. getting average representation of all tokens
        #pooled_output = seq_output.mean(axis=1)  # (bs, dim)
        #pooled_output = self.dropout(pooled_output)  # (bs, dim)
        #scores = self.classifier(pooled_output)  # (bs, num_classes)
        
        cls_rep = seq_output[:, 0]
        cls_rep = self.dropout(cls_rep)  # (bs, dim)
        scores = self.classifier(cls_rep)  # (bs, num_classes)

        return scores

In [10]:
# from transformers import BertForSequenceClassification

# num_classes = len(set(train_df[TARGET_LABEL].values))
# config = AutoConfig.from_pretrained(BERT_SHORTCUT_NAME, num_labels=num_classes)
# model = BertForSequenceClassification.from_pretrained(BERT_SHORTCUT_NAME, config=config)

In [11]:
# class BertClassifierModel(nn.Module):
  # def __init__(self, num_classes, bert_model_name=BERT_SHORTCUT_NAME, freeze_bert = False, dropout = 0.3):
num_classes = len(set(train_df[TARGET_LABEL].values))
model = BertClassifierModel(num_classes)

In [12]:
from catalyst.dl import SupervisedRunner
from catalyst.dl.callbacks import (
    AccuracyCallback,
    CheckpointCallback,
    InferCallback,
    OptimizerCallback,
)
from catalyst.utils import prepare_cudnn, set_global_seed


criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(
    model.parameters(), lr=float(LEARNING_RATE)
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)


set_global_seed(SEED)
prepare_cudnn(deterministic=True)

runner = SupervisedRunner(input_key=("features", "attention_mask"))

runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    loaders=train_val_loaders,
    callbacks=[
        AccuracyCallback(num_classes=int(num_classes)),
        OptimizerCallback(accumulation_steps=int(ACCUM_STEPS)),
    ],
    logdir=LOG_DIR,
    num_epochs=EPOCHS,
    verbose=True,
)

# and running inference
torch.cuda.empty_cache()
runner.infer(
    model=model,
    loaders=test_loaders,
    callbacks=[
        CheckpointCallback(
            resume=f"{LOG_DIR}/checkpoints/best.pth"
        ),
        InferCallback(),
    ],
    verbose=True,
)

# lastly, saving predicted scores for the test set
predicted_scores = runner.callbacks[0].predictions["logits"]

1/4 * Epoch (train): 100% 381/381 [02:36<00:00,  2.43it/s, accuracy01=0.900, loss=0.299]
1/4 * Epoch (valid): 100% 96/96 [00:14<00:00,  6.78it/s, accuracy01=0.500, loss=0.866]
[2021-02-07 23:37:15,613] 
1/4 * Epoch 1 (_base): lr=3.000e-05 | momentum=0.9000
1/4 * Epoch 1 (train): accuracy01=0.7901 | loss=0.4641
1/4 * Epoch 1 (valid): accuracy01=0.8338 | loss=0.3913
2/4 * Epoch (train): 100% 381/381 [02:36<00:00,  2.44it/s, accuracy01=0.800, loss=0.365]
2/4 * Epoch (valid): 100% 96/96 [00:14<00:00,  6.76it/s, accuracy01=0.500, loss=1.067]
[2021-02-07 23:41:32,901] 
2/4 * Epoch 2 (_base): lr=3.000e-05 | momentum=0.9000
2/4 * Epoch 2 (train): accuracy01=0.8580 | loss=0.3449
2/4 * Epoch 2 (valid): accuracy01=0.8325 | loss=0.3981
3/4 * Epoch (train): 100% 381/381 [02:36<00:00,  2.44it/s, accuracy01=0.900, loss=0.184]
3/4 * Epoch (valid): 100% 96/96 [00:14<00:00,  6.76it/s, accuracy01=0.500, loss=0.938]
[2021-02-07 23:45:16,367] 
3/4 * Epoch 3 (_base): lr=3.000e-05 | momentum=0.9000
3/4 * Epo

In [13]:
predicted_labels = predicted_scores.argmax(-1)
subm = pd.read_csv(os.path.join(DATA_PATH, 'sample_submission.csv'))
subm[TARGET_LABEL] = predicted_labels
subm.to_csv(os.path.join(DATA_PATH, 'submission_catalyst.csv'), index = None)