In [None]:
#!pip install -requiremnts_bert.txt

In [None]:
import numpy as np
import pandas as pd
import torch
import os
import re
import pickle
import torchmetrics
import pytorch_lightning as pl
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizerFast, DataCollatorWithPadding
from transformers import BertForSequenceClassification, AdamW
from transformers import BertTokenizer
from tqdm import tqdm
from torchmetrics.functional import accuracy
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder
tqdm.pandas()

## BERT training

In [None]:
class LyricsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {'input_ids': torch.as_tensor(self.encodings.iloc[idx])}
        item['labels'] = torch.as_tensor(self.labels.iloc[idx])
        return item

    def __len__(self):
        return len(self.encodings)

class LyricsClassifier(pl.LightningModule):
    def __init__(self, model_name='bert-base-uncased', num_labels=5): #@RIES TRY "bert-large-uncased" with the A100 (BUT  Tokenization needs also to be adjusted)
        super().__init__()
        self.save_hyperparameters()
        self.bert = BertForSequenceClassification.from_pretrained(self.hparams.model_name,
                                                                  num_labels=self.hparams.num_labels)
        self.accuracy = torchmetrics.Accuracy(task="multiclass",compute_on_step=False, num_classes=num_labels)

    def forward(self, input_ids, labels=None):
        return self.bert(input_ids, labels=labels)
    
    def training_step(self, batch, batch_idx):
        outputs = self.forward(batch['input_ids'], batch['labels'])
        loss = outputs.loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(batch['input_ids'], batch['labels'])
        _, predicted = torch.max(outputs.logits, 1)
        correct = (predicted == batch['labels']).sum().item()
        accuracy = correct / len(batch['labels'])
        self.log('val_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return accuracy
        
    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=1e-5)

def load_data():
    #load pickle files (data and labels)

    with open("/content/drive/MyDrive/NLP/new_small_try/tokenized_lyrics_small.pickle", 'rb') as f:
        encodings = pickle.load(f)
    with open("/content/drive/MyDrive/NLP/new_small_try/labels_small.pickle", 'rb') as f:
        labels = pickle.load(f)

    #split into training and validation + test set
    train_encodings, train_labels, val_test_encodings, val_test_labels = train_test_split(encodings, labels, test_size=0.3, random_state=42)

    #split validation set into validation and test set
    val_encodings, val_labels, test_encodings, test_labels = train_test_split(val_test_encodings, val_test_labels, test_size=0.5, random_state=42)

    return train_encodings, train_labels, val_encodings, val_labels, test_encodings, test_labels


#prepare tokenizer and data collator
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

#prepare datasets
train_encodings, train_labels, val_encodings, val_labels, test_encodings, test_labels = load_data()

train_dataset = LyricsDataset(train_encodings, train_labels)
val_dataset = LyricsDataset(val_encodings, val_labels)
test_dataset = LyricsDataset(test_encodings, test_labels)

model = LyricsClassifier()

# data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=data_collator)

#Prepare trainer
trainer = pl.Trainer(precision=16, limit_train_batches=0.5,max_epochs=3)

# Training
trainer.fit(model, train_loader, val_loader)

## testing

### load model from checkpoint and test

In [None]:
#loading in model from checkpoint instead of training
#model = LyricsClassifier.load_from_checkpoint(checkpoint_path="/content/lightning_logs/version_3/checkpoints/epoch=0-step=1406.ckpt")

In [None]:
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=data_collator, num_workers=4)

# Move model to device once
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()  # Set the model to evaluation mode

true_labels = []
predicted_labels = []

for batch in tqdm(test_loader, desc="Inference"):
    batch_input_ids = batch['input_ids'].to(device)  
    batch_labels = batch['labels'].to(device)

    # Inference
    with torch.no_grad():  
        outputs = model(batch_input_ids, batch_labels)

    # Get the predicted labels
    _, preds = torch.max(outputs.logits, 1)
    predicted_labels.extend(preds.cpu().numpy())
    true_labels.extend(batch_labels.cpu().numpy())

# Classification report
print(classification_report(true_labels, predicted_labels, zero_division=0))

# Confusion Matrix
cm = confusion_matrix(true_labels, predicted_labels)
plt.figure(figsize=(10, 10))
sns.heatmap(cm, annot=True, fmt="d")
plt.title("Confusion matrix")
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
plt.show()