In [8]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from model import MusicGen

from tqdm.auto import tqdm

In [9]:
min_pitch = 21
max_pitch = 108
n_pitches = max_pitch - min_pitch + 1
sequence_length = 128
n_velocities = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
loaded_data = torch.load('dataset/maestro-100.pth', weights_only=True)
dataset = TensorDataset(loaded_data["sequences"], loaded_data["targets"])
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [11]:
def custom_loss_fn(out, target):
    pitch_pred, velocity_pred, duration_pred, step_pred = torch.split(out, [n_pitches, n_velocities, 1, 1], dim=-1)

    # pitch_pred : (batch_size, n_pitch) float32
    pitch_loss = nn.CrossEntropyLoss()(pitch_pred, target[:, 0].to(torch.int64))
    # velocity_pred : (batch_size, n_velocity) float32
    velocity_loss = nn.CrossEntropyLoss()(velocity_pred, target[:, 1].to(torch.int64))
    # duration_pred : (batch_size, 1)
    duration_loss = nn.MSELoss()(duration_pred, target[:, 2].unsqueeze(1))
    # step_pred : (batch_size, 1)
    step_loss = nn.MSELoss()(step_pred, target[:, 3].unsqueeze(1))

    total_loss = pitch_loss + velocity_loss + step_loss + duration_loss
    return total_loss

In [14]:
model = MusicGen()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
model.to(device)

num_epochs = 3
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_idx, (sequences, targets) in tqdm(enumerate(dataloader)):
        # batch size, 128, 4
        sequences = sequences.to(device)
        # batch size, 4
        targets = targets.to(device)
        
        optimizer.zero_grad()

        # Predicting without hidden for default hidden = None
        out, hidden = model(sequences)
        loss = custom_loss_fn(out, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

0it [00:00, ?it/s]

Epoch 1/3, Loss: 7.8762


0it [00:00, ?it/s]

Epoch 2/3, Loss: 7.7062


0it [00:00, ?it/s]

Epoch 3/3, Loss: 7.6397


In [15]:
torch.save(model.state_dict(), f"models/model2-e{num_epochs}.pth")