In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
from model import MusicGen
from dataset import ShardedDataset
from parse_config import Config
import os
from pathlib import Path

from tqdm.auto import tqdm

In [None]:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="runs/musicgen_jun24")

In [None]:
min_drums_pitch = 35
max_drums_pitch = 81
n_drums_pitches = max_drums_pitch - min_drums_pitch + 1

min_pitch = 0
max_pitch = 127
n_pitches = max_pitch - min_pitch + 1

sequence_length = Config.get('design')['sequence_length']
n_velocities = 128
n_instruments = 4

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
prefix = input("Enter the name/prefix of the processed dataset")
root = os.getcwd()
shard_path = os.path.join(root, Config.get("preprocessing")['processed_data_path'], prefix)
shards = list(Path(shard_path).rglob("*.pth"))
shards.sort()
meta_path = shards.pop()
meta = torch.load(meta_path)

In [None]:
print(shards)

In [None]:
dataset = ShardedDataset(paths=shards, rel_idxs=meta['rel_idxs'])
dataloader = DataLoader(
    dataset, 
    batch_size=50,
    shuffle=True,
    pin_memory=True
)

In [None]:
def custom_loss_fn(out, target):
    instrument_logits = out["instrument"]

    pitch_drums_logits = out["pitch_drums"]
    pitch_bass_logits = out["pitch_bass"]
    pitch_chords_logits = out["pitch_chords"]
    pitch_lead_logits = out["pitch_lead"]

    velocity_drums_logits = out["velocity_drums"]
    velocity_other_logits = out["velocity_other"]

    duration_drums = out["duration_drums"]
    duration_bass = out["duration_bass"]
    duration_chords = out["duration_chords"]
    duration_lead = out["duration_lead"]

    step_drums = out["step_drums"]
    step_other = out["step_other"]

    pitch_target = target[:, 0].long()
    velocity_target = target[:, 1].long()
    duration_target = target[:, 2].unsqueeze(1)
    step_target = target[:, 3].unsqueeze(1)
    instrument_target = target[:, 4].long()

    instrument_loss = F.cross_entropy(instrument_logits, instrument_target)
    is_drums = (instrument_target == 0)
    is_bass = (instrument_target == 1)
    is_chords = (instrument_target == 2)
    is_lead = (instrument_target == 3)

    # Pitch loss
    drums_pitch_loss = F.cross_entropy(pitch_drums_logits[is_drums], pitch_target[is_drums] - min_drums_pitch) if is_drums.any() else 0
    bass_pitch_loss = F.cross_entropy(pitch_bass_logits[is_bass], pitch_target[is_bass]) if is_bass.any() else 0
    chords_pitch_loss = F.cross_entropy(pitch_chords_logits[is_chords], pitch_target[is_chords]) if is_chords.any() else 0
    lead_pitch_loss = F.cross_entropy(pitch_lead_logits[is_lead], pitch_target[is_lead]) if is_lead.any() else 0
    pitch_loss = drums_pitch_loss + bass_pitch_loss + chords_pitch_loss + lead_pitch_loss

    # Velocity loss
    drums_velocity_loss = F.cross_entropy(velocity_drums_logits[is_drums], velocity_target[is_drums]) if is_drums.any() else 0
    other_velocity_loss = F.cross_entropy(velocity_other_logits[~is_drums], velocity_target[~is_drums]) if (~is_drums).any() else 0
    velocity_loss = drums_velocity_loss + other_velocity_loss

    # Duration loss
    duration_loss = 0
    if is_drums.any(): 
        duration_loss += F.huber_loss(duration_drums[is_drums], duration_target[is_drums], reduction='mean', delta=1.0)
    if is_bass.any():
        duration_loss += F.huber_loss(duration_bass[is_bass], duration_target[is_bass], reduction='mean', delta=1.0)
    if is_chords.any():
        duration_loss += F.huber_loss(duration_chords[is_chords], duration_target[is_chords], reduction='mean', delta=1.0)
    if is_lead.any():
        duration_loss += F.huber_loss(duration_lead[is_lead], duration_target[is_lead], reduction='mean', delta=1.0)

    # Step loss
    step_loss = 0
    if is_drums.any():
        step_loss += F.huber_loss(step_drums[is_drums], step_target[is_drums], reduction='mean', delta=0.02)
    if (~is_drums).any():
        step_loss += F.huber_loss(step_other[~is_drums], step_target[~is_drums], reduction='mean', delta=0.02)
    
    # Scale losses
    step_loss *= 500
    duration_loss *= 50

    # Total loss
    total_loss = instrument_loss + pitch_loss + velocity_loss + step_loss + duration_loss
    return {
      'total': total_loss,
      'instrument': instrument_loss,
      'pitch': pitch_loss,
      'velocity': velocity_loss,
      'duration': duration_loss,
      'step': step_loss
    }

In [None]:
print(meta)

In [None]:
model = MusicGen(meta)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
model.to(device)
start_epoch = 0

In [None]:
start_epoch = 15
checkpoint_path = "./trained_weights/modelJun24-checkpoint-7-tree-1-14"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])

In [None]:
num_epochs = 20
for epoch in range(start_epoch, num_epochs):
    model.train()
    total_loss = 0.0
    batch_loss = 0.0
    batch_pitch_loss = 0.0
    batch_velocity_loss = 0.0
    batch_duration_loss = 0.0
    batch_step_loss = 0.0
    batch_instrument_loss = 0.0
    for batch_idx, (sequences, targets) in tqdm(enumerate(dataloader), total=len(dataloader)):
        sequences = sequences.to(device) # batch size, 128, 5
        targets = targets.to(device) # batch size, 5

        optimizer.zero_grad()
        out, hidden = model(sequences) # Predicting without hidden for default hidden = None
        loss_out = custom_loss_fn(out, targets)
        loss = loss_out['total']
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        batch_loss += loss.item()
        batch_pitch_loss += loss_out['pitch'].item()
        batch_velocity_loss += loss_out['velocity'].item()
        batch_duration_loss += loss_out['duration'].item()
        batch_step_loss += loss_out['step'].item()
        batch_instrument_loss += loss_out['instrument'].item()

        if batch_idx % 10 == 0 and batch_idx > 0:
          batch_loss /= 10
          writer.add_scalar('Loss/Batch', batch_loss, global_step=batch_idx + epoch * len(dataloader))
          writer.add_scalar('Loss/Batch_Pitch', batch_pitch_loss, global_step=batch_idx + epoch * len(dataloader))
          writer.add_scalar('Loss/Batch_Velocity', batch_velocity_loss, global_step=batch_idx + epoch * len(dataloader))
          writer.add_scalar('Loss/Batch_Duration', batch_duration_loss, global_step=batch_idx + epoch * len(dataloader))
          writer.add_scalar('Loss/Batch_Step', batch_step_loss, global_step=batch_idx + epoch * len(dataloader))
          writer.add_scalar('Loss/Batch_Instrument', batch_instrument_loss, global_step=batch_idx + epoch * len(dataloader))
          batch_loss = 0.0
          batch_pitch_loss = 0.0
          batch_velocity_loss = 0.0
          batch_duration_loss = 0.0
          batch_step_loss = 0.0
          batch_instrument_loss = 0.0

    torch.save({
      'model': model.state_dict(),
      'optimizer': optimizer.state_dict()
      }, f"trained_weights/modelJun24-checkpoint-7-tree-1-{epoch}")
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

Epoch 17/20, Loss: 28.1934


  0%|          | 0/983 [00:00<?, ?it/s]

Epoch 18/20, Loss: 27.9921


  0%|          | 0/983 [00:00<?, ?it/s]

Epoch 19/20, Loss: 27.7524


  0%|          | 0/983 [00:00<?, ?it/s]

Epoch 20/20, Loss: 27.2372
