In [None]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

IS_COLAB = is_colab()

if IS_COLAB:
  !pip install git+https://github.com/hoggl-dsp/cc_beats

  !apt-get install fluidsynth
  !mkdir data
  !mkdir data/sf2
  !wget -O data/sf2/Xpand_2_-_Practice_Room_Kit.sf2 'https://musical-artifacts.com/artifacts/6296/Xpand_2_-_Practice_Room_Kit.sf2'

  !wget -O data/groove_midi_only.zip 'https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip'
  !unzip data/groove_midi_only.zip -d data/

  !wget -O data/clean_midi.tar.gz 'http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz'
  !tar -xf data/clean_midi.tar.gz -C data/

In [None]:
import torch
import torchaudio
import symusic
import symusic.types
from midi2audio import FluidSynth

import IPython.display as ipd
import ipywidgets as widgets
import tqdm
import os

from typing import Callable

from cc_beats import tokeniser, modules, utils

sf_path = 'data/sf2/Xpand_2_-_Practice_Room_Kit.sf2'
FS = FluidSynth(sf_path)

In [None]:
midi_files = []
for root, _, files in os.walk(os.path.join('data', 'groove')):
    midi_files.extend([os.path.join(root, file) for file in files if file.endswith('.mid') and 'beat' in file and 'eval' not in os.path.basename(root)])

midi_files = []
for root, _, files in os.walk(os.path.join('data', 'clean_midi')):
    midi_files.extend([os.path.join(root, file) for file in files if file.endswith('.mid')]) # and 'beat' in file and 'eval' not in os.path.basename(root)])

for i in range(20):
    print(midi_files[i])

midi_file = symusic.Score(midi_files[0])

encoder = tokeniser.DrumSequenceEncoder(subdivision=16)

tok_sequence = encoder.encode(midi_file)
returned_midi = encoder.decode(tok_sequence)

returned_midi.time_signatures.extend(midi_file.time_signatures)
returned_midi.key_signatures.extend(midi_file.key_signatures)
returned_midi.tempos.extend(midi_file.tempos)

print("Actual File")
print(midi_file)
print(midi_file.tracks[0].notes[:10])
print('-----------------------------')

print("Decoded File")
print(returned_midi)
print(returned_midi.tracks[0].notes[:10])
print('-----------------------------')

display(utils.midi_to_audio_display(FS, midi_file))
display(utils.midi_to_audio_display(FS, returned_midi))

In [None]:
def get_midi_data_files(root_dir: str, filter_fn: Callable[[str, str], bool] = lambda root, file: True):
    midi_files = []
    for root, _, files in os.walk(root_dir):
        midi_files.extend([os.path.join(root, file) for file in files if filter_fn(root, file)])
    return midi_files

In [None]:
# Lakh clean_set analysis
import pandas as pd

def do_lakh_analysis():
    files = get_midi_data_files(os.path.join('data', 'clean_midi'))

    data = {
        'File': [],
        'File Length (quarters)': [],
        'Tempos': [],
        'Is 4/4': [],
        'Num Drum Tracks': [],
        'Drum Hit Count': [],
        'Unique Drum Hits': [],
        'First Drum (quarters)': [],
        'Last Drum (quarters)': [],
    }
    for file in tqdm.tqdm(files, desc='Analysing Midi Files'):
        midi = None
        drum_tracks = None
        data['File'].append(file)
        try:
            midi = symusic.Score.from_file(file)
            data['File Length (quarters)'].append((midi.end() - midi.start()) / midi.tpq)
            data['Tempos'].append(midi.tempos)
            data['Is 4/4'].append(len(midi.time_signatures) == 1 and (midi.time_signatures[0].numerator == 4 and midi.time_signatures[0].denominator == 4))

            drum_tracks = [track for track in midi.tracks if track.is_drum]
            data['Num Drum Tracks'].append(len(drum_tracks))

            count = 0
            note_set = set()
            first_note_time = None
            last_note_time = None
            for track in drum_tracks:
                for note in track.notes:
                    if note.pitch not in tokeniser.ROLAND_DRUM_MAPPING:
                        continue
                    count += 1
                    note_set.add(note.pitch)
                    if first_note_time is None or first_note_time > note.time:
                        first_note_time = note.time
                    if last_note_time is None or last_note_time < note.time:
                        last_note_time = note.time
            
            data['Drum Hit Count'].append(count)
            data['Unique Drum Hits'].append(note_set)
            data['First Drum (quarters)'].append(first_note_time / midi.tpq if first_note_time is not None else None)
            data['Last Drum (quarters)'].append(last_note_time / midi.tpq if last_note_time is not None else None)

        except:
            for key in data:
                if key != 'File':
                    data[key].append(None)
    

    
    lakh_df = pd.DataFrame(data)

    print(" ***** Filtering Dataset *****")
    print("Full Dataset:", len(lakh_df))
    lakh_df = lakh_df[~lakh_df.isna()]
    print("Without unreadable files:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Num Drum Tracks'] > 0]
    print("Without files with no drums:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Is 4/4']]
    print("Without files not in 4/4:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Drum Hit Count'] > 256]
    print("Without files with less than 256 drum hits", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Drum Hit Count'] / (lakh_df['File Length (quarters)']) >= 1.0]
    print("Without files with note density < 1 per quarter:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['File Length (quarters)'] < 1000.0]
    print("Without files with > 1000 beats:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Unique Drum Hits'].apply(lambda x: len(x) > 2)]
    print("Without files with no drum hit variety:", len(lakh_df))

    return lakh_df
    
lakh_clean_df = do_lakh_analysis()

print(lakh_clean_df.describe())

print(lakh_clean_df.head(30))

In [None]:
class DrumInpaintingTransformer(torch.nn.Module):
    def __init__(self, num_pitches: int, embedding_dim: int = 64, num_layers: int = 6, num_heads: int = 8, pitch_pos_weight: float = 2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_proj = torch.nn.Linear(1 + num_pitches, embedding_dim)

        self.transformer = torch.nn.ModuleList([
            modules.TransformerLayerWithRelativeAttention(embedding_dim, num_heads=num_heads, dropout=0.1, max_distance=64)
            for _ in range(num_layers)
        ])

        self.output_hit_proj = torch.nn.Linear(embedding_dim, 1)
        self.output_pitches_proj = torch.nn.Linear(embedding_dim, num_pitches)
        self.output_velocities_proj = torch.nn.Linear(embedding_dim, num_pitches)

        if isinstance(pitch_pos_weight, float):
            pitch_pos_weight = torch.ones((1, 1, 9)) * pitch_pos_weight
        self.pitch_pos_weight = pitch_pos_weight
    
    def forward(self, x: torch.Tensor, attention_mask=None):
        # Input dims [batch, seq_len, 1 + 2 * num_pitches]
        
        x = self.input_proj(x) # [batch, seq_len, embedding_dim]

        for layer in self.transformer:
            x = layer(x) # [batch, seq_len, embedding_dim]
        
        hit_logits = self.output_hit_proj(x)
        pitch_logits = self.output_pitches_proj(x)
        vel_logits = torch.sigmoid(self.output_velocities_proj(x))

        return hit_logits, pitch_logits, vel_logits
    
    def loss_function(self, 
            preds: tuple[torch.Tensor, torch.Tensor, torch.Tensor], 
            truths: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            inpainting_mask: torch.Tensor,
            unmasked_weight: float,
        ) -> torch.Tensor:
        pred_hits, pred_pitches, pred_vels = preds
        true_hits, true_pitches, true_vels = truths
        
        device = pred_hits.device


        hit_weights = torch.ones_like(true_hits)
        hit_weights[~inpainting_mask] = unmasked_weight

        hit_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            pred_hits,
            true_hits,
            weight=hit_weights
        )

        # hit_loss = self.hit_loss(pred_hits[inpainting_mask], true_hits[inpainting_mask])

        hit_mask = inpainting_mask & (true_hits > 0.5)
        hit_mask = hit_mask.expand_as(true_pitches)

        pitch_weights = torch.ones_like(true_pitches)
        pitch_weights[~hit_mask] = unmasked_weight

        pitch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            pred_pitches,
            true_pitches,
            weight=pitch_weights,
            pos_weight=self.pitch_pos_weight.to(pred_pitches.device)
        )

        hit_pitch_mask = hit_mask & (true_pitches > 0.5)
        vel_loss = torch.nn.functional.mse_loss(
            pred_vels[hit_pitch_mask], 
            true_vels[hit_pitch_mask],
            # reduction='sum'
        ) if hit_pitch_mask.any() else torch.tensor(0.0, device=device)

        return hit_loss + pitch_loss + vel_loss


In [None]:
# Dataset
def make_dataset(encoder: tokeniser.DrumSequenceEncoder, midi_files: list[str], max_seq_length: int = 64):
    sequences = encoder.encode_all(midi_files)
    split_hits = []
    split_pitches = []
    split_velocities = []
    for hits, pitches, velocities in sequences:
        len_padding = max_seq_length - (hits.size(0) % max_seq_length)
        if len_padding > 0:
            hits = torch.cat([hits, torch.zeros(len_padding, hits.size(1), dtype=hits.dtype, device=hits.device)])
            pitches = torch.cat([pitches, torch.zeros(len_padding, pitches.size(1), dtype=pitches.dtype, device=pitches.device)])
            velocities = torch.cat([velocities, torch.zeros(len_padding, velocities.size(1), dtype=velocities.dtype, device=velocities.device)])
        
        split_hits.extend(torch.tensor_split(hits, int(hits.size(0) / max_seq_length)))
        split_pitches.extend(torch.tensor_split(pitches, int(pitches.size(0) / max_seq_length)))
        split_velocities.extend(torch.tensor_split(velocities, int(velocities.size(0) / max_seq_length)))
    
    return torch.utils.data.TensorDataset(
        torch.stack(split_hits, dim=0),
        torch.stack(split_pitches, dim=0),
        torch.stack(split_velocities, dim=0)
    )

drum_encoder = tokeniser.DrumSequenceEncoder(subdivision=16)
dataset = make_dataset(drum_encoder, lakh_clean_df['File'].to_list(), max_seq_length=128)

print("Dataset_Size:", len(dataset))

In [None]:
# Set up training parameters
num_epochs = 1
batch_size = 32
mask_prob = 0.4
unmasked_weight = 0.3

device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")

model = DrumInpaintingTransformer(num_pitches=9, embedding_dim=16, num_layers=20, num_heads=16, pitch_pos_weight=3.0)
model = model.to(device)  # Move model to the appropriate device

# Create a DataLoader for batching
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, val_size])

optimiser = torch.optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-6)

train_loader = torch.utils.data.DataLoader(
    train_set, 
    batch_size=batch_size, 
    shuffle=True
)

val_loader = torch.utils.data.DataLoader(
    valid_set, 
    batch_size=batch_size, 
    shuffle=False
)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=(num_epochs * len(train_loader)) // 10)

# Training loop
train_losses = []
val_losses = []

for i in range(num_epochs):
    model.train()
    total_train_loss = 0

    for j, (true_hits, true_pitches, true_vels) in enumerate(train_loader, 1):
        optimiser.zero_grad()
        
        true_hits, true_pitches, true_vels = true_hits.to(device), true_pitches.to(device), true_vels.to(device)
        masked_hits, masked_pitches = true_hits.clone(), true_pitches.clone()

        hit_mask = torch.rand_like(true_hits, device=device) < mask_prob
        masked_hits[hit_mask] = 0.0
        masked_pitches[hit_mask.expand_as(masked_pitches)] = 0.0

        model_input = torch.cat((masked_hits, masked_pitches), dim=-1)

        pred_hits, pred_pitches, pred_vels = model(model_input)
        loss = model.loss_function(
            (pred_hits, pred_pitches, pred_vels),
            (true_hits, true_pitches, true_vels),
            hit_mask,
            unmasked_weight=unmasked_weight
        )
        loss.backward()
        optimiser.step()

        total_train_loss += loss.item()
        
        if j % 10 == 0:
            print(f'[{j}/{len(train_loader)}] - Running loss = {total_train_loss / j}')
            lr_scheduler.step()

    
    # Keep track of training loss
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    total_val_loss = 0
    
    with torch.no_grad():
        for true_hits, true_pitches, true_vels in val_loader:
            true_hits, true_pitches, true_vels = true_hits.to(device), true_pitches.to(device), true_vels.to(device)
            
            # Create same masking pattern as in training
            masked_hits, masked_pitches = true_hits.clone(), true_pitches.clone()
            hit_mask = torch.rand_like(true_hits, device=device) < mask_prob
            masked_hits[hit_mask] = 0.0
            masked_pitches[hit_mask.expand_as(masked_pitches)] = 0.0
            
            model_input = torch.cat((masked_hits, masked_pitches), dim=-1)
            
            pred_hits, pred_pitches, pred_vels = model(model_input)
            loss = model.loss_function(
                (pred_hits, pred_pitches, pred_vels),
                (true_hits, true_pitches, true_vels),
                hit_mask,
                unmasked_weight=unmasked_weight
            )
            
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    print(f"Epoch {i+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

import matplotlib.pyplot as plt

# Plot training and validation loss
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses) + 1), train_losses, 'b-', label='Training Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, 'r-', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)
plt.show()

# Save the model to disk
torch.save(model.state_dict(), 'drum_inpainting_model.pth')
print("Model saved to 'drum_inpainting_model.pth'")

In [None]:
def make_prompt(length: int, prompt_hits: list[tuple[int, int, float]]):
    hits = torch.zeros(1, length, 1)
    pitches = torch.zeros(1, length, 9)

    for index, pitch, velocities in prompt_hits:
        hits[:, index, :] = 1.0
        pitches[:, index, pitch] = 1.0

    return hits.to(device), pitches.to(device), velocities.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")
model = DrumInpaintingTransformer(num_pitches=9, embedding_dim=32, num_layers=20, num_heads=16, pitch_pos_weight=3.0)
model.load_state_dict(torch.load('drum_inpainting_model.pth', map_location=device))
model = model.to(device)
prompt = [
    (0, drum_encoder.pitch_to_index[36], 0.8),
    (3, drum_encoder.pitch_to_index[36], 0.4),
    (4, drum_encoder.pitch_to_index[38], 0.8),
    (6, drum_encoder.pitch_to_index[36], 0.7),
    (11, drum_encoder.pitch_to_index[36], 0.5),
    (12, drum_encoder.pitch_to_index[38], 0.8),
    (14, drum_encoder.pitch_to_index[36], 0.7),
    (19, drum_encoder.pitch_to_index[36], 0.5),
    (20, drum_encoder.pitch_to_index[38], 0.8),
    (22, drum_encoder.pitch_to_index[36], 0.7),
    (27, drum_encoder.pitch_to_index[36], 0.5),
    (28, drum_encoder.pitch_to_index[38], 0.8),
    (29, drum_encoder.pitch_to_index[38], 0.3),
]
prompt_hits, prompt_pitches, prompt_vels = make_prompt(32, prompt)
midi = encoder.decode((prompt_hits.squeeze(), prompt_pitches.squeeze(), prompt_vels.squeeze()))
ipd.display(utils.midi_to_audio_display(FS, midi))

hit_logits, pitch_logits, vels = model(torch.cat((prompt_hits, prompt_pitches), dim=-1))
hit_probs = torch.sigmoid(hit_logits)
pitch_probs = torch.sigmoid(pitch_logits)

hit_logits = (torch.sigmoid(hit_logits) > 0.3).to(dtype=prompt_hits.dtype)
pitch_logits = (torch.sigmoid(pitch_logits) > 0.4).to(dtype=prompt_pitches.dtype)

midi = encoder.decode((hit_logits.squeeze(0), pitch_logits.squeeze(0), vels.squeeze(0)))
display(utils.midi_to_audio_display(FS, midi))

hit_mask = prompt_hits > 0.5
merged_hits = prompt_hits.masked_scatter(hit_mask, hit_logits)
merged_pitches = prompt_pitches.masked_scatter(hit_mask.expand_as(prompt_pitches), pitch_logits)
merged_vels = prompt_vels.masked_scatter(hit_mask.expand_as(prompt_vels), vels)

midi = encoder.decode((merged_hits.squeeze(), merged_pitches.squeeze(), merged_vels.squeeze()))
display(utils.midi_to_audio_display(FS, midi))

# Interface

In [None]:
drum_hits = {
    'Kick': 36,
    'Snare': 38,
    'Closed Hi-Hat': 42,
    'Open Hi-Hat': 46,
    'Low Tom': 43,
    'Mid Tom': 47,
    'High Tom': 50,
    'Crash': 49,
    'Ride': 51
}

class DrumSequencer:
    def __init__(self, model: DrumInpaintingTransformer, encoder: tokeniser.DrumSequenceEncoder, fluid_synth: FluidSynth, num_steps: int = 16):
        self.model = model
        self.encoder = encoder
        self.fs = fluid_synth
        self.device = next(model.parameters()).device

        self.num_steps = num_steps

        self._create_interface()

    def _create_interface(self):
        # Create main layout
        self.output = widgets.Output()
        self.drum_checkboxes = {}

        # Create grid container
        grid = widgets.GridspecLayout(
            len(drum_hits) + 1, 
            self.num_steps + 1,                             
            grid_gap='2px',
            width='100%',
            height=f'{len(drum_hits)*40}px'
        )

        for i in range(1, self.num_steps + 1):
            grid[0, i] = widgets.Label(
                f'{i}',
                layout=widgets.Layout(margin='auto', justify_content='center')
            )

        # Add drum checkboxes
        for i, (drum_name, pitch) in enumerate(drum_hits.items(), 1):
            # Add label for the drum
            label = widgets.Label(
                drum_name,
                align='center',
                layout=widgets.Layout(
                    height='auto',
                    width='100px',
                    padding='3px',
                    justify_content='flex-end',
            ))
            grid[i, 0] = label
            
            # Create checkbox row for this drum
            self.drum_checkboxes[drum_name] = []
            for step in range(1, self.num_steps + 1):
                checkbox = widgets.Checkbox(
                    value=False,
                    indent=False,
                    layout=widgets.Layout(
                        height='auto', 
                        width='auto', 
                        margin='auto',
                        justify_content='center'
                    )
                )
                self.drum_checkboxes[drum_name].append(checkbox)
                grid[i, step] = checkbox
        
        
                
        # Add control buttons
        play_btn = widgets.Button(description='Play')
        play_btn.on_click(self._play_pattern)

        gen_btn = widgets.Button(description='Generate')
        gen_btn.on_click(self._generate_pattern)
                
        clear_btn = widgets.Button(description='Clear')
        clear_btn.on_click(self._clear_pattern)
                
        # Arrange the elements
        controls = widgets.HBox([play_btn, gen_btn, clear_btn])
        main_layout = widgets.VBox([grid, controls, self.output])
        display(main_layout)

        display(self.output)
    
    def _make_midi_from_widgets(self):
        track = symusic.Track(is_drum=True)

        tpq = self.encoder.tpq
        subdiv = self.encoder.subdivision

        steps = { i: [] for i in range(self.num_steps)}
        for drum, boxes in self.drum_checkboxes.items():
            for i, box in enumerate(boxes):
                if box.value:
                    track.notes.append(symusic.Note(int(i * tpq * 4 / subdiv), 80, drum_hits[drum], 80))
        
        track.sort()
        midi = symusic.Score(tpq)
        midi.tracks.append(track)
        return midi
        
    def _play_pattern(self, _):
        midi = self._make_midi_from_widgets()
        with self.output:
            ipd.clear_output()
            display(utils.midi_to_audio_display(self.fs, midi))

    def _generate_pattern(self, _):
        midi = self._make_midi_from_widgets()
        # sequence = sequence.where(sequence == self.tokeniser['<rest>'], torch.full_like(sequence, self.tokeniser['<mask>']))
        with torch.no_grad():
            self.model.eval()

            hits, pitch, _ = self.encoder.encode(midi)
            sequence = torch.cat((hits, pitch), dim=-1).unsqueeze(0).to(device)
            hit_logits, pitch_logits, vels = self.model(sequence)

            hit_probs = (torch.sigmoid(hit_logits) > 0.2).to(torch.float32)
            pitch_probs = (torch.sigmoid(pitch_logits) > 0.3).to(torch.float32)

            midi = self.encoder.decode((hit_probs.squeeze(0), pitch_probs.squeeze(0), vels.squeeze(0)))
            # pred_sequence = torch.multinomial(torch.softmax(logits / 0.1, dim=-1), num_samples=1) # TODO add temperature
            # sequence = sequence.masked_scatter(sequence == self.tokeniser['<mask>'], pred_sequence)
        
        with self.output:
            ipd.clear_output()
            display(utils.midi_to_audio_display(self.fs, midi))


    def _clear_pattern(self, _):
        for boxes in self.drum_checkboxes.values():
            for box in boxes:
                box.value = False
        
DrumSequencer(model, drum_encoder, FS)
