In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch.nn.functional as F
from model import MusicGen

from tqdm.auto import tqdm

In [2]:
min_drum_pitch = 35
max_drum_pitch = 81
n_drum_pitches = max_drum_pitch - min_drum_pitch + 1

min_pitch = 0
max_pitch = 127
n_pitches = max_pitch - min_pitch + 1

sequence_length = 128
n_velocities = 128
n_instruments = 4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
loaded_data = torch.load('data/lmd-10.pth', weights_only=True)
denorm = torch.load('data/lmd-10-denorm.pth', weights_only=True)
dataset = TensorDataset(loaded_data["sequences"], loaded_data["targets"])
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [4]:
def custom_loss_fn(out, target):
    instrument_logits = out["instrument"]
    drum_pitch_logits = out["drum_pitch"]
    regular_pitch_logits = out["regular_pitch"]
    velocity_logits = out["velocity"]
    step = out["step"]
    duration = out["duration"]

    pitch_target = target[:, 0].long()
    velocity_target = target[:, 1].long()
    duration_target = target[:, 2].unsqueeze(1)
    step_target = target[:, 3].unsqueeze(1)
    instrument_target = target[:, 4].long()

    instrument_loss = F.cross_entropy(instrument_logits, instrument_target)
    is_drum = (instrument_target == 0)

    drum_pitch_loss = F.cross_entropy(drum_pitch_logits[is_drum], pitch_target[is_drum] - min_drum_pitch) if is_drum.any() else 0
    regular_pitch_loss = F.cross_entropy(regular_pitch_logits[~is_drum], pitch_target[~is_drum]) if (~is_drum).any() else 0
    pitch_loss = drum_pitch_loss + regular_pitch_loss

    # Velocity loss
    velocity_loss = F.cross_entropy(velocity_logits, velocity_target)

    # Step and duration loss (MSE)
    step_loss = F.mse_loss(step, step_target)
    duration_loss = F.mse_loss(duration, duration_target)

    # Total loss
    total_loss = instrument_loss + pitch_loss + velocity_loss + step_loss + duration_loss
    return total_loss

In [None]:
model = MusicGen()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.to(device)

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (sequences, targets) in tqdm(enumerate(dataloader)):
        # batch size, 128, 5
        sequences = sequences.to(device)
        # batch size, 5
        targets = targets.to(device)
        
        optimizer.zero_grad()

        # Predicting without hidden for default hidden = None
        out, hidden = model(sequences)
        loss = custom_loss_fn(out, targets)
        loss.backward()
        optimizer.step()
        #print(loss.item())
        total_loss += loss.item()
        
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

In [8]:
torch.save(model.state_dict(), f"models/model3-e{num_epochs}.pth")