In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2LMHeadModel, GPT2Tokenizer, get_linear_schedule_with_warmup
from tqdm import tqdm
import random
import numpy as np
# For reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
import random
import torch
from torch.utils.data import Dataset
from copy import deepcopy
from symusic import Score
from miditok import TSD, TokenizerConfig

# Random MIDI pitch augmentation
def randomize_midi_pitch(midi_score, prob=0.2, max_change=4):
    new_score = deepcopy(midi_score)
    for track in new_score.tracks:
        for note in track.notes:
            if random.random() < prob:
                change = random.randint(-max_change, max_change)
                note.pitch = max(0, min(note.pitch + change, 127))
    return new_score

# Dataset class
class LyricsMidiDataset(Dataset):
    def __init__(self, dataframe, lyrics_tokenizer, midi_tokenizer, max_length, root_dir=None, augment=False):
        self.dataframe = dataframe
        self.lyrics_tokenizer = lyrics_tokenizer
        self.midi_tokenizer = midi_tokenizer
        self.max_length = max_length
        self.augment = augment
        self.root_dir = root_dir

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        lyrics = self.dataframe.iloc[idx]['lyrics']
        midi_path = self.dataframe.iloc[idx]['midi_path']
        if self.root_dir:
            midi_path = os.path.join(self.root_dir, midi_path)

        midi_path = os.path.normpath(midi_path)
        if not os.path.isfile(midi_path):
            raise FileNotFoundError(f"MIDI file not found: {midi_path}")

        lyrics_tokens = self.lyrics_tokenizer(
            lyrics + self.lyrics_tokenizer.eos_token,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Ensure padding tokens are correctly replaced for GPT-2 compatibility
        lyrics_tokens['input_ids'][lyrics_tokens['input_ids'] == self.lyrics_tokenizer.pad_token_id] = 50256


        midi_score = Score(midi_path)
        if self.augment:
            midi_score = randomize_midi_pitch(midi_score)

        midi_tokens = self.midi_tokenizer.encode(midi_score)[0].ids
        midi_tokens = self._pad_or_truncate(midi_tokens)
        # # Debugging checks
        # print("Lyrics:", lyrics)
        # print("Lyrics Tokens:", lyrics_tokens['input_ids'])
        # print("MIDI Path:", midi_path)
        # print("MIDI Tokens before padding:", midi_tokens)

        return {
            'lyrics_ids': lyrics_tokens['input_ids'].squeeze(0),
            'lyrics_attention_mask': lyrics_tokens['attention_mask'].squeeze(0),
            'midi_tokens': midi_tokens
        }

    def _pad_or_truncate(self, tokens):
        # Truncate sequences longer than `max_length`
        if len(tokens) > self.max_length:
            tokens = tokens[:self.max_length]
        # Pad sequences shorter than `max_length`
        elif len(tokens) < self.max_length:
            tokens = tokens + [0] * (self.max_length - len(tokens))
        return torch.tensor(tokens, dtype=torch.long)

In [3]:
from transformers import AutoModel, GPT2LMHeadModel
import torch.nn as nn
import torch

class LyricsGenerator(nn.Module):
    def __init__(self, lyrics_vocab_size, d_model, max_lyrics_length, max_midi_length):
        super(LyricsGenerator, self).__init__()

        # MIDI Encoder
        self.midi_encoder = AutoModel.from_pretrained("ruru2701/musicbert-v1.1")
        self.midi_projection = nn.Linear(self.midi_encoder.config.hidden_size, d_model)
        self.midi_positional_embedding = nn.Embedding(max_midi_length, d_model)

        # GPT-2 for lyrics
        self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=50256)
        self.gpt2.resize_token_embeddings(lyrics_vocab_size)
        self.lyrics_positional_embedding = nn.Embedding(max_lyrics_length, d_model)

        # Cross-Attention Layer
        self.cross_attention = nn.MultiheadAttention(embed_dim=d_model, num_heads=8)

    def forward(self, lyrics_ids, lyrics_attention_mask, midi_tokens):
        # MIDI Encoding
        midi_outputs = self.midi_encoder(input_ids=midi_tokens)
        midi_embeds = self.midi_projection(midi_outputs.last_hidden_state)
        midi_positions = torch.arange(midi_tokens.size(1), device=midi_tokens.device).unsqueeze(0)
        midi_embeds += self.midi_positional_embedding(midi_positions)

        # Lyrics Encoding
        lyrics_positions = torch.arange(lyrics_ids.size(1), device=lyrics_ids.device).unsqueeze(0)
        lyrics_embeds = self.gpt2.transformer.wte(lyrics_ids) + self.lyrics_positional_embedding(lyrics_positions)

        # Cross-Attention
        midi_embeds_t = midi_embeds.transpose(0, 1)
        lyrics_embeds_t = lyrics_embeds.transpose(0, 1)
        cross_attn_output, _ = self.cross_attention(query=lyrics_embeds_t, key=midi_embeds_t, value=midi_embeds_t)
        combined_embeds = lyrics_embeds + cross_attn_output.transpose(0, 1)

        # Concatenate and Pass through GPT-2
        combined_attention_mask = torch.cat(
            [torch.ones((lyrics_ids.size(0), midi_tokens.size(1)), device=lyrics_ids.device), lyrics_attention_mask],
            dim=1
        )
        combined_embeds = torch.cat((midi_embeds, combined_embeds), dim=1)
        outputs = self.gpt2(inputs_embeds=combined_embeds, attention_mask=combined_attention_mask)
        return outputs.logits[:, midi_tokens.size(1):, :]

# class LyricsGenerator(nn.Module):
#     def __init__(self, lyrics_vocab_size, d_model, max_lyrics_length, max_midi_length):
#         super(LyricsGenerator, self).__init__()

#         # MIDI Encoder
#         self.midi_encoder = AutoModel.from_pretrained("ruru2701/musicbert-v1.1")
#         self.midi_projection = nn.Linear(self.midi_encoder.config.hidden_size, d_model)

#         # GPT-2 for lyrics
#         self.gpt2 = GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=50256)
#         self.gpt2.resize_token_embeddings(lyrics_vocab_size)
#         self.lyrics_positional_embedding = nn.Embedding(max_lyrics_length, d_model)

#     def forward(self, lyrics_ids, lyrics_attention_mask, midi_tokens):
#         # MIDI Encoding
#         midi_outputs = self.midi_encoder(input_ids=midi_tokens)
#         midi_embeds = self.midi_projection(midi_outputs.last_hidden_state)

#         # Lyrics Encoding
#         lyrics_positions = torch.arange(lyrics_ids.size(1), device=lyrics_ids.device).unsqueeze(0)
#         lyrics_embeds = self.gpt2.transformer.wte(lyrics_ids) + self.lyrics_positional_embedding(lyrics_positions)

#         # Concatenate MIDI and Lyrics Embeddings
#         combined_embeds = torch.cat((midi_embeds, lyrics_embeds), dim=1)

#         # Extend attention mask for MIDI tokens
#         combined_attention_mask = torch.cat(
#             [torch.ones((lyrics_ids.size(0), midi_tokens.size(1)), device=lyrics_ids.device), lyrics_attention_mask],
#             dim=1
#         )

#         # Pass concatenated embeddings through GPT-2
#         outputs = self.gpt2(inputs_embeds=combined_embeds, attention_mask=combined_attention_mask)
#         return outputs.logits[:, midi_tokens.size(1):, :]


In [4]:
from tqdm import tqdm
from torch.amp import autocast, GradScaler
def train(model, train_dataloader, val_dataloader, optimizer, scheduler, epochs, device, lyrics_tokenizer):
    model.to(device)
    scaler = GradScaler()
    loss_fct = nn.CrossEntropyLoss(ignore_index=50256)
    save_dir = "model_checkpoint"
    os.makedirs(save_dir, exist_ok=True)
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        loop = tqdm(train_dataloader, leave=True)
        for i, batch in enumerate(loop):
            optimizer.zero_grad()

            lyrics_ids = batch['lyrics_ids'].to(device)
            lyrics_attention_mask = batch['lyrics_attention_mask'].to(device)
            midi_tokens = batch['midi_tokens'].to(device)

            with autocast(device_type=device.type):
                logits = model(lyrics_ids, lyrics_attention_mask, midi_tokens)
                loss = loss_fct(logits.transpose(1, 2), lyrics_ids)
                train_loss += loss.item()

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # # Debugging 
            # if i % 300 == 0:
            #     print(f"Epoch {epoch + 1}, Batch {i}/{len(train_dataloader)}")
            #     print(f"Loss: {loss.item():.4f}")

            #     # Decode input lyrics and predictions
            #     decoded_input = lyrics_tokenizer.decode(lyrics_ids[0].tolist(), skip_special_tokens=True)
            #     predicted_tokens = logits.argmax(dim=-1)[0]
            #     decoded_prediction = lyrics_tokenizer.decode(predicted_tokens.tolist(), skip_special_tokens=True)

            #     # print(f"Input Lyrics: {decoded_input}")
            #     print("_______________________________")
            #     print(f"Predicted Lyrics: {decoded_prediction}")
            #     # print(f"MIDI Tokens: {midi_tokens[0].cpu().numpy()}")
            loop.set_description(f"Epoch {epoch}")
            loop.set_postfix(loss=loss.item())

        val_loss = validate(model, val_dataloader, loss_fct, device)
        print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {train_loss / len(train_dataloader):.4f}, Validation Loss: {val_loss:.4f}")

        # Save Checkpoint
        save_checkpoint(model, optimizer, scheduler, epoch, save_dir)
        print(f"Checkpoint saved for epoch {epoch + 1}!")

def validate(model, dataloader, loss_fct, device):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            lyrics_ids = batch['lyrics_ids'].to(device)
            lyrics_attention_mask = batch['lyrics_attention_mask'].to(device)
            midi_tokens = batch['midi_tokens'].to(device)

            with autocast(device_type=device.type):
                logits = model(lyrics_ids, lyrics_attention_mask, midi_tokens)
                loss = loss_fct(logits.transpose(1, 2), lyrics_ids)
                val_loss += loss.item()
    return val_loss / len(dataloader)

def save_checkpoint(model, optimizer, scheduler, epoch, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch + 1}.pt")
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }
    torch.save(checkpoint, checkpoint_path)

def load_checkpoint(model, optimizer, scheduler, path, device):
    checkpoint = torch.load(path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    epoch = checkpoint['epoch']
    print(f"Checkpoint loaded. Resuming from epoch {epoch}.")
    return model, optimizer, scheduler, epoch

In [5]:
import pandas as pd
df = pd.read_csv("data/lyrics_midi_data.csv")

In [6]:
from transformers import GPT2Tokenizer
from torch.utils.data import DataLoader, random_split
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from pathlib import Path
from miditok import TSD, TokenizerConfig
import torch
from torch.utils.data import Subset
# Tokenizers
lyrics_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
lyrics_tokenizer.add_special_tokens({'pad_token': '<|pad|>'})

# Load MIDI Tokenizer
config = TokenizerConfig(
    num_velocities=1,
    use_chords=False,
    use_rests=False,
    use_tempos=False,
    use_time_signatures=False,
)
midi_tokenizer = TSD(config)
midi_tokenizer = midi_tokenizer.from_pretrained(Path("tokenizer", "tokenizer.json"))

# Dataset and Dataloader
dataset = LyricsMidiDataset(df, lyrics_tokenizer, midi_tokenizer, max_length=512, root_dir='data', augment=True)
train_size = int(0.8 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])

# train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
# val_dataloader = DataLoader(val_dataset, batch_size=4)

# Define a small subset size
subset_size = 100  # Adjust the size as needed
train_subset = Subset(train_dataset, range(subset_size))
val_subset = Subset(val_dataset, range(subset_size))

# Create DataLoaders for the subsets
train_dataloader = DataLoader(train_subset, batch_size=4, shuffle=True)
val_dataloader = DataLoader(val_subset, batch_size=4)

# Model
model = LyricsGenerator(lyrics_vocab_size=len(lyrics_tokenizer), d_model=768, max_lyrics_length=512, max_midi_length=512)

# Optimizer and Scheduler
optimizer = AdamW(model.parameters(), lr=1e-5)
total_steps = len(train_dataloader) * 10
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# Train
train(model, train_dataloader, val_dataloader, optimizer, scheduler, epochs=2, device=device, lyrics_tokenizer=lyrics_tokenizer)


The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  0%|          | 0/25 [00:00<?, ?it/s]We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
Epoch 0: 100%|██████████| 25/25 [00:38<00:00,  1.55s/it, loss=6.94]


Epoch 1/2, Training Loss: 7.1577, Validation Loss: 6.8876
Checkpoint saved for epoch 1!


Epoch 1: 100%|██████████| 25/25 [00:38<00:00,  1.55s/it, loss=6.24]


Epoch 2/2, Training Loss: 6.5588, Validation Loss: 6.4437
Checkpoint saved for epoch 2!


In [11]:
def generate_lyrics(
    model, 
    midi_path, 
    lyrics_tokenizer, 
    midi_tokenizer, 
    max_midi_length, 
    max_lyrics_length, 
    num_beams=5, 
    input_text=None
):
    """
    Generates lyrics conditioned on MIDI input and optional input text using the trained model.
    
    Args:
        model: Trained LyricsGenerator model.
        midi_path: Path to the MIDI file.
        lyrics_tokenizer: Tokenizer for lyrics (e.g., GPT-2 tokenizer).
        midi_tokenizer: Tokenizer for MIDI (e.g., miditok TSD tokenizer).
        max_midi_length: Maximum length of MIDI token sequence.
        max_lyrics_length: Maximum length of lyrics sequence.
        num_beams: Number of beams for beam search.
        input_text: Optional input text to condition lyrics generation.
    
    Returns:
        Generated lyrics as a string.
    """
    device = next(model.parameters()).device
    model.eval()

    # Tokenize MIDI
    try:
        midi_score = Score(midi_path)  # Load MIDI file into Score object
        midi_tokens = midi_tokenizer.encode(midi_score)[0].ids  # Tokenize MIDI
    except Exception as e:
        raise ValueError(f"Error processing MIDI file: {e}")
    
    # Pad or truncate MIDI tokens to max_midi_length
    midi_tokens = midi_tokens[:max_midi_length]
    padding_length = max_midi_length - len(midi_tokens)
    midi_tokens = midi_tokens + [0] * padding_length  # Pad with 0s
    midi_tokens = torch.tensor(midi_tokens, dtype=torch.long).unsqueeze(0).to(device)

    # Initialize input for lyrics generation
    if input_text:
        # Tokenize
        input_ids = lyrics_tokenizer.encode(input_text, return_tensors="pt").to(device)
    else:
        # Default to starting token if no input text
        input_ids = torch.tensor(lyrics_tokenizer.encode("<|endoftext|>")).unsqueeze(0).to(device)
    
    attention_mask = torch.ones_like(input_ids).to(device)

    # Generate with beam search
    beam_output = model.gpt2.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=max_lyrics_length,
        num_beams=num_beams,
        early_stopping=True,
        num_return_sequences=1,
        pad_token_id=lyrics_tokenizer.pad_token_id,
        do_sample=True
    )

    # Decode
    generated_lyrics = lyrics_tokenizer.decode(beam_output[0], skip_special_tokens=True)

    return generated_lyrics


In [14]:
base_dir = 'data/'
midi_path = os.path.join(base_dir, df["midi_path"][9001])
input_text = None
generated_lyrics = generate_lyrics(
    model=model,
    midi_path=midi_path,
    lyrics_tokenizer=lyrics_tokenizer,
    midi_tokenizer=midi_tokenizer,
    max_midi_length=256,
    max_lyrics_length=512,
    num_beams=5,
    input_text=input_text
)
print("Generated Lyrics:")
print(generated_lyrics)


Generated Lyrics:
It's been a while since I've been able to get my hands on one of these, so I thought I'd share it with you.

I've been doing this for a while now, and it's been really fun.

I have a few things I want to share with you.

I want you to know that I love you.

I want you to know that I love you.

I want you to know that I love you.

I want you to know that I love you.

I want you to know that I love you.

I want you to know that I love you.

I want you to know that I love you.
