<a href="https://colab.research.google.com/github/gupta24789/multiclass-classification/blob/main/multiclass_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q pytorch-lightning

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m777.7/777.7 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.2/840.2 kB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import pandas as pd
import numpy as np
import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
import torchmetrics
from transformers import AutoTokenizer, AutoModel

from sklearn.metrics import classification_report
from sklearn.utils.class_weight import compute_class_weight

## Set Seed

In [3]:
SEED = 121
torch.manual_seed(SEED)
np.random.seed(SEED)
pl.seed_everything(SEED)

INFO:lightning_fabric.utilities.seed:Seed set to 121


121

## Load Read Data

In [4]:
train_df = pd.read_csv("https://raw.githubusercontent.com/gupta24789/multiclass-classification/main/data/train.txt", header = None, sep=';')
val_df = pd.read_csv("https://raw.githubusercontent.com/gupta24789/multiclass-classification/main/data/val.txt",header = None, sep=';')
train_df.columns = ['text','label']
val_df.columns = ['text','label']

In [5]:
print(train_df.label.value_counts())

joy         5362
sadness     4666
anger       2159
fear        1937
love        1304
surprise     572
Name: label, dtype: int64


In [6]:
print(val_df.label.value_counts())

joy         704
sadness     550
anger       275
fear        212
love        178
surprise     81
Name: label, dtype: int64


In [7]:
train_df.head()

Unnamed: 0,text,label
0,i didnt feel humiliated,sadness
1,i can go from feeling so hopeless to so damned...,sadness
2,im grabbing a minute to post i feel greedy wrong,anger
3,i am ever feeling nostalgic about the fireplac...,love
4,i am feeling grouchy,anger


## Encode Label

In [8]:
## Encode Labels
label2idx_map = {w:i for i,w in enumerate(train_df.label.unique().tolist())}
idx2label_map = {i:w for w,i in label2idx_map.items()}

train_df['encoded_label'] = train_df.label.apply(lambda x: label2idx_map[x])
val_df['encoded_label'] = val_df.label.apply(lambda x: label2idx_map[x])

## Class Weight

In [9]:
class_weights = compute_class_weight(class_weight='balanced',classes=np.unique(train_df.encoded_label), y=train_df.encoded_label)
class_weights

array([0.57151022, 1.23513973, 2.04498978, 4.66200466, 1.37669936,
       0.49732687])

## Transformer Model Exploration

In [75]:
# model_name = "albert-base-v2"
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformer_model = AutoModel.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [76]:
inputs = tokenizer("I hate you", return_tensors='pt')
inputs

{'input_ids': tensor([[ 101, 1045, 5223, 2017,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [77]:
embedding = transformer_model(**inputs)

In [78]:
last_hidden_state, pooler_output = embedding['last_hidden_state'], embedding['pooler_output']

In [79]:
last_hidden_state.shape, pooler_output.shape

(torch.Size([1, 5, 768]), torch.Size([1, 768]))

In [80]:
last_hidden_state[:,0,:].shape

torch.Size([1, 768])

## Data Loaders

In [81]:
def custom_collate(batch):

  text = [item['text'] for item in batch]
  label = [item['encoded_label'] for item in batch]

  inputs = tokenizer(text, max_length= 50, truncation=True, padding='max_length', return_tensors='pt')
  label = torch.tensor(label, dtype = torch.long)

  batch = {"input_ids": inputs['input_ids'], "token_type_ids": inputs['token_type_ids'],"attention_mask": inputs['attention_mask'], "label": label}
  return batch

In [82]:
train_data = train_df[['text','encoded_label']].to_dict('records')
val_data = val_df[['text','encoded_label']].to_dict('records')

In [83]:
train_data[:2]

[{'text': 'i didnt feel humiliated', 'encoded_label': 0},
 {'text': 'i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake',
  'encoded_label': 0}]

In [84]:
batch_size = 2
train_dl = DataLoader(train_data, batch_size = batch_size, shuffle = True, collate_fn= custom_collate)

In [85]:
example = next(iter(train_dl))
example['input_ids'].shape, example['token_type_ids'].shape, example['attention_mask'].shape, example['label'].shape

(torch.Size([2, 50]),
 torch.Size([2, 50]),
 torch.Size([2, 50]),
 torch.Size([2]))

In [86]:
example['input_ids']

tensor([[  101,  1045,  2052,  7887,  2514, 21568,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1045,  2428,  2514, 12491,  2058,  2035,  2023, 26865,  2004,
          1045,  2031,  2042,  2000,  2023, 16581, 10380,  3807,  1998,  7714,
          2113,  2035,  1996,  2598, 22213,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]])

In [87]:
example['token_type_ids']

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]])

In [88]:
example['attention_mask']

tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0]])

In [89]:
example['label']

tensor([4, 0])

In [90]:
## dataloaders
batch_size = 64
train_dl = DataLoader(train_data, batch_size = batch_size, shuffle = True, collate_fn= custom_collate, num_workers = 2)
val_dl = DataLoader(val_data, batch_size = batch_size, shuffle = False, collate_fn= custom_collate, num_workers = 2)

## Build Model

In [91]:
class MultiClassTransformer(pl.LightningModule):

  def __init__(self, output_dim, learning_rate, freeze = False):
    super().__init__()
    self.learning_rate = learning_rate

    ## define loss & accuracy
    self.loss_fn = nn.CrossEntropyLoss(weight= torch.tensor(class_weights, dtype = torch.float))
    self.train_f1 = torchmetrics.F1Score(task="multiclass", num_classes=output_dim)
    self.val_f1 = torchmetrics.F1Score(task="multiclass", num_classes=output_dim)

    ## define layers
    self.transformer_model = AutoModel.from_pretrained(model_name)
    hidden_dim = self.transformer_model.config.hidden_size
    self.linear = nn.Linear(hidden_dim, output_dim)
    self.tanh = nn.Tanh()

    ## freeze layers
    if freeze:
      for name , param in self.transformer_model.named_parameters():
        param.requires_grad  = False


  def forward(self, inputs):
    """
    No need to apply softmax at the end as crossentropy implicitly apply the softmax
    """
    embeddings = self.transformer_model(**inputs)
    last_hidden_state, pooler_output = embeddings['last_hidden_state'], embeddings['pooler_output']

    hidden_state = last_hidden_state[:,0,:]
    out = self.linear(self.tanh(hidden_state))
    return out

  def _shared_step(self, batch):
    label = batch.pop('label')
    logits = self(batch)
    loss = self.loss_fn(logits, label)
    return logits, loss, label

  def training_step(self, batch, batch_idx):
    logits, loss, label = self._shared_step(batch)
    self.train_f1.update(logits, label)
    self.log_dict({"train_loss": loss, "train_f1": self.train_f1}, on_step = False, on_epoch = True, prog_bar=True)
    return loss

  def validation_step(self,batch, batch_idx):
    logits, loss, label = self._shared_step(batch)
    self.val_f1.update(logits, label)
    self.log_dict({"val_loss": loss,  "val_f1": self.val_f1}, on_step = False, on_epoch = True, prog_bar=True)
    return loss

  def on_training_epoch_end(self):
    self.train_f1.reset()

  def on_validation_epoch_end(self):
    if self.current_epoch!=0:
      print(f"Epoch : {self.current_epoch} Val F1 : {self.val_f1.compute()}")
    self.val_f1.reset()

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

In [92]:
# ## test model architecture
# model = MultiClassTransformer(output_dim = len(label2idx_map), learning_rate = 1e-3, freeze = False)
# inputs = {
#     "input_ids": example['input_ids'],
#     "token_type_ids": example['token_type_ids'],
#     "attention_mask": example['attention_mask']
# }
# logits = model(inputs)
# model.loss_fn(logits, example['label'])

In [94]:
## Model Training

model = MultiClassTransformer(output_dim = len(label2idx_map), learning_rate = 1e-3, freeze = True)

callbacks = pl.callbacks.ModelCheckpoint(dirpath = "multiclass_logs",
                                         filename = '{epoch}-{val_loss:.2f}-{val_acc:.2f}',
                                          mode = "min",
                                          monitor = "val_loss",
                                          save_last = True,
                                          save_top_k=-1)

trainer = pl.Trainer(accelerator= "gpu",
           max_epochs=5,
           check_val_every_n_epoch = 1,
           callbacks = [callbacks])

trainer.fit(model, train_dl, val_dl)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name              | Type              | Params
--------------------------------------------------------
0 | loss_fn           | CrossEntropyLoss  | 0     
1 | train_f1          | MulticlassF1Score | 0     
2 | val_f1            | MulticlassF1Score | 0     
3 | transformer_model | BertModel         | 109 M 
4 | linear            | Linear            | 4.6 K 
5 | tanh              | Tanh              | 0     
--------------------------------------------------------
4.6 K     Trainable params
109 M     Non-trainable params
109 M     Total 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 1 Val F1 : 0.49000000953674316


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 2 Val F1 : 0.44749999046325684


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 3 Val F1 : 0.48899999260902405


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 4 Val F1 : 0.45649999380111694


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


## Load Model

In [None]:
## Load model in cpu
MultiClassTransformer.load_from_checkpoint("multiclass_logs/",
                                            output_dim = len(label2idx_map),
                                           learning_rate = 1e-3,
                                           freeze = False,
                                           map_location = "cpu")

In [None]:
model = MultiClassTransformer.load_from_checkpoint("multiclass_logs/",
                                            output_dim = len(label2idx_map),
                                                   learning_rate = 1e-3,
                                                   freeze = False,
                                                   map_location = "cpu")

In [None]:
model.eval()

MultiClassLSTM(
  (loss_fn): CrossEntropyLoss()
  (train_f1): MulticlassF1Score()
  (val_f1): MulticlassF1Score()
  (embedding): Embedding(10379, 100, padding_idx=0)
  (lstm): LSTM(100, 256, num_layers=2, batch_first=True, dropout=0.25, bidirectional=True)
  (relu): ReLU()
  (linear1): Linear(in_features=512, out_features=32, bias=True)
  (linear2): Linear(in_features=32, out_features=6, bias=True)
)

In [None]:
def predict(text):
  model.eval()
  tokenized_text = process_text(text)
  token_tensor = convert_to_number_tensor(tokenized_text)
  token_tensor = token_tensor.view(1,-1)
  lengths = torch.tensor([token_tensor.shape[1]], dtype = torch.long)
  preds = model(token_tensor, lengths)
  value, index = torch.topk(preds, k = 1)
  return index.item()

In [None]:
index = predict("I love you")
print(f"Label : {idx2label_map[index]}")

Label : joy


In [None]:
index = predict("i hate you")
print(f"Label : {idx2label_map[index]}")

Label : anger


## Classification report

In [None]:
val_preds_index = [predict(text) for text in val_df.complaints]

In [None]:
print(classification_report(val_df.encoded_label, val_preds_index, target_names = list(label2idx_map.keys())))

              precision    recall  f1-score   support

     sadness       0.96      0.90      0.93       550
       anger       0.89      0.88      0.88       275
        love       0.77      0.72      0.74       178
    surprise       0.81      0.84      0.82        81
        fear       0.78      0.91      0.84       212
         joy       0.90      0.91      0.91       704

    accuracy                           0.89      2000
   macro avg       0.85      0.86      0.85      2000
weighted avg       0.89      0.89      0.89      2000



In [None]:
val_df.label.value_counts()

joy         704
sadness     550
anger       275
fear        212
love        178
surprise     81
Name: label, dtype: int64