In [11]:
import pandas as pd
import torch
from torch import nn
from src.utils import initialize_lyrics_tokenizer, initialize_midi_tokenizer, load_checkpoint
from src.model import LyricsGenerator
from src.data import prepare_dataloaders

def calculate_perplexity(model, dataloader, loss_fct, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            # Move data to the correct device
            lyrics_ids = batch['lyrics_ids'].to(device)
            lyrics_attention_mask = batch['lyrics_attention_mask'].to(device)
            midi_tokens = batch['midi_tokens'].to(device)

            with torch.autocast(device_type=device.type):
                logits = model(lyrics_ids, lyrics_attention_mask, midi_tokens)
                loss = loss_fct(logits.transpose(1, 2), lyrics_ids)
                total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    perplexity = torch.exp(torch.tensor(avg_loss))
    return perplexity.item()


In [14]:
# Initialize model and tokenizer
lyrics_tokenizer = initialize_lyrics_tokenizer()
midi_tokenizer = initialize_midi_tokenizer()
model = LyricsGenerator(lyrics_tokenizer, d_model=768, max_lyrics_length=512, max_midi_length=512)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Load model from checkpoint
model = load_checkpoint(model=model, path='model_checkpoint/final_checkpoint.pth', inference=True, device=device)
# Data
df = pd.read_csv('data/lyrics_midi_data.csv')
train_dataloader, val_dataloader = prepare_dataloaders(
            df=df,
            lyrics_tokenizer=lyrics_tokenizer,
            midi_tokenizer=midi_tokenizer,
            max_length=512,
            root_dir='data/',
            batch_size=4,
        )

# Loss function
loss_fct = nn.CrossEntropyLoss(ignore_index=lyrics_tokenizer.pad_token_id)

# Calculate perplexity
val_perplexity = calculate_perplexity(model, val_dataloader, loss_fct, device)
print(f"Validation Perplexity: {val_perplexity:.4f}")

Checkpoint loaded for inference.
Validation Perplexity: 1615.6462
