In [None]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

In [None]:
if is_colab():
  !pip install git+https://github.com/hoggl-dsp/cc_beats

  !apt-get install fluidsynth
  !mkdir data
  !mkdir data/sf2
  !wget -O data/sf2/Xpand_2_-_Practice_Room_Kit.sf2 'https://musical-artifacts.com/artifacts/6296/Xpand_2_-_Practice_Room_Kit.sf2'
  !wget -O data/groove_midi_only.zip 'https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip'
  !unzip data/groove_midi_only.zip -d data/

In [None]:
import torch
import torchaudio
import symusic
import symusic.types
from midi2audio import FluidSynth

import IPython.display as ipd

import os

from cc_beats import modules, tokeniser, utils

In [None]:
midi_file = symusic.Score.from_file('data/groove/drummer1/session1/1_funk_80_beat_4-4.mid')
print(midi_file)

# sf_path = '/opt/homebrew/Cellar/fluid-synth/2.4.5/share/fluid-synth/sf2/VintageDreamsWaves-v2.sf2'
sf_path = 'data/sf2/Xpand_2_-_Practice_Room_Kit.sf2'

fs = FluidSynth(sf_path)

# synth = Synthesizer(sf_path)


# fs.play_midi('data/groove_midi_only/drummer1/session1/1_funk_80_beat_4-4.mid')

utils.midi_to_audio_display(fs, midi_file)
print(midi_file)
print(midi_file.tracks[0])

In [None]:
midi_files = []
for root, _, files in os.walk(os.path.join('data', 'groove')):
    midi_files.extend([os.path.join(root, file) for file in files if file.endswith('.mid') and 'beat' in file and 'eval' not in os.path.basename(root)])

for i in range(20):
    print(midi_files[i])

midi_file = symusic.Score(midi_files[0])

drum_tokeniser = tokeniser.DrumSequenceTokeniser(subdivision=16, velocity_bands=4)

tok_sequence = drum_tokeniser.encode(midi_file)
returned_midi = drum_tokeniser.decode(tok_sequence)

returned_midi.time_signatures.extend(midi_file.time_signatures)
returned_midi.key_signatures.extend(midi_file.key_signatures)
returned_midi.tempos.extend(midi_file.tempos)

print("Actual File")
print(midi_file)
print(midi_file.tracks[0].notes[:10])
print('-----------------------------')

print("Decoded File")
print(returned_midi)
print(returned_midi.tracks[0].notes[:10])
print('-----------------------------')

display(utils.midi_to_audio_display(fs, midi_file))
display(utils.midi_to_audio_display(fs, returned_midi))

In [None]:
class DrumInpaintingTransformer(torch.nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int = 256, num_layers: int = 8, num_heads: int = 8, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)

        self.transformer = torch.nn.ModuleList([
            modules.TransformerLayerWithRelativeAttention(embedding_dim, num_heads=num_heads, dropout=0.1)
            for _ in range(num_layers)
        ])

        self.dense_out = torch.nn.Linear(embedding_dim, vocab_size)

    def forward(self, x, attention_mask=None):
        # Embed tokens
        x = self.embedding(x)  # [batch_size, seq_len, embedding_dim]

        for layer in self.transformer:
            x = layer(x, attention_mask)

        # Project to vocabulary
        output = self.dense_out(x)  # [batch_size, seq_len, vocab_size]

        return output

# class DrumInpaintingTransformer(torch.nn.Module):
#     def __init__(self, vocab_size: int, embedding_dim: int = 512, *args, **kwargs):
#         super().__init__(*args, **kwargs)

#         self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)

#         encoder_layer = torch.nn.TransformerEncoderLayer(
#             d_model=embedding_dim,
#             nhead=16,
#             dim_feedforward=512,
#             batch_first=True,
#         )

#         self.transformer_encoder = torch.nn.TransformerEncoder(
#             encoder_layer=encoder_layer,
#             num_layers=6
#         )

#         self.dense_out = torch.nn.Linear(embedding_dim, vocab_size)

#     def forward(self, x, attention_mask=None):
#         # Embed tokens
#         x = self.embedding(x)  # [batch_size, seq_len, d_model]

#         # Pass through transformer (attention_mask needs to be properly formatted for TransformerEncoder)
#         if attention_mask is not None:
#             # Create a mask for padding tokens (1 means masked/ignored position)
#             padding_mask = (attention_mask == 0)
#             x = self.transformer_encoder(x, src_key_padding_mask=padding_mask)
#         else:
#             x = self.transformer_encoder(x)

#         # Project to vocabulary
#         output = self.dense_out(x)  # [batch_size, seq_len, vocab_size]

#         return output

In [None]:
def create_dataset(tokeniser: tokeniser.DrumSequenceTokeniser, midi_files: list[str], chunk_length: int = 64):
    encoded_sequences = tokeniser.encode_all(midi_files)
    seq_chunks = []
    for seq in encoded_sequences:
        iterated_seq = seq.copy()
        while len(iterated_seq) > 0:
            if len(iterated_seq) >= chunk_length:
                seq_chunks.append(torch.tensor(iterated_seq[:chunk_length]))
                iterated_seq = iterated_seq[chunk_length:]
            else:
                num_pads = chunk_length - len(iterated_seq)
                seq_chunks.append(torch.tensor(iterated_seq + [tokeniser['<pad>'] for _ in range(num_pads)]))
                iterated_seq = []
    seq_chunks = torch.stack(seq_chunks, dim=0)
    token_counts = seq_chunks.unique(return_counts=True)
    for tok, count in zip(token_counts[0], token_counts[1]):
        print(tok.item(), count.item())
    return torch.utils.data.TensorDataset(seq_chunks)

dataset = create_dataset(drum_tokeniser, midi_files)
print("Num Chunks:", len(dataset))
print("Vocab Size:", len(drum_tokeniser))

In [None]:
num_epochs = 30
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")

model = DrumInpaintingTransformer(len(drum_tokeniser.vocab))
model = model.to(device)  # Move model to the appropriate device

# Create a DataLoader for batching
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, val_size])

loss_fn = torch.nn.CrossEntropyLoss()
optimiser = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True
)

val_loader = torch.utils.data.DataLoader(
    valid_set,
    batch_size=batch_size,
    shuffle=False
)

# Training loop
train_losses = []
val_losses = []

for epoch in range(num_epochs):
    # Training phase
    model.train()
    total_train_loss = 0
    train_batches = 0

    for batch in train_loader:
        inputs = batch[0].to(device)
        targets = inputs.clone()

        # Create masked input for training (25% masking probability)
        mask = torch.rand(inputs.shape, device=device) > 0.25
        masked_inputs = inputs.clone()
        masked_inputs[~mask] = drum_tokeniser['<mask>']

        # Zero the gradients
        optimiser.zero_grad()

        # Forward pass
        outputs = model(masked_inputs)

        # Reshape for loss calculation
        batch_size, seq_len, vocab_size = outputs.shape
        outputs_flat = outputs.reshape(-1, vocab_size)
        targets_flat = targets.reshape(-1)
        mask_flat = mask.reshape(-1)

        # Calculate loss only on masked positions
        loss = loss_fn(outputs_flat[~mask_flat], targets_flat[~mask_flat])

        # Backward pass and optimize
        loss.backward()
        optimiser.step()

        total_train_loss += loss.item()
        train_batches += 1

    avg_train_loss = total_train_loss / max(1, train_batches)
    train_losses.append(avg_train_loss)

    # Validation phase
    model.eval()
    total_val_loss = 0
    val_batches = 0

    with torch.no_grad():
        for batch in val_loader:
            inputs = batch[0].to(device)
            targets = inputs.clone()

            # Create masked input for validation
            mask = torch.rand(inputs.shape, device=device) > 0.25
            masked_inputs = inputs.clone()
            masked_inputs[~mask] = drum_tokeniser['<mask>']

            # Forward pass
            outputs = model(masked_inputs)

            # Reshape for loss calculation
            batch_size, seq_len, vocab_size = outputs.shape
            outputs_flat = outputs.reshape(-1, vocab_size)
            targets_flat = targets.reshape(-1)
            mask_flat = mask.reshape(-1)

            # Calculate loss only on masked positions
            loss = loss_fn(outputs_flat[~mask_flat], targets_flat[~mask_flat])

            total_val_loss += loss.item()
            val_batches += 1

    avg_val_loss = total_val_loss / max(1, val_batches)
    val_losses.append(avg_val_loss)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

# Plot training and validation loss
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
torch.save(model, 'model.pth')

In [None]:
import random

batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")

model = torch.load('model.pth', map_location=device, weights_only=False)
# model.load_state_dict()
model = model.to(device)  # Move model to the appropriate device

input_sequence = random.choice(dataset)[0].unsqueeze(0).to(device)

model.eval()
# Create masked input for training (20% masking probability)
mask = torch.rand(input_sequence.shape, device=device) > 0.25
masked_inputs = input_sequence.clone()
masked_inputs[~mask] = drum_tokeniser['<mask>']

output_logits = model(masked_inputs).squeeze()
print(output_logits.shape)
output_preds = torch.multinomial(torch.softmax(output_logits, dim=-1) / 0.5, num_samples=1)

print(output_preds)
output_sequence = input_sequence.masked_scatter(~mask, output_preds)

input_midi = drum_tokeniser.decode(input_sequence.squeeze())
output_midi = drum_tokeniser.decode(output_sequence.squeeze())

for in_note, out_note in zip(input_midi.tracks[0].notes, output_midi.tracks[0].notes):
    print(in_note, ' - ', out_note)

display(utils.midi_to_audio_display(fs, input_midi))
display(utils.midi_to_audio_display(fs, output_midi))

In [None]:
sequence = [
    'p36_v2', '<mask>', '<mask>', 'p36_v1', 'p38_v3', '<mask>', 'p36_v1', '<mask>', '<mask>', '<mask>', '<mask>', 'p36_v1', 'p38_v3', '<mask>', 'p36_v1', '<mask>',
    '<mask>', '<mask>', '<mask>', 'p36_v1', 'p38_v3', '<mask>', 'p36_v1', '<mask>', '<mask>', '<mask>', '<mask>', 'p36_v1', 'p38_v3', '<mask>', '<mask>', '<mask>'
]
tok_sequence = torch.tensor([drum_tokeniser[tok] for tok in sequence], device=device)

output_logits = model(tok_sequence.unsqueeze(0)).squeeze()

mask = tok_sequence == drum_tokeniser['<mask>']
rests = torch.full_like(tok_sequence, drum_tokeniser['<rest>'])

output_preds = torch.multinomial(torch.softmax(output_logits, dim=-1) / 0.1, num_samples=1).squeeze()
print(torch.topk(torch.softmax(output_logits / 0.1, dim=-1), 10, dim=-1))
# output_preds = torch.argmax(output_logits, dim=-1).squeeze()

template_midi = drum_tokeniser.decode(tok_sequence.masked_scatter(mask, rests))
predicted_midi = drum_tokeniser.decode(output_preds)
output_midi = drum_tokeniser.decode(tok_sequence.masked_scatter(mask, output_preds))

display(utils.midi_to_audio_display(fs, template_midi))
display(utils.midi_to_audio_display(fs, predicted_midi))
display(utils.midi_to_audio_display(fs, output_midi))