<a href="https://colab.research.google.com/github/maxmatical/ml-cheatsheet/blob/master/Pytorch_Lightning_BERT_Huggingface.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inspirations
https://curiousily.com/posts/multi-label-text-classification-with-bert-and-pytorch-lightning/

https://github.com/mgrankin/over9000/blob/master/train.py


https://medium.com/pytorch/getting-started-with-ray-lightning-easy-multi-node-pytorch-lightning-training-e639031aff8b


In [None]:
!nvidia-smi

Fri Nov 26 23:44:31 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   68C    P8    32W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
%%capture
!pip install transformers
!pip install pytorch_lightning
!pip install torchmetrics

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR # , ReduceLROnPlateau
from torch.optim import AdamW

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
  
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, StochasticWeightAveraging
from pytorch_lightning.loggers import TensorBoardLogger
import torchmetrics
from torchmetrics.functional import accuracy, f1, auroc

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

In [None]:
# pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
model = AutoModel.from_pretrained("distilroberta-base")

Downloading:   0%|          | 0.00/480 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/316M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Data

In [None]:
!gdown --id 1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr

Downloading...
From: https://drive.google.com/uc?id=1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr
To: /content/toxic_comments.csv
100% 68.8M/68.8M [00:01<00:00, 67.4MB/s]


In [None]:
df = pd.read_csv("toxic_comments.csv")

df.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0


In [None]:
train_df, val_df = train_test_split(df, test_size=0.15)

In [None]:
# subsample clean comments
LABEL_COLUMNS = df.columns.tolist()[2:]

train_toxic = train_df[train_df[LABEL_COLUMNS].sum(axis=1) > 0]
train_clean = train_df[train_df[LABEL_COLUMNS].sum(axis=1) == 0]

train_df = pd.concat([
  train_toxic,
  train_clean.sample(15_000)
])

train_df.shape, val_df.shape

((28852, 8), (23936, 8))

## Creating Dataset and Lightning Data Module

In [None]:
# set batch size and max seq_len
bs = 8
seq_len = 256

In [None]:
class ToxicCommentsDataset(Dataset):

  def __init__(
    self,
    data: pd.DataFrame,
    tokenizer: AutoTokenizer,
    max_token_len: int = 128

  ):
    self.tokenizer = tokenizer
    self.data = data
    self.max_token_len = max_token_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index: int):
    data_row = self.data.iloc[index]
    comment_text = data_row.comment_text
    labels = data_row[LABEL_COLUMNS]

    encoding = self.tokenizer.encode_plus(
      comment_text,
      add_special_tokens=True,
      max_length=self.max_token_len,
      return_token_type_ids=False,
      padding="max_length",
      truncation=True,
      return_attention_mask=True,
      return_tensors='pt',

    )

    return dict(
      comment_text=comment_text,
      input_ids=encoding["input_ids"].flatten(),
      attention_mask=encoding["attention_mask"].flatten(),
      labels = torch.IntTensor(labels)
    #   labels=torch.FloatTensor(labels)
    )

In [None]:
# test
train_dataset = ToxicCommentsDataset(
  train_df,
  tokenizer,
  max_token_len=seq_len
)

sample_item = train_dataset[0]
sample_item.keys()

dict_keys(['comment_text', 'input_ids', 'attention_mask', 'labels'])

In [None]:
print(sample_item["comment_text"], sample_item["labels"])
print(sample_item["input_ids"].shape)

please 

please don't be a cunt tensor([1, 0, 1, 0, 0, 0], dtype=torch.int32)
torch.Size([256])


In [None]:
class ToxicCommentsDataModule(pl.LightningDataModule):
  def __init__(self, train_df, test_df, tokenizer, batch_size=8, max_token_len=512):
    super().__init__()
    self.train_df, self.test_df = train_df, test_df
    self.tokenizer = tokenizer
    self.batch_size = batch_size
    self.max_token_len = max_token_len

  def setup(self, stage=None):
    self.train_dataset = ToxicCommentsDataset(
        self.train_df,
        self.tokenizer,
        self.max_token_len
    )

    self.test_dataset = ToxicCommentsDataset(
        self.test_df,
        self.tokenizer,
        self.max_token_len
    )

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size = self.batch_size,
        shuffle = True,
        num_workers=1
    )

  def val_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle = False,
        num_workers=1
    )

  def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle = False,
        num_workers=1
    )

In [None]:
data_module = ToxicCommentsDataModule(
    train_df,
    val_df,
    tokenizer,
    batch_size = bs,
    max_token_len = seq_len
)


In [None]:
data_module.setup() # call this before getting len of dataloader

In [None]:
len(data_module.train_dataloader())

3607

# Model

In [None]:
sample_batch = next(iter(DataLoader(train_dataset, batch_size=8, num_workers=1)))
sample_batch["input_ids"].shape, sample_batch["attention_mask"].shape

(torch.Size([8, 256]), torch.Size([8, 256]))

In [None]:
# output = model(sample_batch["input_ids"], sample_batch["attention_mask"])

In [None]:
output.pooler_output.shape

torch.Size([8, 768])

In [None]:
# test lr schedule with https://www.kaggle.com/isbhargav/guide-to-pytorch-learning-rate-scheduling
# test wtih flat cos lr

In [None]:
class BertModel(nn.Module):
  def __init__(self, n_classes: int):
    super().__init__()
    self.model = model
    self.classifier = nn.Linear(self.model.config.hidden_size, n_classes)

  def forward(self, input_ids, attention_mask): 
    out = self.model(input_ids, attention_mask=attention_mask)
    out = self.classifier(out.pooler_output)
    return out

In [None]:
bert_model = BertModel(len(LABEL_COLUMNS))

In [None]:
# bert_model(sample_batch["input_ids"], sample_batch["attention_mask"]).shape  # should be bs x 6

In [None]:
class ToxicCommentClassifier(pl.LightningModule):
  def __init__(self, pytorch_model: nn.Module, total_steps: int, lr: float = 2e-5):
    super().__init__()
    self.lr = lr
    self.total_steps = total_steps # for lr schedule

    self.pytorch_model = pytorch_model
    self.criterion = nn.BCEWithLogitsLoss() # bce for multi-label
    # self.criterion = nn.BCELoss()

    # metrics 
    self.accuracy = torchmetrics.Accuracy()

    # manually define opt step
    self.automatic_optimization = False

  def forward(self, input_ids, attention_mask):
    """
    forward step for lightning module
    """
    out = self.pytorch_model(input_ids, attention_mask=attention_mask)
    return out

  def training_step(self, batch, batch_idx):
    """
    training step
    define manual optimization with opt.step()
    and learning rate scheduler step
    """
    # batch comes from dataset
    # maybe do this https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#learning-rate-scheduling-manual
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    out = self(input_ids, attention_mask)
       
    # compute loss + metrics
    loss = self.criterion(out, labels.to(dtype=torch.float32)) # cast to float because labels is in Int
    acc = self.accuracy(out, labels)

    # manual optimization
    opt = self.optimizers()
    opt.zero_grad() # note: if using sam, don't call zero_grad() since first_step and second_step has (zero_grad=True)
    self.manual_backward(loss)
    opt.step()

    # lr schedule
    lr_sch = self.lr_schedulers()
    lr_sch.step()
    # log loss and metrics each k steps and each epoch
    log_values = {"train_loss": loss, "train_acc": acc}
    self.log_dict(log_values sync_dist=True, prog_bar=True, on_step=True, on_epoch=True)
    return {"loss": loss, "accuracy": acc, "predictions": out, "labels": labels}

  def validation_step(self, batch, batch_idx):
    """
    validation step
    get validation loss and accuracy metrics
    """
    # batch comes from dataset
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    out = self(input_ids, attention_mask)

    # compute loss + metrics
    loss = self.criterion(out, labels.to(dtype=torch.float32)) # cast to float because labels is in Int
    acc = self.accuracy(out, labels)
    # log loss and metrics each k steps and each epoch
    log_values = {"val_loss": loss, "val_acc": acc}
    self.log_dict(log_values sync_dist=True, prog_bar=True, on_step=True, on_epoch=True)
    return {"val_loss": loss, "val_accuracy": acc}

  def training_epoch_end(self, outputs):
    """
    after every training epoch, log train metrics

    note: this might not be needed if logging is done for each step and epoch
    """
    # log validation loss and metrics
    avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
    avg_acc = torch.stack([x["accuracy"] for x in outputs]).mean()

    self.log("avg_train_loss", avg_loss, sync_dist=True, prog_bar=True)
    self.log("train_accuracy", avg_acc, sync_dist=True, prog_bar=True)

  def validation_epoch_end(self, outputs):
    """
    after every training epoch, log val metrics

    note: this might not be needed if logging is done for each step and epoch
    """
    # log validation loss and metrics
    avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
    avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
    self.log("avg_val_loss", avg_loss, sync_dist=True, prog_bar=True)
    self.log("val_accuracy", avg_acc, sync_dist=True, prog_bar=True)

    # reduce lr on plateau scheduler
    # one_cylce_sch, reduce_lr_on_plateau_sch = self.lr_schedulers()
    # sch.step(self.trainer.callback_metrics["val_accuracy"])

  def configure_optimizers(self):
    # TODO: 
    # 1. automatically find steps for onecycleLR
    # 2. add flat cos
    # 3. make lr scheduler configurable(one_cycle vs flat_cos)
    optimizer = AdamW(self.parameters(), lr=self.lr)

    """
    note if using sam, eg 
    ```
    base_optimizer = AdamW
    optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, betas = (0.9, 0.99))
    ```
    then optimizer for scheduler needs to be come optimizer.base_optimizer 
    eg.
    ```
    scheduler = OneCycleLR(
      optimizer.base_optimizer,
      max_lr=2e-5,
      pct_start=0.3,
      total_steps=self.total_steps

    )
    ```
    """
    scheduler = OneCycleLR(
      optimizer, # note: if using sam, change this to optimizer.base_optimizer
      max_lr=2e-5,
      pct_start=0.3,
      total_steps=self.total_steps

    )
    # maybe dont want to use ReduceLROnPlateau
    # scheduler2 = ReduceLROnPlateau(
    #     optimizer,
    #     patience=3
    # )
    return [optimizer], [scheduler]

In [None]:
# training epoch related hyperparams
steps_per_epoch = len(data_module.train_dataloader())
n_epochs = 1

total_steps = steps_per_epoch * n_epochs
print(f"total_steps: {total_steps}")

total_steps: 3607


Note: if using multiple gpus, total steps calculation is something like
```
steps_per_epoch = len(data_module.train_dataloader())

# for gpu or tpu cores
num_devices = 4
if tpu_cores:
    num_devices = max(num_devices, tpu_cores)

accumulate_grad_batches = 1 # for no gradient accumulation

effective_accum = accumulate_grad_batches * num_devices
total_steps (steps_per_epoch // effective_accum) * n_epochs

```

In [None]:
toxic_comment_model = ToxicCommentClassifier(bert_model, total_steps = total_steps)

In [None]:
# toxic_comment_model(sample_batch["input_ids"], sample_batch["attention_mask"])[1].shape  # should be bs x 6

# Trainer

defining the trainer and running `trainer.fit` 

In [None]:
# callbacks
lr_monitor_cb = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=6)

swa = StochasticWeightAveraging(swa_epoch_start=0.8, swa_lrs=None)

logger = TensorBoardLogger("lightning_logs", name="toxic-comments")

In [None]:
# try with SWA, don't use early stopping or model checkpoint right now
# callbacks = [swa, lr_monitor_cb]

# only log lr 
# not using SWA because switches out one cycle lr for SWALR scheduler
callbacks = [lr_monitor_cb]

In [None]:
trainer = pl.Trainer(
    logger=logger,
    callbacks = callbacks,
    precision=16,
    max_epochs = n_epochs,
    accelerator="gpu",
    devices=1,
)

Using 16bit native Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
# optional: currently not working?
# lr_finder = trainer.tuner.lr_find(model=toxic_comment_model, datamodule=data_module)

In [None]:
trainer.fit(toxic_comment_model, data_module)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | pytorch_model | BertModel         | 82.1 M
1 | criterion     | BCEWithLogitsLoss | 0     
2 | accuracy      | Accuracy          | 0     
----------------------------------------------------
82.1 M    Trainable params
0         Non-trainable params
82.1 M    Total params
164.246   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

In [None]:
trainer.logged_metrics

{'train_accuracy': 0.9270631074905396,
 'train_loss': 0.17185573279857635,
 'val_accuracy': 0.9817500114440918,
 'val_loss': 0.05154530331492424}

In [None]:
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/


# Saving only the pytorch model

In [None]:
# saving model weights
saved_model_pth = "saved_model.pth"

torch.save(toxic_comment_model.pytorch_model.state_dict(), saved_model_pth)

# Exporting model as pure pytorch 

In [None]:
new_bert_model = BertModel(6)

new_bert_model.load_state_dict(torch.load(saved_model_pth))

<All keys matched successfully>

In [None]:
# check values are the same

inf_preds = new_bert_model(sample_batch["input_ids"], sample_batch["attention_mask"])
print(inf_preds.shape) # should be 8 x 6
inf_preds

torch.Size([8, 6])


tensor([[ 3.3002, -2.6973,  2.7221, -5.2559,  0.7052, -3.9604],
        [ 4.4156, -0.5025,  3.8863, -3.7945,  2.5557, -2.4770],
        [ 3.4849, -2.9588,  0.6141, -4.2794,  1.5000, -1.5067],
        [ 1.9291, -5.8827, -1.8313, -5.9951, -1.6118, -4.0301],
        [ 0.1542, -5.0230, -2.0128, -4.9881, -2.5206, -4.2967],
        [ 2.1084, -3.1527, -2.2322,  0.3907, -2.0325, -3.0269],
        [ 0.9454, -6.3798, -3.6319, -5.6751, -1.6457, -4.8039],
        [ 1.1208, -5.3446, -0.0648, -5.9639, -2.0950, -5.3173]],
       grad_fn=<AddmmBackward0>)

In [None]:
toxic_comment_model(sample_batch["input_ids"], sample_batch["attention_mask"])[1]