In [None]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False
    
IS_COLAB = is_colab()

if IS_COLAB:
  !pip install git+https://github.com/hoggl-dsp/cc_beats

  !apt-get install fluidsynth
  !mkdir data
  !mkdir data/sf2
  !wget -O data/sf2/Xpand_2_-_Practice_Room_Kit.sf2 'https://musical-artifacts.com/artifacts/6296/Xpand_2_-_Practice_Room_Kit.sf2'
  
  !wget -O data/groove_midi_only.zip 'https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip'
  !unzip data/groove_midi_only.zip -d data/

  !wget -O data/clean_midi.tar.gz 'http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz'
  !tar -xf data/clean_midi.tar.gz -C data/

In [None]:
import torch
import torchaudio
import symusic
import symusic.types
from midi2audio import FluidSynth

import IPython.display as ipd

import os
import random
import tqdm
from typing import Callable

from cc_beats import modules, tokeniser, utils

sf_path = 'data/sf2/Xpand_2_-_Practice_Room_Kit.sf2'
FS = FluidSynth(sf_path)

In [None]:
# lakh_midi_files = []
# for root, _, files in os.walk(os.path.join('data', 'lmd_full')):
#     lakh_midi_files.extend([os.path.join(root, file) for file in files if file.endswith('.mid')])


# lakh_midi_with_drums = []
# for file in random.sample(lakh_midi_files, 10):
#     print(" ---------------------------------------------------------------------------- ")
#     print("Loading:", file)
#     try:
#         midi = symusic.Score(file)
#         midi.tracks = [track for track in midi.tracks if track.is_drum]
#         num_drum_tracks = len(midi.tracks)
#         print(f"Found {num_drum_tracks} drum tracks")
#         if num_drum_tracks > 0:
#             lakh_midi_with_drums.append(midi)
#     except:
#         print("Loading failed, skipping...")
#         continue

# drum_tokeniser = tokeniser.DrumSequenceTokeniser(subdivision=16, velocity_bands=4)
# for midi in lakh_midi_with_drums:
#     reconstructed_midi = drum_tokeniser.decode(drum_tokeniser.encode(midi))
#     print(" ---------------------------------------------------------------------------- ")
#     display(utils.midi_to_audio_display(FS, midi))
#     display(utils.midi_to_audio_display(FS, reconstructed_midi))

In [None]:
def get_midi_data_files(root_dir: str, filter_fn: Callable[[str, str], bool] = lambda x, y: True):
    midi_files = []
    for root, _, files in os.walk(root_dir):
        midi_files.extend([os.path.join(root, file) for file in files if filter_fn(root, file)])
    return midi_files

def create_dataset(tokeniser: tokeniser.DrumSequenceTokeniser, midi_files: list[str] | list[symusic.types.Score], chunk_length: int = 64):
    midis = []
    for file in tqdm.tqdm(midi_files, desc='Loading Midi Files'):
        try:
            midi = symusic.Score.from_file(file)
            midi.tracks = [track for track in midi.tracks if track.is_drum]
            if len(midi.tracks) > 0:
                # print(f"Found {len(midi.tracks)} drum tracks")
                midis.append(midi)
            else:
                # print("No drum tracks, skipping...")
                pass
        except:
            # print("Loading failed, skipping...")
            pass
    
    print(f"Loaded {len(midis)} midi files with at least one drum track")
    
    
    encoded_sequences = []
    for midi in tqdm.tqdm(midis, desc='Encoding Midi Files'):
        sequence = tokeniser.encode(midi)
        if sequence is not None:
            encoded_sequences.append(sequence)

    seq_chunks = []
    for seq in tqdm.tqdm(encoded_sequences, desc='Chunking Midi Files'):
        iterated_seq = seq.copy()
        if len(iterated_seq) > 1e7:
            continue
        while len(iterated_seq) > 0:
            if len(iterated_seq) >= chunk_length:
                seq_chunks.append(torch.tensor(iterated_seq[:chunk_length]))
                iterated_seq = iterated_seq[chunk_length:]
            else:
                num_pads = chunk_length - len(iterated_seq)
                seq_chunks.append(torch.tensor(iterated_seq + [tokeniser['<pad>'] for _ in range(num_pads)]))
                iterated_seq = []
    seq_chunks = torch.stack(seq_chunks, dim=0)
    token_counts = seq_chunks.unique(return_counts=True)
    # for tok, count in zip(token_counts[0], token_counts[1]):
    #     print(tok.item(), count.item())
    return torch.utils.data.TensorDataset(seq_chunks)

drum_tokeniser = tokeniser.DrumSequenceTokeniser(subdivision=16, velocity_bands=4)
dataset = create_dataset(
    drum_tokeniser,
    get_midi_data_files(os.path.join('data', 'groove'), lambda root, file: 'beat' in file and 'eval' not in os.path.basename(root))
      + get_midi_data_files(os.path.join('data', 'clean_midi'))
)
print("Num Chunks:", len(dataset))
print("Vocab Size:", len(drum_tokeniser))

In [None]:
class DrumInpaintingTransformer(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int = 256, num_layers: int = 8, num_heads: int = 8, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)

        self.transformer = torch.nn.ModuleList([
            modules.TransformerLayerWithRelativeAttention(embedding_dim, num_heads=num_heads, dropout=0.1)
            for _ in range(num_layers)
        ])

        self.dense_out = torch.nn.Linear(embedding_dim, vocab_size)

    def forward(self, x, attention_mask=None):
        # Embed tokens
        x = self.embedding(x)  # [batch_size, seq_len, embedding_dim]

        for layer in self.transformer:
            x = layer(x, attention_mask)

        # Project to vocabulary
        output = self.dense_out(x)  # [batch_size, seq_len, vocab_size]

        return output

In [None]:
num_epochs = 1
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")

model = DrumInpaintingTransformer(len(drum_tokeniser.vocab), embedding_dim=512, num_layers=16, num_heads=16)
model = model.to(device)  # Move model to the appropriate device

# Create a DataLoader for batching
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, val_size])

loss_fn = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True
)

val_loader = torch.utils.data.DataLoader(
    valid_set,
    batch_size=batch_size,
    shuffle=False
)

# Training loop
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    # Training phase
    model.train()
    total_train_loss = 0
    train_batches = 0

    for i, batch in enumerate(train_loader, 1):
        inputs = batch[0].to(device)
        targets = inputs.clone()

        # Create masked input for training (25% masking probability)
        mask = torch.rand(inputs.shape, device=device) > 0.25
        masked_inputs = inputs.clone()
        masked_inputs[~mask] = drum_tokeniser['<mask>']

        # Zero the gradients
        optimiser.zero_grad()

        # Forward pass
        outputs = model(masked_inputs)

        # Reshape for loss calculation
        batch_size, seq_len, vocab_size = outputs.shape
        outputs_flat = outputs.reshape(-1, vocab_size)
        targets_flat = targets.reshape(-1)
        mask_flat = mask.reshape(-1)

        # Calculate loss only on masked positions
        loss = loss_fn(outputs_flat[~mask_flat], targets_flat[~mask_flat])

        # Backward pass and optimize
        loss.backward()
        optimiser.step()

        total_train_loss += loss.item()
        train_batches += 1

        if i % 10 == 0:
            print(f'[{i}/{len(train_loader)}] - Loss = {loss.item()}')

    avg_train_loss = total_train_loss / max(1, train_batches)
    train_losses.append(avg_train_loss)

    # Validation phase
    model.eval()
    total_val_loss = 0
    val_batches = 0

    with torch.no_grad():
        for batch in val_loader:
            inputs = batch[0].to(device)
            targets = inputs.clone()

            # Create masked input for validation
            mask = torch.rand(inputs.shape, device=device) > 0.25
            masked_inputs = inputs.clone()
            masked_inputs[~mask] = drum_tokeniser['<mask>']

            # Forward pass
            outputs = model(masked_inputs)

            # Reshape for loss calculation
            batch_size, seq_len, vocab_size = outputs.shape
            outputs_flat = outputs.reshape(-1, vocab_size)
            targets_flat = targets.reshape(-1)
            mask_flat = mask.reshape(-1)

            # Calculate loss only on masked positions
            loss = loss_fn(outputs_flat[~mask_flat], targets_flat[~mask_flat])

            total_val_loss += loss.item()
            val_batches += 1

    avg_val_loss = total_val_loss / max(1, val_batches)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

# Plot training and validation loss
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()