In [None]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

IS_COLAB = is_colab()

if IS_COLAB:
  !pip install git+https://github.com/hoggl-dsp/cc_beats

  !apt-get install fluidsynth
  !mkdir data
  !mkdir data/sf2
  !wget -O data/sf2/Xpand_2_-_Practice_Room_Kit.sf2 'https://musical-artifacts.com/artifacts/6296/Xpand_2_-_Practice_Room_Kit.sf2'

  !wget -O data/groove_midi_only.zip 'https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip'
  !unzip data/groove_midi_only.zip -d data/

  !wget -O data/clean_midi.tar.gz 'http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz'
  !tar -xf data/clean_midi.tar.gz -C data/

In [None]:
import torch
import torchaudio
import symusic
import symusic.types
from midi2audio import FluidSynth

import IPython.display as ipd
import tqdm
import os

from typing import Callable

from cc_beats import tokeniser, modules, utils

sf_path = 'data/sf2/Xpand_2_-_Practice_Room_Kit.sf2'
FS = FluidSynth(sf_path)

In [None]:
midi_files = []
for root, _, files in os.walk(os.path.join('data', 'groove')):
    midi_files.extend([os.path.join(root, file) for file in files if file.endswith('.mid') and 'beat' in file and 'eval' not in os.path.basename(root)])

midi_files = []
for root, _, files in os.walk(os.path.join('data', 'clean_midi')):
    midi_files.extend([os.path.join(root, file) for file in files if file.endswith('.mid')]) # and 'beat' in file and 'eval' not in os.path.basename(root)])

for i in range(20):
    print(midi_files[i])

midi_file = symusic.Score(midi_files[0])

encoder = tokeniser.DrumSequenceEncoder(subdivision=16)

tok_sequence = encoder.encode(midi_file)
returned_midi = encoder.decode(tok_sequence)

returned_midi.time_signatures.extend(midi_file.time_signatures)
returned_midi.key_signatures.extend(midi_file.key_signatures)
returned_midi.tempos.extend(midi_file.tempos)

print("Actual File")
print(midi_file)
print(midi_file.tracks[0].notes[:10])
print('-----------------------------')

print("Decoded File")
print(returned_midi)
print(returned_midi.tracks[0].notes[:10])
print('-----------------------------')

display(utils.midi_to_audio_display(FS, midi_file))
display(utils.midi_to_audio_display(FS, returned_midi))

In [None]:
def get_midi_data_files(root_dir: str, filter_fn: Callable[[str, str], bool] = lambda root, file: True):
    midi_files = []
    for root, _, files in os.walk(root_dir):
        midi_files.extend([os.path.join(root, file) for file in files if filter_fn(root, file)])
    return midi_files

In [None]:
# Lakh clean_set analysis
import pandas as pd

def do_lakh_analysis():
    files = get_midi_data_files(os.path.join('data', 'clean_midi'))

    data = {
        'File': [],
        'File Length (quarters)': [],
        'Tempos': [],
        'Is 4/4': [],
        'Num Drum Tracks': [],
        'Drum Hit Count': [],
        'Unique Drum Hits': [],
        'First Drum (quarters)': [],
        'Last Drum (quarters)': [],
    }
    for file in tqdm.tqdm(files, desc='Analysing Midi Files'):
        midi = None
        drum_tracks = None
        data['File'].append(file)
        try:
            midi = symusic.Score.from_file(file)
            data['File Length (quarters)'].append((midi.end() - midi.start()) / midi.tpq)
            data['Tempos'].append(midi.tempos)
            data['Is 4/4'].append(len(midi.time_signatures) == 1 and (midi.time_signatures[0].numerator == 4 and midi.time_signatures[0].denominator == 4))

            drum_tracks = [track for track in midi.tracks if track.is_drum]
            data['Num Drum Tracks'].append(len(drum_tracks))

            count = 0
            note_set = set()
            first_note_time = None
            last_note_time = None
            for track in drum_tracks:
                count += len(track.notes)
                for note in track.notes:
                    note_set.add(note.pitch)
                    if first_note_time is None or first_note_time > note.time:
                        first_note_time = note.time
                    if last_note_time is None or last_note_time < note.time:
                        last_note_time = note.time
            
            data['Drum Hit Count'].append(count)
            data['Unique Drum Hits'].append(note_set)
            data['First Drum (quarters)'].append(first_note_time / midi.tpq if first_note_time is not None else None)
            data['Last Drum (quarters)'].append(last_note_time / midi.tpq if last_note_time is not None else None)

        except:
            for key in data:
                if key != 'File':
                    data[key].append(None)
    

    
    lakh_df = pd.DataFrame(data)

    print(" ***** Filtering Dataset *****")
    print("Full Dataset:", len(lakh_df))
    lakh_df = lakh_df[~lakh_df.isna()]
    print("Without unreadable files:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Num Drum Tracks'] > 0]
    print("Without files with no drums:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Is 4/4']]
    print("Without files not in 4/4:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Drum Hit Count'] > 256]
    print("Without files with less than 256 drum hits", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Drum Hit Count'] / (lakh_df['File Length (quarters)']) >= 1.0]
    print("Without files with note density < 1 per quarter:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['File Length (quarters)'] < 1000.0]
    print("Without files with > 1000 beats:", len(lakh_df))

    return lakh_df
    
lakh_clean_df = do_lakh_analysis()

print(lakh_clean_df.describe())

In [None]:
class DrumInpaintingTransformer(torch.nn.Module):
    def __init__(self, num_pitches: int, embedding_dim: int = 64, num_layers: int = 6, pitch_pos_weight: float = 2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_proj = torch.nn.Linear(1 + 2 * num_pitches, embedding_dim)

        self.transformer = torch.nn.ModuleList([
            modules.TransformerLayerWithRelativeAttention(embedding_dim, num_heads=8, dropout=0.1)
            for _ in range(num_layers)
        ])

        self.output_hit_proj = torch.nn.Linear(embedding_dim, 1)
        self.output_pitches_proj = torch.nn.Linear(embedding_dim, num_pitches)
        self.output_velocities_proj = torch.nn.Linear(embedding_dim, num_pitches)

        self.hit_loss = torch.nn.BCEWithLogitsLoss()
        if isinstance(pitch_pos_weight, float):
            pitch_pos_weight = torch.tensor(pitch_pos_weight)
        self.pitch_loss = torch.nn.BCEWithLogitsLoss(pos_weight=pitch_pos_weight)
        self.vel_loss = torch.nn.MSELoss()
    
    def forward(self, x: torch.Tensor, attention_mask=None):
        # Input dims [batch, seq_len, 1 + 2 * num_pitches]
        
        x = self.input_proj(x) # [batch, seq_len, embedding_dim]

        for layer in self.transformer:
            x = layer(x) # [batch, seq_len, embedding_dim]
        
        hit_logits = self.output_hit_proj(x)
        pitch_logits = self.output_pitches_proj(x)
        vel_logits = torch.sigmoid(self.output_velocities_proj(x))

        return hit_logits, pitch_logits, vel_logits
    
    def loss_function(self, 
            preds: tuple[torch.Tensor, torch.Tensor, torch.Tensor], 
            truths: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            inpainting_mask: torch.Tensor
        ) -> torch.Tensor:
        pred_hits, pred_pitches, pred_vels = preds
        true_hits, true_pitches, true_vels = truths
        
        device = pred_hits.device

        hit_loss = self.hit_loss(pred_hits[inpainting_mask], true_hits[inpainting_mask])

        
        hit_mask = inpainting_mask & (true_hits > 0.5)
        hit_mask = hit_mask.expand(-1, -1, 9)

        pitch_loss = self.pitch_loss(pred_pitches[hit_mask], true_pitches[hit_mask]) if hit_mask.any() else torch.tensor(0.0, device=device)

        hit_pitch_mask = hit_mask & (true_pitches > 0.5)
        vel_loss = self.vel_loss(pred_vels[hit_pitch_mask], true_vels[hit_pitch_mask]) if hit_pitch_mask.any() else torch.tensor(0.0, device=device)

        return hit_loss + pitch_loss + vel_loss


In [None]:
# Dataset
def make_dataset(encoder: tokeniser.DrumSequenceEncoder, midi_files: list[str], max_seq_length: int = 64):
    sequences = encoder.encode_all(midi_files)
    split_hits = []
    split_pithes = []
    split_velocities = []
    for hits, pitches, velocities in sequences:
        len_padding = max_seq_length - (hits.size(0) % max_seq_length)
        if len_padding > 0:
            hits = torch.cat([hits, torch.zeros(len_padding, hits.size(1), dtype=hits.dtype, device=hits.device)])
            pitches = torch.cat([pitches, torch.zeros(len_padding, pitches.size(1), dtype=pitches.dtype, device=pitches.device)])
            velocities = torch.cat([velocities, torch.zeros(len_padding, velocities.size(1), dtype=velocities.dtype, device=velocities.device)])
        
        split_hits.extend(torch.tensor_split(hits, int(hits.size(0) / max_seq_length)))
        split_pithes.extend(torch.tensor_split(pitches, int(pitches.size(0) / max_seq_length)))
        split_velocities.extend(torch.tensor_split(velocities, int(velocities.size(0) / max_seq_length)))
    
    return torch.utils.data.TensorDataset(
        torch.stack(split_hits, dim=0),
        torch.stack(split_pithes, dim=0),
        torch.stack(split_velocities, dim=0)
    )

drum_encoder = tokeniser.DrumSequenceEncoder(subdivision=16)
dataset = make_dataset(drum_encoder, lakh_clean_df['File'].to_list(), max_seq_length=128)

print("Dataset_Size:", len(dataset))

In [None]:
# Set up training parameters
num_epochs = 1
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")

model = DrumInpaintingTransformer(num_pitches=9, embedding_dim=16, num_layers=16, pitch_pos_weight=1.5)
model = model.to(device)  # Move model to the appropriate device

# Create a DataLoader for batching
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, val_size])

optimiser = torch.optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-6)

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimiser)

train_loader = torch.utils.data.DataLoader(
    train_set, 
    batch_size=batch_size, 
    shuffle=True
)

val_loader = torch.utils.data.DataLoader(
    valid_set, 
    batch_size=batch_size, 
    shuffle=False
)

# Training loop
train_losses = []
val_losses = []

for i in range(num_epochs):
    model.train()
    total_train_loss = 0

    for j, (hits, pitches, vels) in enumerate(train_loader, 1):
        optimiser.zero_grad()
        
        hits, pitches, vels = hits.to(device), pitches.to(device), vels.to(device)
        masked_hits, masked_pitches, masked_vels = hits.clone(), pitches.clone(), vels.clone()

        mask = torch.rand_like(hits, device=device) < 0.25
        masked_hits[mask] = 0.5
        masked_pitches[mask.squeeze(), :] = 0.0
        masked_vels[mask.squeeze(), :] = 0.0
        
        model_input = torch.cat((masked_hits, masked_pitches, masked_vels), dim=-1)

        pred_hits, pred_pitches, pred_vels = model(model_input)
        loss = model.loss_function((pred_hits, pred_pitches, pred_vels), (hits, pitches, vels), mask)
        loss.backward()
        optimiser.step()

        if j % 10 == 0:
            print(f'[{j}/{len(train_loader)}] - Running loss = {total_train_loss / len(train_loader)}')

        total_train_loss += loss.item()
        print(loss)
    
    # Keep track of training loss
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    total_val_loss = 0
    
    with torch.no_grad():
        for hits, pitches, vels in val_loader:
            hits, pitches, vels = hits.to(device), pitches.to(device), vels.to(device)
            
            # Create same masking pattern as in training
            masked_hits, masked_pitches, masked_vels = hits.clone(), pitches.clone(), vels.clone()
            mask = torch.rand_like(hits, device=device) < 0.25
            masked_hits[mask] = 0.5
            masked_pitches[mask.squeeze(), :] = 0.0
            masked_vels[mask.squeeze(), :] = 0.0
            
            model_input = torch.cat((masked_hits, masked_pitches, masked_vels), dim=-1)
            
            pred_hits, pred_pitches, pred_vels = model(model_input)
            loss = model.loss_function((pred_hits, pred_pitches, pred_vels), (hits, pitches, vels), mask)
            
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    print(f"Epoch {i+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

import matplotlib.pyplot as plt

# Plot training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses) + 1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, 'r-', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Save the model to disk
torch.save(model.state_dict(), 'drum_inpainting_model.pth')
print("Model saved to 'drum_inpainting_model.pth'")

In [None]:
def make_prompt(length: int, prompt_hits: list[tuple[int, int, float]]):
    hits = torch.zeros(1, length, 1)
    pitches = torch.zeros(1, length, 9)
    velocities = torch.zeros(1, length, 9)

    for index, pitch, velocity in prompt_hits:
        hits[:, index, :] = 1.0
        pitches[:, index, pitch] = 1.0
        velocities[:, index, pitch] = velocity

    return hits.to(device), pitches.to(device), velocities.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")
# model = DrumInpaintingTransformer(num_pitches=9, num_layers=4, pitch_pos_weight=5.0)
model.load_state_dict(torch.load('drum_inpainting_model.pth', map_location=device))
model = model.to(device)
prompt = [
    (0, drum_encoder.pitch_to_index[36], 0.8),
    (3, drum_encoder.pitch_to_index[36], 0.4),
    (4, drum_encoder.pitch_to_index[38], 0.8),
    (6, drum_encoder.pitch_to_index[36], 0.7),
    (11, drum_encoder.pitch_to_index[36], 0.5),
    (12, drum_encoder.pitch_to_index[38], 0.8),
    (14, drum_encoder.pitch_to_index[36], 0.7),
    (19, drum_encoder.pitch_to_index[36], 0.5),
    (20, drum_encoder.pitch_to_index[38], 0.8),
    (22, drum_encoder.pitch_to_index[36], 0.7),
    (27, drum_encoder.pitch_to_index[36], 0.5),
    (28, drum_encoder.pitch_to_index[38], 0.8),
    (29, drum_encoder.pitch_to_index[38], 0.3),
]
prompt_hits, prompt_pitches, prompt_vels = make_prompt(32, prompt)
midi = encoder.decode((prompt_hits.squeeze(), prompt_pitches.squeeze(), prompt_vels.squeeze()))
ipd.display(utils.midi_to_audio_display(FS, midi))

hit_logits, pitch_logits, vels = model(torch.cat((prompt_hits, prompt_pitches, prompt_vels), dim=-1))
hit_probs = torch.sigmoid(hit_logits)
pitch_probs = torch.sigmoid(pitch_logits)

hits = (torch.sigmoid(hit_logits) > 0.3).to(dtype=prompt_hits.dtype)
pitches = (torch.sigmoid(pitch_logits) > 0.5).to(dtype=prompt_pitches.dtype)

mask = prompt_hits > 0.5
merged_hits = prompt_hits.masked_scatter(mask, hits)
merged_pitches = prompt_pitches.masked_scatter(mask.expand_as(prompt_pitches), pitches)
merged_vels = prompt_vels.masked_scatter(mask.expand_as(prompt_vels), vels)

print(merged_hits)
print(merged_pitches)
midi = encoder.decode((merged_hits.squeeze(), merged_pitches.squeeze(), merged_vels.squeeze()))
display(utils.midi_to_audio_display(FS, midi))