In [None]:
from datasets import load_dataset, load_metric, Dataset, DatasetDict
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
from transformers import AutoTokenizer, DataCollatorWithPadding
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import classification_report
from sklearn import metrics
from scipy.special import softmax
import torch
from torch.utils.data import DataLoader

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

In [None]:
df = pd.read_csv("Q3/pre_main_data.csv")
df

In [None]:
dataset = Dataset.from_pandas(df)
trainvalid_test = dataset.train_test_split(test_size=0.2)
train_valid = trainvalid_test['train'].train_test_split(test_size=0.1)
dataset = DatasetDict({
    'train': train_valid['train'],
    'validation': train_valid['test'],
    'test': trainvalid_test['test'],
})
dataset

In [None]:
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def tokenize_function(example):
    return tokenizer(example["tweet"], truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["id", "tweet"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")


train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

tokenized_datasets

In [None]:
plm = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=3)
plm.to(device)

In [None]:
class Head(torch.nn.Module):
    def __init__(self, input_size=768, num_classes=3):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=input_size, 
                                  hidden_size=384, num_layers=1, 
                                  bidirectional=True, batch_first=True).to(device)
        net_list = [
            torch.nn.Linear(768, 512),
            torch.nn.Tanh(),
            torch.nn.LayerNorm(512),
            torch.nn.Dropout(0.1),
            torch.nn.Linear(512, num_classes)        
        ]
        self.label_net = torch.nn.Sequential(*net_list).to(device)
        self.training_criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=5e-4, weight_decay=0)
        
    def forward(self, plm_last_hidden_states):  # ~[8, 34, 768]
        x, (hn, cn) = self.lstm(plm_last_hidden_states)
        x = x[:, -1, :]  # Last LSTM
        x = self.label_net(x)
#         print(x.shape)
        return x

In [None]:
# progress_bar = tqdm(range(num_training_steps))

class Trainer:
    def __init__(self, plm):
        self.plm = plm
        self.head = Head()
        print(self.head)
        
    
    def train(self):
        plm.eval()
        self.head.train()
        for epoch in tqdm(range(5)):
            for batch in tqdm(train_dataloader):
                batch = {k: v.to(device) for k, v in batch.items() if k not in ["id", "tweet"]}
                outputs = self.plm(**batch, output_hidden_states=True, return_dict=True)
                hidden_states = torch.stack([val.detach() for val in outputs.hidden_states])  # ~[13, 8, 34, 768]
                last_hidden_states = hidden_states[-1].to(device)
                output = self.head(last_hidden_states)
#                 print(output)
#                 print(batch["labels"].float())
                loss = self.head.training_criterion(output.to(device), torch.tensor(batch["labels"], dtype=torch.long))
                loss.backward()
                self.head.optimizer.step()
                self.head.optimizer.zero_grad()
        #         progress_bar.update(1)

trainer = Trainer(plm)
trainer.train()

In [None]:
train_dataloader