In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split 
from torchmetrics.classification import F1Score, Accuracy

from transformers import DistilBertTokenizer, DistilBertForTokenClassification, DistilBertConfig, AdamW, DistilBertModel
import pytorch_lightning as pl
from pytorch_lightning import callbacks
from pytorch_lightning.loggers import TensorBoardLogger

import pandas as pd
import numpy as np
import os

from typing import Any

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('medium')

In [3]:
DATA_PATH = '../data/VUA18/'

In [4]:
train_df = pd.read_csv(DATA_PATH+'train.tsv', sep='\t', encoding='utf-8')
test_df = pd.read_csv(DATA_PATH+'test.tsv', sep='\t', encoding='utf-8')

In [5]:
# print(f"The number of training samples is: {len(train_df)} and the number of test samples is: {len(test_df)}")

In [6]:
train_df.head()

Unnamed: 0,index,label,sentence,POS,w_index
0,b1g-fragment02 841,0,If it now seems self-evident that monitoring o...,ADP,42
1,fef-fragment03 667,0,Which equation should we use in a practical ca...,NOUN,10
2,as6-fragment01 76,0,It was initiated partly in response to the fur...,ADP,10
3,ew1-fragment01 108,1,You fully know as an old pressman the difficul...,VERB,10
4,fpb-fragment01 1152,0,It was a condition of her gift to you the ten ...,ADJ,5


In [7]:
class MetaphorDataset(Dataset):
    def __init__(self, df, tokenizer, max_length):
        self.data = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx) -> dict:
        sentence = self.data.iloc[idx]['sentence']
        label = self.data.iloc[idx]['label']
        pos_tag = self.data.iloc[idx]['POS']

        sentence_encoding = self.tokenizer(sentence, 
                                           truncation = True, 
                                           padding='max_length', 
                                           max_length=self.max_length,
                                           return_tensors= 'pt')
        # pos_encoding = self.tokenizer(pos_tag,
        #                               truncation = True,
        #                               padding = 'max_length',
        #                               max_length = 5,
        #                               return_tensors= 'pt')
        
        return {
            'input_ids': sentence_encoding['input_ids'],
            'attention_mask': sentence_encoding['attention_mask'],
            # 'pos_ids': pos_encoding['input_ids'],
            'label': torch.tensor(label, dtype = torch.long)
        }

In [8]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# model = DistilBertModel.from_pretrained('distilbert-base-uncased')


In [9]:
max_length = 128  # You can adjust this based on your dataset
batch_size = 32   # You can adjust this as well
train_dataset = MetaphorDataset(train_df, tokenizer, max_length)
test_dataset = MetaphorDataset(test_df, tokenizer, max_length)


In [10]:
val_size = int(0.2 * len(train_dataset))
train_size = int(len(train_dataset)- val_size)
train_set, val_set = random_split(train_dataset, [train_size, val_size])

print(f"samples in train set: {len(train_set)}")
print(f"samples in test set: {len(test_dataset)}")
print(f"samples in val set: {len(val_set)}")

samples in train set: 78127
samples in test set: 43947
samples in val set: 19531


In [11]:
train_loader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(val_set, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [12]:
for batch in test_loader:
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    # pos_ids = batch['pos_ids']
    labels = batch['label']

    print("Input IDs:", input_ids.shape)
    print("Attention Mask:", attention_mask.shape)
    # print("POS IDs:", pos_ids.shape)
    print("Labels:", labels.shape)
    break

Input IDs: torch.Size([32, 1, 128])
Attention Mask: torch.Size([32, 1, 128])
Labels: torch.Size([32])


In [13]:
config = {
    'lr' : 2e-5
}

In [77]:
class MetaphorClassifier(pl.LightningModule):
    def __init__(self, num_classes, model_name = 'distilbert-base-uncased', config= config) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.config = config 
        self.model = DistilBertModel.from_pretrained(model_name)
        hidden_size = self.model.config.hidden_size 
        self.classifier = nn.Linear(hidden_size, num_classes)

        # print(self.model)

    def forward(self, batch):
        input_ids = batch['input_ids'].squeeze(1)
        attention_mask = batch['attention_mask'].squeeze(1)
        # pos_ids = batch['pos_ids']
        sen_outputs = self.model(input_ids = input_ids, attention_mask = attention_mask)
        pooled_sen_output = sen_outputs['last_hidden_state'][:, 0]

        # pos_outputs = self.model(input_ids = pos_ids)[0]
        # pooled_pos_output = pos_outputs[:, 0]

        # combined_output = torch.cat((pooled_sen_output, pooled_pos_output), dim=1)

        # if pos_ids is not None:
        #     pos_outputs = self.model(input_ids=pos_ids)[0]
        #     pooled_pos_output = pos_outputs[:, 0]
        #     combined_output = torch.cat((pooled_sen_output, pooled_pos_output), dim=1)
        # else:
        #     combined_output = pooled_sen_output

        logits = self.classifier(pooled_sen_output)
        return logits
    
    def common_step(self, batch, step_type:str):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        # pos_ids = batch['pos_ids']

        data = {'input_ids': input_ids,
                'attention_mask': attention_mask}
                # 'pos_ids' : pos_ids}

        labels = batch['label']
        logits = self.forward(data)

        device = logits.device
        predictions = torch.argmax(logits, dim=1).to(device)
        labels = labels.to(device)
        loss = nn.CrossEntropyLoss()(logits, labels)

        accuracy = Accuracy(task='binary').to(device)
        
        self.log(f'{step_type}_loss', loss)
        self.log(f'{step_type}_acc', accuracy(predictions, labels))
        return loss, logits, labels
    
    def accuracy(self, correct, labels):
        return correct/len(labels)


    def training_step(self, batch, batch_idx):
        loss, logits, labels = self.common_step(batch, step_type='train')
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, logits, labels = self.common_step(batch, step_type='val')
        device = logits.device
        predictions = torch.argmax(logits, dim=1).to(device)
        labels = labels.to(device)
        f1=F1Score(task='binary').to(device)

        self.log('F1_score', f1(predictions, labels))
        return loss
    
    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        data = {'input_ids': input_ids,
                'attention_mask': attention_mask}
                # 'pos_ids' : pos_ids}

        # labels = batch['label']
        logits = self.forward(data)

        device = logits.device
        predictions = torch.argmax(logits, dim=1).to(device)
        labels = labels.to(device)
        accuracy = Accuracy(task='binary').to(device)
        f1=F1Score(task='binary').to(device)

        self.log('test_acc', accuracy(predictions, labels))
        self.log('test_f1', predictions, labels)

        # accuracy = Accuracy()(logits, labels)
        # f1 = F1Score(logits, labels)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr = self.config['lr'])



In [74]:
def train_distilbert(config, num_epochs = 8, num_classes=2, checkpoint = None):

    model = MetaphorClassifier(config = config , num_classes = num_classes)
    tlogger = TensorBoardLogger(save_dir="metaphor-logs", name ="distilbert-full", version = 'v1')

    db_callbacks = [
        callbacks.ModelCheckpoint(monitor= 'val_loss',
                                  save_top_k = 1,
                                   save_on_train_epoch_end= False,
                                    filename = '{epoch}-{val_loss:.2f}' )

    ]

    trainer = pl.Trainer(accelerator="gpu",
                         logger= tlogger,
                         log_every_n_steps =2,
                         precision = 16,
                         enable_checkpointing= True,
                         callbacks= db_callbacks,
                         devices = 1,
                         enable_progress_bar= True,
                         max_epochs= num_epochs)
    
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=checkpoint)
    return trainer

In [75]:
db_trainer = train_distilbert(config= config, num_epochs= 10)

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type            | Params
-----------------------------------------------
0 | model      | DistilBertModel | 66.4 M
1 | classifier | Linear          | 1.5 K 
-----------------------------------------------
66.4 M    Trainable params
0         Non-trainable params
66.4 M    Total params
265.458   Total estimated model params size (MB)


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00, 21.89it/s]

  rank_zero_warn(
  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 9: 100%|██████████| 2442/2442 [05:07<00:00,  7.95it/s, v_num=v1]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 2442/2442 [05:07<00:00,  7.95it/s, v_num=v1]


In [69]:
%reload_ext tensorboard
%tensorboard --logdir="/home/vri/Projects/research/metaphor-detection/notebooks/metaphor-logs" --host localhost --port 8081



In [81]:
db_trainer.validate(dataloaders=test_loader, ckpt_path='best')

Restoring states from the checkpoint path at metaphor-logs/distilbert-full/v1/checkpoints/epoch=0-val_loss=0.36-v2.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at metaphor-logs/distilbert-full/v1/checkpoints/epoch=0-val_loss=0.36-v2.ckpt
  rank_zero_warn(


Validation DataLoader 0: 100%|██████████| 1374/1374 [01:00<00:00, 22.70it/s]


[{'val_loss': 0.3958614766597748,
  'val_acc': 0.8581245541572571,
  'F1_score': 0.002378622768446803}]