In [None]:
import wandb
import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer

In [None]:
# Make everything deterministic
pl.seed_everything(42)


In [None]:
wandb.login()

In [None]:
from typing import Any


class LyricsTranscriptor(pl.LightningModule):
    def __init__(self, model, loss):
        super().__init__()
        self.model = model
        self.loss_fn = loss

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.loss_fn(y_hat, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        return 0
    
    def test_step(self, batch, batch_idx):
        return 0

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
class LyricsDataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, val_dataset, batch_size=32):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

In [None]:
MAX_EPOCHS = 1000

In [None]:
early_stopping = EarlyStopping(monitor='val_loss',  patience=5 ,mode="min", verbose=True)
wandb_logger = WandbLogger(log_model="all")
trainer = Trainer(logger=wandb_logger,callbacks=[early_stopping])
loss = F.cross_entropy

In [None]:
train_dataset = None
val_dataset = None
datamodule = LyricsDataModule(train_dataset, val_dataset)

model = None
segmenter = LyricsTranscriptor(model=model, loss=loss)

wandb_logger.watch(model)
trainer = pl.Trainer(max_epochs=MAX_EPOCHS, logger=wandb_logger, )
trainer.fit(segmenter, datamodule=datamodule)
trainer.test(segmenter, datamodule=datamodule)
wandb.finish()