In [None]:
import random

import pandas as pd

import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss

from lib import ReplacementDataset, get_replacement_model, get_replacement_tokenizer, get_replacement_collate_fn

# import LightningModule
from pytorch_lightning import LightningModule, Trainer

In [None]:
df = pd.read_csv("Tweets.csv")
text_samples = df["text"].tolist()
text_samples[:10]

In [None]:
class ReplacementLangaugeModel(LightningModule):
    
        def __init__(self, model_name, lr=1e-4):
            super().__init__()
            self.tokenizer = get_replacement_tokenizer(model_name, empty_token="[EMT]")
            self.empty_id = self.tokenizer.convert_tokens_to_ids("[EMT]")
            self.model = get_replacement_model(model_name, len(self.tokenizer))
            self.lr = lr
            self.loss_fn = CrossEntropyLoss()
    

        def forward(self, input_ids, attention_mask):
            return self.model(input_ids, attention_mask=attention_mask)
    

        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
            return optimizer
        

        def training_step(self, batch, batch_idx):
            print(batch["input_ids"].shape, batch["attention_mask"].shape, batch["labels"].shape)
            outputs = self(batch["input_ids"], batch["attention_mask"])
            loss = self.loss_fn(outputs.logits.permute(0, 2, 1), batch["labels"].permute(0, 2, 1))
            self.log("train_loss", loss)
            return loss
        

        def validation_step(self, batch, batch_idx):
            outputs = self(batch["input_ids"], batch["attention_mask"])
            loss = self.loss_fn(outputs.logits.permute(0, 2, 1), batch["labels"].permute(0, 2, 1))
            self.log("val_loss", loss)
            return loss
        


In [None]:
model = ReplacementLangaugeModel("distilbert-base-uncased")
tokenizer = model.tokenizer
empty_id = model.empty_id

In [None]:
train_size = int(0.8 * len(text_samples))
train_samples = text_samples[:train_size]
val_samples = text_samples[train_size:]

dataset = ReplacementDataset(text_samples, tokenizer, empty_id)
collate_fn = get_replacement_collate_fn(tokenizer)

train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn, num_workers=12)
val_loader = DataLoader(dataset, batch_size=32, shuffle=False, collate_fn=collate_fn, num_workers=12)

In [None]:
trainer = Trainer(default_root_dir="checkpoints/")
trainer.fit(model, train_loader, val_loader)