In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split 
from torchmetrics.classification import F1Score, Accuracy

from transformers import BartForSequenceClassification, BartTokenizer, AdamW, BartConfig
import pytorch_lightning as pl
from pytorch_lightning import callbacks
from pytorch_lightning.loggers import TensorBoardLogger

import pandas as pd
import numpy as np
import os

from typing import Any

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('medium')

In [3]:
DATA_PATH = '../data/VUA18/'

In [4]:
train_df = pd.read_csv(DATA_PATH+'train.tsv', sep='\t', encoding='utf-8')
test_df = pd.read_csv(DATA_PATH+'test.tsv', sep='\t', encoding='utf-8')

In [5]:
# print(f"The number of training samples is: {len(train_df)} and the number of test samples is: {len(test_df)}")

In [6]:
train_df.head()

Unnamed: 0,index,label,sentence,POS,w_index
0,b1g-fragment02 841,0,If it now seems self-evident that monitoring o...,ADP,42
1,fef-fragment03 667,0,Which equation should we use in a practical ca...,NOUN,10
2,as6-fragment01 76,0,It was initiated partly in response to the fur...,ADP,10
3,ew1-fragment01 108,1,You fully know as an old pressman the difficul...,VERB,10
4,fpb-fragment01 1152,0,It was a condition of her gift to you the ten ...,ADJ,5


In [20]:
train_df['label'].value_counts()

0    85177
1    12481
Name: label, dtype: int64

In [21]:
from sklearn.utils.class_weight import compute_class_weight

# Assuming you have access to the training dataset
# Replace `train_dataset` with your actual training dataset
class_weights = compute_class_weight('balanced', classes=[0, 1], y=train_df['label'])
class_weights = torch.tensor(class_weights, dtype=torch.float32) # Convert to tensor


In [22]:
class_weights

tensor([0.5733, 3.9123])

In [7]:
class MetaphorDataset(Dataset):
    def __init__(self, df, tokenizer, max_length):
        self.data = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx) -> dict:
        sentence = self.data.iloc[idx]['sentence']
        label = self.data.iloc[idx]['label']
        pos_tag = self.data.iloc[idx]['POS']

        sentence_encoding = self.tokenizer(sentence, 
                                           truncation = True, 
                                           padding='max_length', 
                                           max_length=self.max_length,
                                           return_tensors= 'pt')
        # pos_encoding = self.tokenizer(pos_tag,
        #                               truncation = True,
        #                               padding = 'max_length',
        #                               max_length = 5,
        #                               return_tensors= 'pt')
        
        return {
            'input_ids': sentence_encoding['input_ids'],
            'attention_mask': sentence_encoding['attention_mask'],
            # 'pos_ids': pos_encoding['input_ids'],
            'label': torch.tensor(label, dtype = torch.long)
        }

In [8]:
class CustomBartForSequenceClassification(BartForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.classification_head = nn.Linear(config.d_model, 2)  # Add your custom classification head

config = BartConfig.from_pretrained('facebook/bart-large')
model = CustomBartForSequenceClassification(config)

In [9]:
# model_name = "facebook/bart-large-mnli"
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
# model = BartForSequenceClassification.from_pretrained(model_name, num_labels=2)

In [10]:
max_length = 128  # You can adjust this based on your dataset
batch_size = 32   # You can adjust this as well
train_dataset = MetaphorDataset(train_df, tokenizer, max_length)
test_dataset = MetaphorDataset(test_df, tokenizer, max_length)


In [11]:
val_size = int(0.2 * len(train_dataset))
train_size = int(len(train_dataset)- val_size)
train_set, val_set = random_split(train_dataset, [train_size, val_size])

print(f"samples in train set: {len(train_set)}")
print(f"samples in test set: {len(test_dataset)}")
print(f"samples in val set: {len(val_set)}")

samples in train set: 78127
samples in test set: 43947
samples in val set: 19531


In [12]:
train_loader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(val_set, batch_size = batch_size, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [13]:
for batch in test_loader:
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    # pos_ids = batch['pos_ids']
    labels = batch['label']

    print("Input IDs:", input_ids.shape)
    print("Attention Mask:", attention_mask.shape)
    # print("POS IDs:", pos_ids.shape)
    print("Labels:", labels.shape)
    break

Input IDs: torch.Size([32, 1, 128])
Attention Mask: torch.Size([32, 1, 128])
Labels: torch.Size([32])


In [14]:
config = {
    'lr' : 2e-5
}

In [26]:
class MetaphorClassifier(pl.LightningModule):
    def __init__(self, num_classes, model_name='facebook/bart-base', config=config, class_weights = None):
        super(MetaphorClassifier, self).__init__()
        self.save_hyperparameters()
        self.config = config
        self.class_weights = class_weights

        # Load pre-trained BART model and tokenizer
        self.bart = BartForSequenceClassification.from_pretrained(model_name, num_labels=num_classes)
        self.tokenizer = BartTokenizer.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        return self.bart(input_ids=input_ids, attention_mask=attention_mask)[0]

    def common_step(self, batch, step_type: str):
        input_ids = batch['input_ids'].squeeze(1)
        attention_mask = batch['attention_mask'].squeeze(1)
        labels = batch['label']

        logits = self(input_ids, attention_mask)

        device = logits.device

        loss_fn = nn.CrossEntropyLoss(weight = self.class_weights.to(device))
        loss = loss_fn(logits, labels)

        device = logits.device
        predictions = torch.argmax(logits, dim=1).to(device)
        labels = labels.to(device)

        accuracy = Accuracy(task='binary').to(device)

        self.log(f'{step_type}_loss', loss)
        self.log(f'{step_type}_acc', accuracy(predictions, labels))
        return loss, logits, labels

    def training_step(self, batch, batch_idx):
        loss, logits, labels = self.common_step(batch, step_type='train')
        return loss

    def validation_step(self, batch, batch_idx):
        loss, logits, labels = self.common_step(batch, step_type='val')
        # device = logits.device
        predictions = torch.argmax(logits, dim=1).to(logits.device)
        labels = labels.to(logits.device)

        f1 = F1Score(task='binary').to(logits.device)
        self.log('F1_score', f1(predictions, labels))

        return loss

    def test_step(self, batch, batch_idx):
        loss, logits, labels = self.common_step(batch, step_type='test')
        predictions = torch.argmax(logits, dim=1).to(logits.device)
        labels = labels.to(logits.device)

        accuracy = Accuracy().to(logits.device)
        f1 = F1Score(task='binary').to(logits.device)

        self.log('test_acc', accuracy(logits, labels))
        self.log('test_f1', f1(predictions, labels))

    def configure_optimizers(self):
        return optim.AdamW(self.parameters(), lr=self.config['lr'])


In [27]:
def train_distilbert(config, num_epochs = 8, num_classes=2, checkpoint = None):

    model = MetaphorClassifier(config = config , num_classes = num_classes, class_weights=class_weights)
    tlogger = TensorBoardLogger(save_dir="metaphor-logs", name ="bart-full", version = 'v1')

    db_callbacks = [
        callbacks.ModelCheckpoint(monitor= 'val_loss',
                                  save_top_k = 1,
                                   save_on_train_epoch_end= False,
                                    filename = '{epoch}-{val_loss:.2f}' )

    ]

    trainer = pl.Trainer(accelerator="gpu",
                         logger= tlogger,
                         log_every_n_steps =2,
                         precision = 16,
                         enable_checkpointing= True,
                         callbacks= db_callbacks,
                         devices = 1,
                         enable_progress_bar= True,
                         max_epochs= num_epochs)
    
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=checkpoint)
    return trainer

In [28]:
db_trainer = train_distilbert(config= config, num_epochs= 10)

Some weights of BartForSequenceClassification were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight', 'classification_head.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type                          | Params
-------------------------------------------------------
0 | bart | BartForSequenceClassification | 140 M 
-------------------------------------------------------
140 M     Trainable params
0         Non-trainable par

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 9: 100%|██████████| 2442/2442 [11:52<00:00,  3.43it/s, v_num=v1]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 2442/2442 [11:52<00:00,  3.43it/s, v_num=v1]


In [29]:
%reload_ext tensorboard
%tensorboard --logdir="/home/vri/Projects/research/metaphor-detection/notebooks/metaphor-logs" --host localhost --port 8081



In [30]:
db_trainer.validate(dataloaders=test_loader, ckpt_path='best')

Restoring states from the checkpoint path at metaphor-logs/bart-full/v1/checkpoints/epoch=0-val_loss=0.62.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at metaphor-logs/bart-full/v1/checkpoints/epoch=0-val_loss=0.62.ckpt
  rank_zero_warn(


Validation DataLoader 0: 100%|██████████| 1374/1374 [01:54<00:00, 12.00it/s]


[{'val_loss': 0.6664794087409973,
  'val_acc': 0.5613579750061035,
  'F1_score': 0.2079881876707077}]