<a href="https://colab.research.google.com/github/gupta24789/multilabel-classification/blob/main/multilabel_rnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q  pytorch-lightning

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m777.7/777.7 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.2/840.2 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import random
import pandas as pd
import numpy as np
import itertools
import string
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import WordPunctTokenizer
from nltk.stem import PorterStemmer
nltk.download('stopwords')

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
import torchmetrics

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


## Set Seed

In [3]:
seed = 121
random.seed(seed)
torch.manual_seed(seed)
pl.seed_everything(seed)

INFO:lightning_fabric.utilities.seed:Seed set to 121


121

## Utilities

In [4]:
tokenizer = WordPunctTokenizer()
stemmer = PorterStemmer()
STOPWORDS = stopwords.words('english')


def process_context(context):

  clean_context = []

  context = context.lower()
  words_list = tokenizer.tokenize(context)


  for w in words_list:
    if w not in STOPWORDS and w not in string.punctuation:
      stem_word = stemmer.stem(w)
      clean_context.append(stem_word)

  return clean_context


def convert_word_to_number_tensor(context):

  encoded_context = []
  for w in context:
    encoded_context.append(token2idx.get(w, UNK_ID))

  return torch.tensor(encoded_context)

## Load Data

In [5]:
train_df = pd.read_csv("https://raw.githubusercontent.com/gupta24789/multilabel-classification/main/data/train.csv")
val_df = pd.read_csv("https://raw.githubusercontent.com/gupta24789/multilabel-classification/main/data/test.csv")

print(f"Train shape : {train_df.shape}")
print(f"Val shape : {val_df.shape}")

train_df.columns = train_df.columns.str.lower()
val_df.columns = val_df.columns.str.lower()

Train shape : (16777, 9)
Val shape : (4195, 9)


In [6]:
train_df.head(3)

Unnamed: 0,id,title,abstract,computer science,physics,mathematics,statistics,quantitative biology,quantitative finance
0,1,Reconstructing Subject-Specific Effect Maps,Predictive models allow subject-specific inf...,1,0,0,0,0,0
1,2,Rotation Invariance Neural Network,Rotation invariance and translation invarian...,1,0,0,0,0,0
2,3,Spherical polyharmonics and Poisson kernels fo...,We introduce and develop the notion of spher...,0,0,1,0,0,0


## Data Prep

In [7]:
train_df['context'] = train_df.title + ". " + train_df.abstract
val_df['context'] = val_df.title + ". " + val_df.abstract

In [8]:
target_columns = ['computer science', 'physics', 'mathematics',
       'statistics', 'quantitative biology', 'quantitative finance']

In [9]:
train_df = train_df[['context'] + target_columns]
val_df = val_df[['context'] + target_columns]

In [10]:
train_df.head(3)

Unnamed: 0,context,computer science,physics,mathematics,statistics,quantitative biology,quantitative finance
0,Reconstructing Subject-Specific Effect Maps. ...,1,0,0,0,0,0
1,Rotation Invariance Neural Network. Rotation...,1,0,0,0,0,0
2,Spherical polyharmonics and Poisson kernels fo...,0,0,1,0,0,0


In [11]:
train_df['context_clean'] = train_df.context.apply(lambda x: process_context(x))
val_df['context_clean'] = val_df.context.apply(lambda x: process_context(x))

## Build Vocab

In [12]:
special_tokens = ['__PAD__','__UNK__']
vocab = list(set(itertools.chain.from_iterable(train_df.context_clean.tolist())))
vocab = vocab + special_tokens
token2idx = {w:i for i,w in enumerate(vocab)}
idx2token = {i:w for w,i in token2idx.items()}

PAD_ID = token2idx['__PAD__']
UNK_ID = token2idx['__UNK__']

print(f"Vocab size : {len(vocab)}")

Vocab size : 38181


## Convert word to number tensor

In [13]:
train_df['encoded_context'] = train_df.context_clean.apply(lambda x: convert_word_to_number_tensor(x))
val_df['encoded_context'] = val_df.context_clean.apply(lambda x: convert_word_to_number_tensor(x))

train_df['labels'] = train_df[target_columns].values.tolist()
val_df['labels'] = val_df[target_columns].values.tolist()

In [14]:
train_df.head(3)

Unnamed: 0,context,computer science,physics,mathematics,statistics,quantitative biology,quantitative finance,context_clean,encoded_context,labels
0,Reconstructing Subject-Specific Effect Maps. ...,1,0,0,0,0,0,"[reconstruct, subject, specif, effect, map, pr...","[tensor(18596), tensor(31367), tensor(24791), ...","[1, 0, 0, 0, 0, 0]"
1,Rotation Invariance Neural Network. Rotation...,1,0,0,0,0,0,"[rotat, invari, neural, network, rotat, invari...","[tensor(28906), tensor(864), tensor(6077), ten...","[1, 0, 0, 0, 0, 0]"
2,Spherical polyharmonics and Poisson kernels fo...,0,0,1,0,0,0,"[spheric, polyharmon, poisson, kernel, polyhar...","[tensor(37566), tensor(33389), tensor(12923), ...","[0, 0, 1, 0, 0, 0]"


In [15]:
train_data = train_df[['encoded_context','labels']].to_dict("records")
val_data = val_df[['encoded_context','labels']].to_dict("records")

In [16]:
train_data[:1]

[{'encoded_context': tensor([18596, 31367, 24791, 34856, 22348, 27528, 11431, 26648, 31367, 24791,
          32804, 27228, 23688, 12843, 22383, 13099, 21269, 12514, 31367, 21269,
          32804, 22260, 10848, 32448, 14328, 17245, 15785, 37197, 18730, 31367,
          32153, 17245, 12990, 37197, 34856,  6348, 13519, 13098, 31367, 21269,
          14328, 32804,  4145, 28208, 32153, 32804, 28208, 24522, 31367, 24791,
          34856, 22348, 15205, 28208,  3230, 11431,  9693, 15710,  8494, 12990,
           5210,   477,   595, 24133,  1182,  5904, 18596, 17821, 12169, 29292,
          24562, 31367, 24791, 12990, 27528, 11431, 21472,  4256, 22773, 15275,
          29292, 24791, 33737, 28730, 14058,  8973, 15128,  6452, 30223, 28208,
          11893, 15128,  3256, 37242, 15275,  5904, 17821, 26338,  4856, 37998,
          28208, 20815, 22773, 15275, 29563,  8415, 17245, 17944,  3251, 37197,
          18730, 18596, 31019, 14892, 29948, 16819,  5652, 11431, 25633, 34570,
          15395, 3724

## Data Loaders

In [17]:
train_df.encoded_context.str.len().describe([.99])

count    16777.000000
mean       107.690648
std         43.493761
min          5.000000
50%        104.000000
99%        218.000000
max        412.000000
Name: encoded_context, dtype: float64

In [18]:
def custom_collate(batch):

  context = [torch.tensor(item['encoded_context']) for item in batch]
  padded_context = nn.utils.rnn.pad_sequence(context, batch_first= True, padding_value= PAD_ID)

  labels = torch.tensor([item['labels'] for item in batch], dtype = torch.float)

  batch = {"context": padded_context, "label": labels}
  return batch

In [19]:
batch_size = 2
train_dl = DataLoader(train_data, batch_size = batch_size, shuffle = True, collate_fn= custom_collate)

In [20]:
example = next(iter(train_dl))
example['context'].shape, example['label'].shape

  context = [torch.tensor(item['encoded_context']) for item in batch]


(torch.Size([2, 257]), torch.Size([2, 6]))

In [21]:
example['context']

tensor([[22208,   252, 15747, 12877, 20435, 11431,  1182, 18607, 16119,  2457,
         17669, 20878,  2303,  8179,  2303, 30032, 33737,  1182,  5904, 33108,
         11816, 35605, 11529, 11431,  2840, 25424,  2612,  6853, 28500, 28175,
         36257, 21636,   252, 15747, 11816, 17242,  4413,  1937, 37384, 25037,
          3110, 12877, 20435, 11431, 22157, 34767, 15747, 30678, 31467, 16819,
         31019, 34739,  2477,  1855, 16821, 36823, 11074,  1360,  4145,  6708,
         34837, 15747, 11431, 34536, 13322, 25424, 19269,  6579,  6412, 24380,
         30067, 32502, 30042,  1805,  6159, 32847, 30835, 36589,  9365, 11529,
         11431, 36589, 36589, 32680, 11431,   245, 34472, 30369, 24380, 21226,
         33482, 11789,  2612, 19407, 15938,  5471, 25275, 11789, 34760, 25633,
         37384, 15329, 33705, 17242, 11403, 28835, 19745, 19159, 24380, 21226,
         16819,  6412, 37555, 24380, 30067, 18432, 27990, 24933,  6853, 28500,
         33491, 35828, 31378, 12762, 24380, 23710, 1

In [22]:
example['label']

tensor([[0., 0., 0., 0., 1., 0.],
        [1., 0., 1., 0., 0., 0.]])

In [23]:
## dataloaders
batch_size = 64
train_dl = DataLoader(train_data, batch_size = batch_size, shuffle = True, collate_fn= custom_collate)
val_dl = DataLoader(val_data, batch_size = batch_size, shuffle = False, collate_fn= custom_collate)

## Build Model

In [41]:
class MultiLabelRNN(pl.LightningModule):

  def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, learning_rate, dropout, num_layers = 1, bidirectional = False):
    super().__init__()
    self.learning_rate = learning_rate
    self.bidirectional = bidirectional

    ## define loss & accuracy
    self.loss_fn = nn.BCEWithLogitsLoss()
    self.train_f1 = torchmetrics.F1Score(task="multilabel", num_labels=output_dim)
    self.val_f1 = torchmetrics.F1Score(task="multilabel", num_labels=output_dim)
    self.train_ham = torchmetrics.HammingDistance(task="multilabel", num_labels=output_dim)
    self.val_ham = torchmetrics.HammingDistance(task="multilabel", num_labels=output_dim)

    ## define layers
    self.embedding = nn.Embedding(num_embeddings= vocab_size, embedding_dim= embedding_dim, padding_idx= PAD_ID)
    self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first = True, num_layers = num_layers, bidirectional = bidirectional, dropout=dropout)
    self.relu = nn.ReLU()
    self.linear1 = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, 32)
    self.linear2 = nn.Linear(32, output_dim)


  def forward(self, text):
    """
    No need to apply softmax at the end as crossentropy implicitly apply the softmax
    """
    embedded = self.embedding(text)
    output, hidden = self.rnn(embedded)

    if self.bidirectional:
       ## concatnate last hidden layer of forward & backward
      hidden_squeezed = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
    else:
      hidden_squeezed = hidden[-1,:,:].squeeze(0)

    hidden_squeezed = self.relu(hidden_squeezed)
    out = self.linear1(hidden_squeezed)
    hidden = self.relu(out)
    logits = self.linear2(out)
    return logits

  def _shared_step(self, batch):
    text, label = batch['context'], batch['label']
    logits = self(text)
    loss = self.loss_fn(logits, label)
    return logits, loss, label

  def training_step(self, batch, batch_idx):
    logits, loss, label = self._shared_step(batch)
    self.train_f1(logits, label)
    self.train_ham(logits, label)
    self.log_dict({"train_loss": loss, "train_f1": self.train_f1,"train_ham" : self.train_ham}, on_step = False, on_epoch = True, prog_bar=True)
    return loss

  def validation_step(self,batch, batch_idx):
    logits, loss, label = self._shared_step(batch)
    self.val_f1(logits, label)
    self.val_ham(logits, label)
    self.log_dict({"val_loss": loss,  "val_f1": self.val_f1, "val_ham": self.val_ham}, on_step = False, on_epoch = True, prog_bar=True)
    return loss

  def on_training_epoch_end(self):
    self.train_f1.reset()
    self.train_ham.reset()

  def on_validation_epoch_end(self):
    print(f"Epoch : {self.current_epoch} Val F1 : {self.val_f1.compute()}  val ham : {self.val_ham.compute()}")
    self.val_f1.reset()
    self.val_ham.reset()

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate)
    return optimizer

In [43]:
## test model architecture
model = MultiLabelRNN(vocab_size = len(token2idx),
                      embedding_dim=100,
                      hidden_dim= 64,
                      output_dim= len(target_columns),
                      learning_rate= 1e-3,
                      dropout = 0.5,
                      num_layers= 2,
                      bidirectional = True
                      )

logits = model(example['context'])
model.loss_fn(logits, example['label'])

tensor(0.6947, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)

In [121]:
## Model Training

model = MultiLabelRNN(vocab_size = len(token2idx),
                      embedding_dim=100,
                      hidden_dim= 256,
                      output_dim= len(target_columns),
                      learning_rate= 1e-4,
                      dropout = 0.25,
                      num_layers= 2,
                      bidirectional = True
                      )

callbacks = pl.callbacks.ModelCheckpoint(dirpath = "checkpoints_logs",
                                         filename = '{epoch}-{val_loss:.2f}-{val_ham:.2f}',
                                          mode = "min",
                                          monitor = "val_ham",
                                          save_last = True,
                                          save_top_k=-1)


trainer = pl.Trainer(accelerator= "gpu",
           max_epochs=20,
           check_val_every_n_epoch = 2,
           callbacks = [callbacks])

trainer.fit(model, train_dl, val_dl)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type                      | Params
--------------------------------------------------------
0 | loss_fn   | BCEWithLogitsLoss         | 0     
1 | train_f1  | MultilabelF1Score         | 0     
2 | val_f1    | MultilabelF1Score         | 0     
3 | train_ham | MultilabelHammingDistance | 0     
4 | val_ham   | MultilabelHammingDistance | 0     
5 | embedding | Embedding                 | 3.8 M 
6 | rnn       | RNN                       | 577 K 
7 | relu      | ReLU                      | 0     
8 | linear1   | Linear       

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch : 0 Val F1 : 0.359375  val ham : 0.3203125


  context = [torch.tensor(item['encoded_context']) for item in batch]


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 1 Val F1 : 0.47309744358062744  val ham : 0.17274534702301025


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 3 Val F1 : 0.6377889513969421  val ham : 0.14318633079528809


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 5 Val F1 : 0.6396945118904114  val ham : 0.1386968493461609


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 7 Val F1 : 0.7104406952857971  val ham : 0.11668652296066284


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 9 Val F1 : 0.5559672117233276  val ham : 0.15713149309158325


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 11 Val F1 : 0.6647094488143921  val ham : 0.1315852403640747


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 13 Val F1 : 0.683581531047821  val ham : 0.1283273696899414


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 15 Val F1 : 0.7056596279144287  val ham : 0.11736196279525757


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch : 17 Val F1 : 0.7209969162940979  val ham : 0.11207789182662964


Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch : 19 Val F1 : 0.7277783155441284  val ham : 0.10889947414398193


## Predict

In [122]:
def predict(context):
  clean_context = process_context(context)
  encoded_context = convert_word_to_number_tensor(clean_context)
  encoded_context = encoded_context.view(1, -1)
  preds = model(encoded_context)
  preds = preds.detach().numpy().flatten()
  preds = np.array(preds>0.5).astype(int)
  print("Pred : ", [target_columns[i] for i, val in enumerate(preds) if val==1])

In [123]:
model = model.eval()

In [132]:
random_sample = val_df.sample().to_dict('records')[0]
context = random_sample['context']
label = random_sample['labels']
print("True : ",[target_columns[i] for i, val in enumerate(label) if val==1])
predict(context)

True :  ['computer science', 'statistics']
Pred :  ['computer science', 'statistics']
