In [8]:
import numpy as np
import pandas as pd
import torch
import os
import pickle
import torchmetrics
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizerFast, DataCollatorWithPadding
from transformers import BertForSequenceClassification, AdamW
import pytorch_lightning as pl
from transformers import BertTokenizer
from tqdm import tqdm
from sklearn.metrics import classification_report
from torchmetrics.functional import accuracy

tqdm.pandas()

In [None]:
#TODO: add proper requirements.txt

# Bert training (GPU necessary!)

In [None]:
class LyricsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {'input_ids': torch.as_tensor(self.encodings.iloc[idx])}
        item['labels'] = torch.as_tensor(self.labels.iloc[idx])
        return item

    def __len__(self):
        return len(self.encodings)

class LyricsClassifier(pl.LightningModule):
    def __init__(self, model_name='bert-base-uncased', num_labels=5):
        super().__init__()
        self.save_hyperparameters()
        self.bert = BertForSequenceClassification.from_pretrained(self.hparams.model_name,
                                                                  num_labels=self.hparams.num_labels)
        self.accuracy = torchmetrics.Accuracy(task="multiclass",compute_on_step=False, num_classes=num_labels)

        
    def forward(self, input_ids, labels=None):
        return self.bert(input_ids, labels=labels)
    
    def training_step(self, batch, batch_idx):
        outputs = self.forward(batch['input_ids'], batch['labels'])
        loss = outputs.loss
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(batch['input_ids'], batch['labels'])
        _, predicted = torch.max(outputs.logits, 1)
        correct = (predicted == batch['labels']).sum().item()
        accuracy = correct / len(batch['labels'])
        self.log('val_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return accuracy
        
    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=1e-5)


def load_data():
    with open("/content/drive/MyDrive/NLP/new_small_try/tokenized_lyrics_small.pickle", 'rb') as f:
        encodings = pickle.load(f)
    with open("/content/drive/MyDrive/NLP/new_small_try/labels_small.pickle", 'rb') as f:
        labels = pickle.load(f)

    return encodings, labels

def main():
    encodings, labels = load_data()

    #prepare tokenizer and data collator
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    #prepare datasets
    train_encodings, val_encodings, train_labels, val_labels = train_test_split(encodings, labels, test_size=0.1, random_state=42)
    train_dataset = LyricsDataset(train_encodings, train_labels)
    val_dataset = LyricsDataset(val_encodings, val_labels)

    model = LyricsClassifier()

    # data loaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=data_collator)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=data_collator)

    #Prepare trainer
    trainer = pl.Trainer(precision=16, limit_train_batches=0.5,max_epochs=3)

    # Training
    trainer.fit(model, train_loader, val_loader)
    
    return model, data_collator

model, data_collator = main()