In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split 
from torchmetrics.classification import F1Score, Accuracy

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

from transformers import DistilBertTokenizer, DistilBertForTokenClassification, DistilBertConfig, AdamW, DistilBertModel
import pytorch_lightning as pl
from pytorch_lightning import callbacks
from pytorch_lightning.loggers import TensorBoardLogger

import pandas as pd
import numpy as np
import os

from typing import Any

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision('medium')

In [3]:
DATA_PATH = '../data/VUA18/'

In [4]:
train_df = pd.read_csv(DATA_PATH+'train.tsv', sep='\t', encoding='utf-8')
test_df = pd.read_csv(DATA_PATH+'test.tsv', sep='\t', encoding='utf-8')

In [5]:
# print(f"The number of training samples is: {len(train_df)} and the number of test samples is: {len(test_df)}")

In [6]:
train_df.head()

Unnamed: 0,index,label,sentence,POS,w_index
0,b1g-fragment02 841,0,If it now seems self-evident that monitoring o...,ADP,42
1,fef-fragment03 667,0,Which equation should we use in a practical ca...,NOUN,10
2,as6-fragment01 76,0,It was initiated partly in response to the fur...,ADP,10
3,ew1-fragment01 108,1,You fully know as an old pressman the difficul...,VERB,10
4,fpb-fragment01 1152,0,It was a condition of her gift to you the ten ...,ADJ,5


In [7]:
train_df, val_df = train_test_split(train_df, test_size=0.2)

In [8]:
train_df

Unnamed: 0,index,label,sentence,POS,w_index
25312,as6-fragment01 79,0,The scope was extended in an enhanced Urban Pr...,VERB,2
65590,fet-fragment01 197,0,In the churchyard in his last parish a family ...,ADJ,5
23269,as6-fragment01 72,1,The main thrust of the Government's policy is ...,NOUN,46
87309,kcu-fragment02 2108,0,come home from work she said I've been looking...,VERB,24
73907,kbj-fragment17 1470,1,Those two add?,VERB,2
...,...,...,...,...,...
69075,ac2-fragment06 1482,0,We're not gon na change anyone's mind.,ADV,1
30130,fet-fragment01 36,0,There is still a great deal of Greece all thro...,NOUN,22
80941,a1h-fragment06 144,0,"It was a blessing that, in response to congrat...",INTJ,27
41098,b1g-fragment02 798,0,This contribution cuts across many of the ques...,NOUN,7


In [9]:
val_df

Unnamed: 0,index,label,sentence,POS,w_index
91201,a3e-fragment03 35,0,"You may, even as a novice, care so passionatel...",ADJ,17
80346,kb7-fragment10 2284,0,I mean we've just been to look at some others ...,NOUN,9
36265,cdb-fragment04 945,0,"In high spirits, his father was talking about ...",ADJ,16
76505,c8t-fragment01 95,0,She said briefly: They know me.,ADV,2
19291,a6u-fragment02 294,0,The physicality so characteristic of Kahlo's w...,DET,17
...,...,...,...,...,...
43437,fet-fragment01 38,0,"The youth of it, Alexander thought.",ADP,2
42038,a1j-fragment34 558,0,There could also be controversy over the execu...,ADP,16
94551,a6u-fragment02 274,0,"This small painting on metal, in the style of ...",ADJ,19
89914,g0l-fragment01 68,0,"Our information continued for a while, but it ...",NOUN,1


In [10]:
class_weights = compute_class_weight('balanced', classes=[0, 1], y=train_df['label'])
new_weights = torch.tensor(class_weights, dtype=torch.float32) 

In [11]:
class MetaphorDataset(Dataset):
    def __init__(self, df, tokenizer, max_length):
        self.data = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx) -> dict:
        sentence = self.data.iloc[idx]['sentence']
        label = self.data.iloc[idx]['label']
        pos_tag = self.data.iloc[idx]['POS']

        sentence_encoding = self.tokenizer(sentence, 
                                           truncation = True, 
                                           padding='max_length', 
                                           max_length=self.max_length,
                                           return_tensors= 'pt')
        # pos_encoding = self.tokenizer(pos_tag,
        #                               truncation = True,
        #                               padding = 'max_length',
        #                               max_length = 5,
        #                               return_tensors= 'pt')
        
        return {
            'input_ids': sentence_encoding['input_ids'],
            'attention_mask': sentence_encoding['attention_mask'],
            # 'pos_ids': pos_encoding['input_ids'],
            'label': torch.tensor(label, dtype = torch.long)
        }

In [12]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# model = DistilBertModel.from_pretrained('distilbert-base-uncased')


In [13]:
max_length = 128  # You can adjust this based on your dataset
batch_size = 32   # You can adjust this as well
train_dataset = MetaphorDataset(train_df, tokenizer, max_length)
val_dataset = MetaphorDataset(val_df, tokenizer, max_length)
test_dataset = MetaphorDataset(test_df, tokenizer, max_length)


In [14]:
# val_size = int(0.2 * len(train_dataset))
# train_size = int(len(train_dataset)- val_size)
# train_set, val_set = random_split(train_dataset, [train_size, val_size])

print(f"samples in train set: {len(train_dataset)}")
print(f"samples in test set: {len(test_dataset)}")
print(f"samples in val set: {len(val_dataset)}")

samples in train set: 78126
samples in test set: 43947
samples in val set: 19532


In [15]:
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = batch_size, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [16]:
for batch in test_loader:
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    # pos_ids = batch['pos_ids']
    labels = batch['label']

    print("Input IDs:", input_ids.shape)
    print("Attention Mask:", attention_mask.shape)
    # print("POS IDs:", pos_ids.shape)
    print("Labels:", labels.shape)
    break

Input IDs: torch.Size([32, 1, 128])
Attention Mask: torch.Size([32, 1, 128])
Labels: torch.Size([32])


In [17]:
config = {
    'lr' : 1e-3
}

In [24]:
class MetaphorClassifier(pl.LightningModule):
    def __init__(self, num_classes, model_name = 'distilbert-base-uncased', config= config, class_weights = None) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.config = config 
        self.model = DistilBertModel.from_pretrained(model_name)
        self.weights = class_weights
        hidden_size = self.model.config.hidden_size 
        self.classifier = nn.Linear(hidden_size, num_classes)

        # print(self.model)

    def forward(self, batch):
        input_ids = batch['input_ids'].squeeze(1)
        attention_mask = batch['attention_mask'].squeeze(1)
        # pos_ids = batch['pos_ids']
        sen_outputs = self.model(input_ids = input_ids, attention_mask = attention_mask)
        pooled_sen_output = sen_outputs['last_hidden_state'][:, 0]

        # pos_outputs = self.model(input_ids = pos_ids)[0]
        # pooled_pos_output = pos_outputs[:, 0]

        # combined_output = torch.cat((pooled_sen_output, pooled_pos_output), dim=1)

        # if pos_ids is not None:
        #     pos_outputs = self.model(input_ids=pos_ids)[0]
        #     pooled_pos_output = pos_outputs[:, 0]
        #     combined_output = torch.cat((pooled_sen_output, pooled_pos_output), dim=1)
        # else:
        #     combined_output = pooled_sen_output

        logits = self.classifier(pooled_sen_output)
        return logits
    
    def common_step(self, batch, step_type:str):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        # pos_ids = batch['pos_ids']

        data = {'input_ids': input_ids,
                'attention_mask': attention_mask}
                # 'pos_ids' : pos_ids}

        labels = batch['label']
        logits = self.forward(data)

        device = logits.device
        predictions = torch.argmax(logits, dim=1).to(device)
        labels = labels.to(device)
        loss = nn.CrossEntropyLoss(weight=self.weights.to(device))(logits, labels)

        accuracy = Accuracy(task='binary').to(device)
        
        self.log(f'{step_type}_loss', loss)
        self.log(f'{step_type}_acc', accuracy(predictions, labels))
        return loss, logits, labels
    
    def accuracy(self, correct, labels):
        return correct/len(labels)


    def training_step(self, batch, batch_idx):
        loss, logits, labels = self.common_step(batch, step_type='train')
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, logits, labels = self.common_step(batch, step_type='val')
        device = logits.device
        predictions = torch.argmax(logits, dim=1).to(device)
        labels = labels.to(device)
        f1=F1Score(task='binary').to(device)

        self.log('F1_score', f1(predictions, labels))
        return loss
    
    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        data = {'input_ids': input_ids,
                'attention_mask': attention_mask}
                # 'pos_ids' : pos_ids}

        # labels = batch['label']
        logits = self.forward(data)

        device = logits.device
        predictions = torch.argmax(logits, dim=1).to(device)
        labels = labels.to(device)
        accuracy = Accuracy(task='binary').to(device)
        f1=F1Score(task='binary').to(device)

        self.log('test_acc', accuracy(predictions, labels))
        self.log('test_f1', predictions, labels)

        # accuracy = Accuracy()(logits, labels)
        # f1 = F1Score(logits, labels)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr = self.config['lr'])



In [25]:
def train_distilbert(config, num_epochs = 8, num_classes=2, checkpoint = None):

    model = MetaphorClassifier(config = config , num_classes = num_classes, class_weights=new_weights)
    tlogger = TensorBoardLogger(save_dir="metaphor-logs", name ="distilbert-weighted", version = 'v1')

    db_callbacks = [
        callbacks.ModelCheckpoint(monitor= 'val_loss',
                                  save_top_k = 1,
                                   save_on_train_epoch_end= False,
                                    filename = '{epoch}-{val_loss:.2f}' )

    ]

    trainer = pl.Trainer(accelerator="gpu",
                         logger= tlogger,
                         log_every_n_steps =2,
                         precision = 16,
                         enable_checkpointing= True,
                         callbacks= db_callbacks,
                         devices = 1,
                         enable_progress_bar= True,
                         max_epochs= num_epochs)
    
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=checkpoint)
    return trainer

In [26]:
db_trainer = train_distilbert(config= config, num_epochs= 10)

  rank_zero_warn(
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type            | Params
-----------------------------------------------
0 | model      | DistilBertModel | 66.4 M
1 | classifier | Linear          | 1.5 K 
-----------------------------------------------
66.4 M    Trainable params
0         Non-trainable params
66.4 M    Total params
265.458   Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 9: 100%|██████████| 2442/2442 [05:03<00:00,  8.06it/s, v_num=v1]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 2442/2442 [05:03<00:00,  8.06it/s, v_num=v1]


In [None]:
%reload_ext tensorboard
%tensorboard --logdir="/home/vri/Projects/research/metaphor-detection/notebooks/metaphor-logs" --host localhost --port 8081



In [27]:
db_trainer.test(dataloaders=test_loader, ckpt_path='best')

Restoring states from the checkpoint path at metaphor-logs/distilbert-weighted/v1/checkpoints/epoch=0-val_loss=0.69.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at metaphor-logs/distilbert-weighted/v1/checkpoints/epoch=0-val_loss=0.69.ckpt
  rank_zero_warn(


Testing DataLoader 0:   0%|          | 0/1374 [00:00<?, ?it/s]

ValueError: `self.log(test_f1, tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0'))` was called, but the tensor must have a single element. You can try doing `self.log(test_f1, tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0').mean())`