# Multi-label Text Classification with BERT and PyTorch Lightning

## Installing & importing Libraries

In [1]:
!nvidia-smi

Thu Dec 16 09:10:32 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P8    27W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install transformers --quiet

In [3]:
!pip install pytorch-lightning --quiet

In [4]:
!pip install torchmetrics --quiet

In [5]:
import pytorch_lightning as pl
print(pl.__version__)

1.5.6


In [6]:
# importing libraries
import pandas as pd
import numpy as np
from tqdm.auto import tqdm

import torch
import torchmetrics
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast as BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup

import pytorch_lightning as pl
from torchmetrics.functional import accuracy, f1, auroc
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix

import seaborn as sns
from pylab import rcParams
import matplotlib.pyplot as plt
from matplotlib import rc

In [7]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

RANDOM_SEED = 42

sns.set(style='whitegrid', palette='muted', font_scale=1.2)
HAPPY_COLORS_PALETTE = ["#01BEFE", "#FFDD00", "#FF7D00", "#FF006D", "#ADFF02", "#8F00FF"]
sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))
rcParams['figure.figsize'] = 12, 8
pl.seed_everything(RANDOM_SEED)

Global seed set to 42


42

## Loading the Data

In [8]:
# data from a Toxic Comment Classification challenge from Kaggle
!gdown --id 1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr

Downloading...
From: https://drive.google.com/uc?id=1VuQ-U7TtggShMeuRSA_hzC8qGDl2LRkr
To: /content/toxic_comments.csv
100% 68.8M/68.8M [00:00<00:00, 79.5MB/s]


In [9]:
df = pd.read_csv("toxic_comments.csv")

In [10]:
df.head(20)

Unnamed: 0,id,comment_text,toxic,severe_toxic,obscene,threat,insult,identity_hate
0,0000997932d777bf,Explanation\nWhy the edits made under my usern...,0,0,0,0,0,0
1,000103f0d9cfb60f,D'aww! He matches this background colour I'm s...,0,0,0,0,0,0
2,000113f07ec002fd,"Hey man, I'm really not trying to edit war. It...",0,0,0,0,0,0
3,0001b41b1c6bb37e,"""\nMore\nI can't make any real suggestions on ...",0,0,0,0,0,0
4,0001d958c54c6e35,"You, sir, are my hero. Any chance you remember...",0,0,0,0,0,0
5,00025465d4725e87,"""\n\nCongratulations from me as well, use the ...",0,0,0,0,0,0
6,0002bcb3da6cb337,COCKSUCKER BEFORE YOU PISS AROUND ON MY WORK,1,1,1,0,1,0
7,00031b1e95af7921,Your vandalism to the Matt Shirvington article...,0,0,0,0,0,0
8,00037261f536c51d,Sorry if the word 'nonsense' was offensive to ...,0,0,0,0,0,0
9,00040093b2687caa,alignment on this subject and which are contra...,0,0,0,0,0,0


## Preprocessing

In [11]:
train_df, val_df = train_test_split(df, test_size = 0.05)

In [12]:
train_df.shape, val_df.shape

((151592, 8), (7979, 8))

In [13]:
LABEL_COLUMNS = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]

In [14]:
train_df[LABEL_COLUMNS].sum()

toxic            14546
severe_toxic      1515
obscene           8028
threat             465
insult            7467
identity_hate     1334
dtype: int64

In [15]:
train_df[LABEL_COLUMNS].sum().sum()

33355

In [16]:
train_df[LABEL_COLUMNS].sum(axis=1).head()

41414     0
141338    0
145906    0
113950    0
136413    0
dtype: int64

In [17]:
train_toxic = train_df[train_df[LABEL_COLUMNS].sum(axis=1) > 0]

In [18]:
train_toxic.shape

(15427, 8)

In [19]:
train_clean = train_df[train_df[LABEL_COLUMNS].sum(axis=1) == 0]

In [20]:
train_clean.shape, train_toxic.shape

((136165, 8), (15427, 8))

In [80]:
# keeping equal samples of all classes since there's imbalance of data
train_df = pd.concat([
                      train_toxic,
                      train_clean.sample(10_000)
])

In [22]:
train_df[LABEL_COLUMNS].sum()

toxic            14546
severe_toxic      1515
obscene           8028
threat             465
insult            7467
identity_hate     1334
dtype: int64

In [23]:
train_df[LABEL_COLUMNS].sum()

toxic            14546
severe_toxic      1515
obscene           8028
threat             465
insult            7467
identity_hate     1334
dtype: int64

In [24]:
sample_row = df.iloc[16]
sample_comment = sample_row.comment_text
sample_labels = sample_row[LABEL_COLUMNS]

print(sample_comment)
print()
print(sample_labels.to_dict())

Bye! 

Don't look, come or think of comming back! Tosser.

{'toxic': 1, 'severe_toxic': 0, 'obscene': 0, 'threat': 0, 'insult': 0, 'identity_hate': 0}


## Tokenization

In [25]:
# loading tokenizer of bert base version
BERT_MODEL_NAME = "bert-base-cased"
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_NAME)

In [26]:
encoding = tokenizer.encode_plus(
    sample_comment, 
    add_special_tokens=True,
    max_length=512,
    return_token_type_ids=False,
    padding='max_length',
    return_attention_mask=True,
    return_tensors="pt"
)

encoding.keys()

dict_keys(['input_ids', 'attention_mask'])

In [27]:
encoding["input_ids"].shape, encoding["attention_mask"].shape

(torch.Size([1, 512]), torch.Size([1, 512]))

In [28]:
encoding["input_ids"].squeeze()[:20]

tensor([  101, 17774,   106,  1790,   112,   189,  1440,   117,  1435,  1137,
         1341,  1104,  3254,  5031,  1171,   106,  1706, 14607,   119,   102])

In [29]:
encoding["attention_mask"].squeeze()[:20]

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [30]:
encoding["input_ids"].squeeze().shape

torch.Size([512])

In [31]:
tokenizer.convert_ids_to_tokens(encoding["input_ids"].squeeze())[:20]

['[CLS]',
 'Bye',
 '!',
 'Don',
 "'",
 't',
 'look',
 ',',
 'come',
 'or',
 'think',
 'of',
 'com',
 '##ming',
 'back',
 '!',
 'To',
 '##sser',
 '.',
 '[SEP]']

## Wrapping Tokenization process in a PyTorch Dataset

In [32]:
class ToxicCommentsDataset(Dataset):

    def __init__(self, data: pd.DataFrame, tokenizer: BertTokenizer, max_token_len: int = 128):

        self.data = data
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        
        data_row = self.data.iloc[index]

        comment_text = data_row.comment_text
        labels = data_row[LABEL_COLUMNS]

        encoding = self.tokenizer.encode_plus(
            comment_text,
            add_special_tokens=True,
            max_length=self.max_token_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors="pt"
        )
        
        return dict(
            comment_text=comment_text,
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            labels=torch.FloatTensor(labels)
        )


In [33]:
train_dataset = ToxicCommentsDataset(train_df, tokenizer)

In [34]:
sample_item = train_dataset[0]

In [35]:
sample_item.keys()

dict_keys(['comment_text', 'input_ids', 'attention_mask', 'labels'])

In [36]:
sample_item["comment_text"]

'Hi, ya fucking idiot. ^_^'

In [37]:
sample_item["labels"]

tensor([1., 0., 1., 0., 1., 0.])

In [38]:
sample_item["input_ids"].shape

torch.Size([128])

## Loading the Bert Model

In [39]:
bert_model = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [40]:
sample_item["input_ids"].unsqueeze(dim=0).shape

torch.Size([1, 128])

In [41]:
prediction = bert_model(sample_item["input_ids"].unsqueeze(dim=0), sample_item["attention_mask"].unsqueeze(dim=0))

In [42]:
prediction.last_hidden_state.shape, prediction.pooler_output.shape

(torch.Size([1, 128, 768]), torch.Size([1, 768]))

### Wrapping our custom dataset into a LightningDataModule

In [43]:
class ToxicCommentsDataModule(pl.LightningDataModule):

    def __init__(self, train_df, test_df, tokenizer, batch_size=0, max_token_len=128):
        super().__init__
        self.train_df = train_df
        self.test_df = test_df
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.max_token_len = max_token_len

    def setup(self):
        self.train_dataset = ToxicCommentsDataset(
            self.train_df,
            self.tokenizer,
            self.max_token_len
        )

        self.test_dataset = ToxicCommentsDataset(
            self.test_df,
            self.tokenizer,
            self.max_token_len
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4
        )

    def val_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=1, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=1, num_workers=4)

In [44]:
N_EPOCHS = 10
BATCH_SIZE = 32

data_module = ToxicCommentsDataModule(train_df, val_df, tokenizer, batch_size=BATCH_SIZE)
data_module.setup()

## Modeling

### Evaluation

In [45]:
criterion = nn.BCELoss()

prediction = torch.FloatTensor(
    [10.95873564, 1.07321467, 1.58524066, 0.03839076, 15.72987556, 1.09513213]
)

labels = torch.FloatTensor(
  [1., 0., 0., 0., 1., 0.]
) 

In [46]:
torch.sigmoid(prediction)

tensor([1.0000, 0.7452, 0.8299, 0.5096, 1.0000, 0.7493])

In [47]:
output = criterion(torch.sigmoid(prediction), labels)
output

tensor(0.8725)

### Converting Bert representation to a classification task & packing it into LightningModule

In [48]:
class ToxicCommentTagger(pl.LightningModule):

  def __init__(self, n_classes: int, n_training_steps=None, n_warmup_steps=None):
    super().__init__()
    self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
    self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
    self.n_training_steps = n_training_steps
    self.n_warmup_steps = n_warmup_steps
    self.criterion = nn.BCELoss()

  def forward(self, input_ids, attention_mask, labels=None):
    output = self.bert(input_ids, attention_mask=attention_mask)
    output = self.classifier(output.pooler_output)
    output = torch.sigmoid(output)    
    loss = 0
    if labels is not None:
        loss = self.criterion(output, labels)
    return loss, output

  def training_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("train_loss", loss, prog_bar=True, logger=True)
    return {"loss": loss, "predictions": outputs, "labels": labels}

  def validation_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("val_loss", loss, prog_bar=True, logger=True)
    return loss

  def test_step(self, batch, batch_idx):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    labels = batch["labels"]
    loss, outputs = self(input_ids, attention_mask, labels)
    self.log("test_loss", loss, prog_bar=True, logger=True)
    return loss

  def training_epoch_end(self, outputs):
    
    labels = []
    predictions = []
    for output in outputs:
      for out_labels in output["labels"].detach().cpu():
        labels.append(out_labels)
      for out_predictions in output["predictions"].detach().cpu():
        predictions.append(out_predictions)

    labels = torch.stack(labels).int()
    predictions = torch.stack(predictions)

    for i, name in enumerate(LABEL_COLUMNS):
      class_roc_auc = auroc(predictions[:, i], labels[:, i])
      self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)


  def configure_optimizers(self):

    optimizer = AdamW(self.parameters(), lr=2e-5)

    scheduler = get_linear_schedule_with_warmup(
      optimizer,
      num_warmup_steps=self.n_warmup_steps,
      num_training_steps=self.n_training_steps
    )

    return dict(
      optimizer=optimizer,
      lr_scheduler=dict(
        scheduler=scheduler,
        interval='step'
      )
    )

In [49]:
steps_per_epoch=len(train_df) // BATCH_SIZE
total_training_steps = steps_per_epoch * N_EPOCHS

In [50]:
warmup_steps = total_training_steps // 5
warmup_steps, total_training_steps

(1588, 7940)

In [51]:
model = ToxicCommentTagger(
  n_classes=len(LABEL_COLUMNS),
  n_warmup_steps=warmup_steps,
  n_training_steps=total_training_steps 
)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Training

In [55]:
checkpoint_callback = ModelCheckpoint(
  dirpath="checkpoints",
  filename="best-checkpoint",
  save_top_k=1,
  verbose=True,
  monitor="val_loss",
  mode="min"
)

In [56]:
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

In [57]:
trainer = pl.Trainer(
  logger=logger,
  checkpoint_callback=checkpoint_callback,
  callbacks=[early_stopping_callback],
  max_epochs=N_EPOCHS,
  gpus=1,
  progress_bar_refresh_rate=30
)

  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [58]:
trainer.fit(model, data_module)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type      | Params
-----------------------------------------
0 | bert       | BertModel | 108 M 
1 | classifier | Linear    | 4.6 K 
2 | criterion  | BCELoss   | 0     
-----------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
433.260   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  cpuset_checked))
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
Global seed set to 42


Training: 0it [00:00, ?it/s]

  f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"


Validating: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"
  "Trying to infer the `batch_size` from an ambiguous collection.

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [62]:
model.freeze()
model.eval()

ToxicCommentTagger(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine

## Predictions

In [66]:
test_comment = "Hi, I'm Meredith and I'm an alch... good at supplier relations"

In [67]:
encoding = tokenizer.encode_plus(
  test_comment,
  add_special_tokens=True,
  max_length=512,
  return_token_type_ids=False,
  padding="max_length",
  return_attention_mask=True,
  return_tensors='pt',
)

In [68]:
_, test_prediction = model(encoding["input_ids"], encoding["attention_mask"])
test_prediction = test_prediction.flatten().numpy()

for label, prediction in zip(LABEL_COLUMNS, test_prediction):
  print(f"{label}: {prediction}")

toxic: 0.0020448840223252773
severe_toxic: 0.0013950353022664785
obscene: 0.002090828726068139
threat: 0.0016942467773333192
insult: 0.0040459479205310345
identity_hate: 0.0014846234116703272


In [70]:
THRESHOLD = 0.5

test_comment = "You are such a loser! You'll regret everything you've done to me!"
encoding = tokenizer.encode_plus(
  test_comment,
  add_special_tokens=True,
  max_length=512,
  return_token_type_ids=False,
  padding="max_length",
  return_attention_mask=True,
  return_tensors='pt',
)

_, test_prediction = model(encoding["input_ids"], encoding["attention_mask"])
test_prediction = test_prediction.flatten().numpy()

for label, prediction in zip(LABEL_COLUMNS, test_prediction):
  if prediction < THRESHOLD:
    continue
  print(f"{label}: {prediction}")

toxic: 0.9928410053253174
insult: 0.906752347946167


## Evaluation

In [73]:
MAX_TOKEN_COUNT = 512

In [74]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trained_model = model.to(device)

val_dataset = ToxicCommentsDataset(
  val_df,
  tokenizer,
  max_token_len=MAX_TOKEN_COUNT
)

predictions = []
labels = []

for item in tqdm(val_dataset):
  _, prediction = trained_model(
    item["input_ids"].unsqueeze(dim=0).to(device), 
    item["attention_mask"].unsqueeze(dim=0).to(device)
  )
  predictions.append(prediction.flatten())
  labels.append(item["labels"].int())

predictions = torch.stack(predictions).detach().cpu()
labels = torch.stack(labels).detach().cpu()

  0%|          | 0/7979 [00:00<?, ?it/s]

In [77]:
accuracy(predictions, labels, threshold=THRESHOLD)

tensor(0.9720)

In [78]:
print("AUROC per tag")
for i, name in enumerate(LABEL_COLUMNS):
  tag_auroc = auroc(predictions[:, i], labels[:, i], pos_label=1)
  print(f"{name}: {tag_auroc}")

AUROC per tag
toxic: 0.9726083278656006
severe_toxic: 0.9910653829574585
obscene: 0.9901463985443115
threat: 0.9314780235290527
insult: 0.9838355183601379
identity_hate: 0.9806649088859558


In [79]:
y_pred = predictions.numpy()
y_true = labels.numpy()

upper, lower = 1, 0

y_pred = np.where(y_pred > THRESHOLD, upper, lower)

print(classification_report(
  y_true, 
  y_pred, 
  target_names=LABEL_COLUMNS, 
  zero_division=0
))

               precision    recall  f1-score   support

        toxic       0.50      0.96      0.66       748
 severe_toxic       0.45      0.56      0.50        80
      obscene       0.76      0.87      0.81       421
       threat       0.27      0.46      0.34        13
       insult       0.69      0.77      0.73       410
identity_hate       0.52      0.65      0.58        71

    micro avg       0.58      0.86      0.69      1743
    macro avg       0.53      0.71      0.60      1743
 weighted avg       0.60      0.86      0.70      1743
  samples avg       0.08      0.08      0.08      1743



We now know that the model makes mistakes on the tags with low amounts of samples.