<a href="https://colab.research.google.com/github/maxmatical/ml-cheatsheet/blob/master/Pytorch_Lightning_BERT_Huggingface.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://curiousily.com/posts/multi-label-text-classification-with-bert-and-pytorch-lightning/

In [None]:
%%capture
!pip install transformers
!pip install pytorch_lightning

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR, ReduceLROnPlateau
from torch.optim import AdamW

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
  
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy, f1, auroc
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, StochasticWeightAveraging
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

In [None]:
# pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
model = AutoModel.from_pretrained("distilroberta-base")

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaModel: ['lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# Data

In [None]:
!gdown --id 1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr

Downloading...
From: https://drive.google.com/uc?id=1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr
To: /content/toxic_comments.csv
100% 68.8M/68.8M [00:00<00:00, 122MB/s]


In [None]:
df = pd.read_csv("toxic_comments.csv")

df.head()

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0


In [None]:
train_df, val_df = train_test_split(df, test_size=0.15)

In [None]:
# subsample clean comments
LABEL_COLUMNS = df.columns.tolist()[2:]

train_toxic = train_df[train_df[LABEL_COLUMNS].sum(axis=1) > 0]
train_clean = train_df[train_df[LABEL_COLUMNS].sum(axis=1) == 0]

train_df = pd.concat([
  train_toxic,
  train_clean.sample(15_000)
])

train_df.shape, val_df.shape

((28795, 8), (23936, 8))

In [None]:
class ToxicCommentsDataset(Dataset):

  def __init__(
    self,
    data: pd.DataFrame,
    tokenizer: AutoTokenizer,
    max_token_len: int = 128

  ):
    self.tokenizer = tokenizer
    self.data = data
    self.max_token_len = max_token_len

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index: int):
    data_row = self.data.iloc[index]
    comment_text = data_row.comment_text
    labels = data_row[LABEL_COLUMNS]

    encoding = self.tokenizer.encode_plus(
      comment_text,
      add_special_tokens=True,
      max_length=self.max_token_len,
      return_token_type_ids=False,
      padding="max_length",
      truncation=True,
      return_attention_mask=True,
      return_tensors='pt',

    )

    return dict(
      comment_text=comment_text,
      input_ids=encoding["input_ids"].flatten(),
      attention_mask=encoding["attention_mask"].flatten(),
      labels=torch.FloatTensor(labels)
    )

In [None]:
# test
train_dataset = ToxicCommentsDataset(
  train_df,
  tokenizer,
  max_token_len=512
)

sample_item = train_dataset[0]
sample_item.keys()

dict_keys(['comment_text', 'input_ids', 'attention_mask', 'labels'])

In [None]:
print(sample_item["comment_text"], sample_item["labels"])
print(sample_item["input_ids"].shape)

"I think all of you americans you have a real attitude and mental problem !
First you live in a World full of conspiracy and spyies . You STILL believe that America won Vietnam, that YOU (especially you) are a Americqan soldier hero ....pffff...
Anyway no need to show  you more proves of how crazy and idiots americans are - by the way ...what happend in Ferguson ? Is ""American democracy"" ? a ? THE Cowboy LAW ! All you you a bunch of stupid, incredible stupid cowboys !!!
One guy enter in a shop, he molest the cashier, still some products, push away the employee and YOU...AMERICANS ...what you do ?
You go in the street screaming - wasn't his  fault !
You are a JOKE OF people !
A bunch of iditos !
But anyway, from my point of view ...keep thinking and acting like this.Alredy the whole Europe, Asia and in specially Russia DON'T like you at all !
You will end up of being the ""paria"" of World ! -IF you know what that means !

It seems that what i was trying to explain to you was just som

In [None]:
class ToxicCommentsDataModule(pl.LightningDataModule):
  def __init__(self, train_df, test_df, tokenizer, batch_size=8, max_token_len=512):
    super().__init__()
    self.train_df, self.test_df = train_df, test_df
    self.tokenizer = tokenizer
    self.batch_size = batch_size
    self.max_token_len = max_token_len

  def setup(self, stage=None):
    self.train_dataset = ToxicCommentsDataset(
        self.train_df,
        self.tokenizer,
        self.max_token_len
    )

    self.test_dataset = ToxicCommentsDataset(
        self.test_df,
        self.tokenizer,
        self.max_token_len
    )

  def train_dataloader(self):
    return DataLoader(
        self.train_dataset,
        batch_size = self.batch_size,
        shuffle = True,
        num_workers=1
    )

  def val_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle = False,
        num_workers=1
    )

  def test_dataloader(self):
    return DataLoader(
        self.test_dataset,
        batch_size = self.batch_size,
        shuffle = False,
        num_workers=1
    )

In [None]:
data_module = ToxicCommentsDataModule(
    train_df,
    val_df,
    tokenizer,
    batch_size = 4,
    max_token_len = 512
)


In [None]:
data_module.setup() # call this before getting len of dataloader

In [None]:
len(data_module.train_dataloader())

2400

# Model

In [None]:
sample_batch = next(iter(DataLoader(train_dataset, batch_size=8, num_workers=1)))
sample_batch["input_ids"].shape, sample_batch["attention_mask"].shape

(torch.Size([8, 512]), torch.Size([8, 512]))

In [None]:
output = model(sample_batch["input_ids"], sample_batch["attention_mask"])

In [None]:
output.pooler_output.shape

torch.Size([8, 768])

In [None]:
# test lr schedule with https://www.kaggle.com/isbhargav/guide-to-pytorch-learning-rate-scheduling
# test wtih flat cos lr

In [None]:
class BertModel(nn.Module):
  def __init__(self, n_classes: int):
    super().__init__()
    self.model = model
    self.classifier = nn.Linear(self.model.config.hidden_size, n_classes)

  def forward(self, input_ids, attention_mask): 
    out = self.model(input_ids, attention_mask=attention_mask)
    out = self.classifier(out.pooler_output)
    return out

In [None]:
bert_model = BertModel(len(LABEL_COLUMNS))

In [None]:
bert_model(sample_batch["input_ids"], sample_batch["attention_mask"]).shape  # should be bs x 6

torch.Size([8, 6])

In [None]:
class ToxicCommentClassifier(pl.LightningModule):
  def __init__(self, pytorch_model: nn.Module, lr: float = 2e-5):
    super().__init__()
    self.lr = lr

    self.pytorch_model = pytorch_model
    self.criterion = nn.BCEWithLogitsLoss() # bce for multi-label
    # self.criterion = nn.BCELoss()

    # metrics 
    self.accuracy = pl.metrics.Accuracy()

    # manually define opt step
    self.automatic_optimization = False

  def forward(self, input_ids, attention_mask, labels=None):
    out = self.pytorch_model(input_ids, attention_mask=attention_mask)

    # compute and return loss + metrics
    loss = 0
    if labels is not None:
      loss = self.criterion(out, labels)
      acc = self.accuracy(out, labels)
    return loss, out, acc

  def training_step(self, batch, batch_idx):
    # batch comes from dataset
    # maybe do this https://pytorch-lightning.readthedocs.io/en/latest/common/optimizers.html#learning-rate-scheduling-manual
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    loss, out, acc = self(input_ids, attention_mask, labels)

    # manual optimization
    opt = self.optimizers()
    opt.zero_grad()
    self.manual_backward(loss)
    opt.step()

    # lr schedule
    one_cylce_sch, reduce_lr_on_plateau_sch = self.lr_schedulers()
    one_cylce_sch.step()

    return {"loss": loss, "accuracy": acc, "predictions": out, "labels": labels}

  def validation_step(self, batch, batch_idx):
    # batch comes from dataset
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]

    loss, out, acc = self(input_ids, attention_mask, labels)
    return {"val_loss": loss, "val_accuracy": acc}

  def training_epoch_end(self, outputs):
    # log validation loss and metrics
    avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
    avg_acc = torch.stack([x["accuracy"] for x in outputs]).mean()
    self.log("train_loss", avg_loss, sync_dist=True, prog_bar=True)
    self.log("trainaccuracy", avg_acc, sync_dist=True, prog_bar=True)

  def validation_epoch_end(self, outputs):
    # log validation loss and metrics
    avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
    avg_acc = torch.stack([x["val_accuracy"] for x in outputs]).mean()
    self.log("val_loss", avg_loss, sync_dist=True, prog_bar=True)
    self.log("val_accuracy", avg_acc, sync_dist=True, prog_bar=True)

    # reduce lr on plateau scheduler
    one_cylce_sch, reduce_lr_on_plateau_sch = self.lr_schedulers()
    sch.step(self.trainer.callback_metrics["val_accuracy"])

  def configure_optimizers(self):
    # TODO: 
    # 1. automatically find steps for onecycleLR
    # 2. add flat cos
    # 3. make lr scheduler configurable(one_cycle vs flat_cos)
    optimizer = AdamW(self.parameters(), lr=self.lr)

    scheduler1 = OneCycleLR(
      optimizer,
      max_lr=2e-5,
      pct_start=0.3,
      steps_per_epoch=2400,
      epochs=2

    )

    scheduler2 = ReduceLROnPlateau(
        optimizer,
        patience=3
    )
    return [optimizer], [scheduler1, scheduler2]

In [None]:
toxic_comment_model = ToxicCommentClassifier(bert_model)

In [None]:
# toxic_comment_model(sample_batch["input_ids"], sample_batch["attention_mask"])[1].shape  # should be bs x 6

In [None]:
# callbacks
lr_monitor_cb = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)

early_stopping_callback = EarlyStopping(monitor='val_loss', patience=6)

swa = StochasticWeightAveraging()

logger = TensorBoardLogger("lightning_logs", name="toxic-comments")

In [None]:
trainer = pl.Trainer(
    logger=logger,
    callbacks = [checkpoint_callback, early_stopping_callback, lr_monitor_cb],
    precision=16,
    max_epochs = 2,
    gpus=1,
    progress_bar_refresh_rate=30
)

Using native 16bit precision.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
trainer.fit(toxic_comment_model, data_module)

  f"DataModule.{name} has already been called, so it will not be called again. "
  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type              | Params
----------------------------------------------------
0 | pytorch_model | BertModel         | 82.1 M
1 | criterion     | BCEWithLogitsLoss | 0     
----------------------------------------------------
82.1 M    Trainable params
0         Non-trainable params
82.1 M    Total params
328.492   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

RuntimeError: ignored

# Exporting model as pure pytorch 

In [None]:
saved_model_pth = ""
new_bert_model = BertModel(6)

new_bert_model.load_state_dict(torch.load(saved_model_pth))

In [None]:
# check values are the same

inf_preds = new_bert_model(sample_batch["input_ids"], sample_batch["attention_mask"])
print(inf_preds.shape) # should be 8 x 6
inf_preds

In [None]:
toxic_comment_model(sample_batch["input_ids"], sample_batch["attention_mask"])[1]