In [None]:
def is_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

IS_COLAB = is_colab()

if IS_COLAB:
  !pip install git+https://github.com/hoggl-dsp/cc_beats

  !apt-get install fluidsynth
  !mkdir data
  !mkdir data/sf2
  !wget -O data/sf2/Xpand_2_-_Practice_Room_Kit.sf2 'https://musical-artifacts.com/artifacts/6296/Xpand_2_-_Practice_Room_Kit.sf2'

  !wget -O data/groove_midi_only.zip 'https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip'
  !unzip data/groove_midi_only.zip -d data/

  !wget -O data/clean_midi.tar.gz 'http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz'
  !tar -xf data/clean_midi.tar.gz -C data/

In [None]:
import torch
import torchaudio
import torch.utils.tensorboard

import symusic
import symusic.types
from midi2audio import FluidSynth

import IPython.display as ipd
import ipywidgets as widgets
import tqdm
import os
import random
from datetime import datetime

from typing import Callable

from cc_beats import tokeniser, modules, utils

sf_path = 'data/sf2/Xpand_2_-_Practice_Room_Kit.sf2'
FS = FluidSynth(sf_path)

In [None]:
midi_files = []
for root, _, files in os.walk(os.path.join('data', 'groove')):
    midi_files.extend([os.path.join(root, file) for file in files if file.endswith('.mid') and 'beat' in file and 'eval' not in os.path.basename(root)])

midi_files = []
for root, _, files in os.walk(os.path.join('data', 'clean_midi')):
    midi_files.extend([os.path.join(root, file) for file in files if file.endswith('.mid')]) # and 'beat' in file and 'eval' not in os.path.basename(root)])

for i in range(20):
    print(midi_files[i])

midi_file = symusic.Score(midi_files[0])

encoder = tokeniser.DrumSequenceEncoder(subdivision=16)

tok_sequence = encoder.encode(midi_file)
returned_midi = encoder.decode(tok_sequence)

returned_midi.time_signatures.extend(midi_file.time_signatures)
returned_midi.key_signatures.extend(midi_file.key_signatures)
returned_midi.tempos.extend(midi_file.tempos)

print("Actual File")
print(midi_file)
print(midi_file.tracks[0].notes[:10])
print('-----------------------------')

print("Decoded File")
print(returned_midi)
print(returned_midi.tracks[0].notes[:10])
print('-----------------------------')

display(utils.midi_to_audio_display(FS, midi_file))
display(utils.midi_to_audio_display(FS, returned_midi))

In [None]:
def get_midi_data_files(root_dir: str, filter_fn: Callable[[str, str], bool] = lambda root, file: True):
    midi_files = []
    for root, _, files in os.walk(root_dir):
        midi_files.extend([os.path.join(root, file) for file in files if filter_fn(root, file)])
    return midi_files

In [None]:
# Lakh clean_set analysis
import pandas as pd

def do_lakh_analysis():
    files = get_midi_data_files(os.path.join('data', 'clean_midi'))

    data = {
        'File': [],
        'File Length (quarters)': [],
        'Tempos': [],
        'Is 4/4': [],
        'Num Drum Tracks': [],
        'Drum Hit Count': [],
        'Unique Drum Hits': [],
        'First Drum (quarters)': [],
        'Last Drum (quarters)': [],
    }
    for file in tqdm.tqdm(files, desc='Analysing Midi Files'):
        midi = None
        drum_tracks = None
        data['File'].append(file)
        try:
            midi = symusic.Score.from_file(file)
            data['File Length (quarters)'].append((midi.end() - midi.start()) / midi.tpq)
            data['Tempos'].append(midi.tempos)
            data['Is 4/4'].append(len(midi.time_signatures) == 1 and (midi.time_signatures[0].numerator == 4 and midi.time_signatures[0].denominator == 4))

            drum_tracks = [track for track in midi.tracks if track.is_drum]
            data['Num Drum Tracks'].append(len(drum_tracks))

            count = 0
            note_set = set()
            first_note_time = None
            last_note_time = None
            for track in drum_tracks:
                for note in track.notes:
                    if note.pitch not in tokeniser.ROLAND_DRUM_MAPPING:
                        continue
                    count += 1
                    note_set.add(note.pitch)
                    if first_note_time is None or first_note_time > note.time:
                        first_note_time = note.time
                    if last_note_time is None or last_note_time < note.time:
                        last_note_time = note.time

            data['Drum Hit Count'].append(count)
            data['Unique Drum Hits'].append(note_set)
            data['First Drum (quarters)'].append(first_note_time / midi.tpq if first_note_time is not None else None)
            data['Last Drum (quarters)'].append(last_note_time / midi.tpq if last_note_time is not None else None)

        except:
            for key in data:
                if key != 'File':
                    data[key].append(None)



    lakh_df = pd.DataFrame(data)

    print(" ***** Filtering Dataset *****")
    print("Full Dataset:", len(lakh_df))
    lakh_df = lakh_df[~lakh_df.isna()]
    print("Without unreadable files:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Num Drum Tracks'] > 0]
    print("Without files with no drums:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Is 4/4']]
    print("Without files not in 4/4:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Drum Hit Count'] > 256]
    print("Without files with less than 256 drum hits", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Drum Hit Count'] / (lakh_df['File Length (quarters)']) >= 1.0]
    print("Without files with note density < 1 per quarter:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['File Length (quarters)'] < 1000.0]
    print("Without files with > 1000 beats:", len(lakh_df))
    lakh_df = lakh_df[lakh_df['Unique Drum Hits'].apply(lambda x: len(x) > 2)]
    print("Without files with no drum hit variety:", len(lakh_df))

    return lakh_df

lakh_clean_df = do_lakh_analysis()

print(lakh_clean_df.describe())

print(lakh_clean_df.head(30))

In [None]:
class DrumInpaintingTransformer(torch.nn.Module):
    def __init__(self, num_pitches: int, embedding_dim: int = 64, num_layers: int = 6, num_heads: int = 8, dropout: float = 0.1, pitch_pos_weight: float = 2.0, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.input_proj = torch.nn.Linear(1 + num_pitches, embedding_dim)

        self.transformer = torch.nn.ModuleList([
            modules.TransformerLayerWithRelativeAttention(embedding_dim, num_heads=num_heads, dropout=dropout, max_distance=64)
            for _ in range(num_layers)
        ])

        self.output_hit_proj = torch.nn.Linear(embedding_dim, 1)
        self.output_pitches_proj = torch.nn.Linear(embedding_dim, num_pitches)
        self.output_velocities_proj = torch.nn.Linear(embedding_dim, num_pitches)

        if isinstance(pitch_pos_weight, float):
            pitch_pos_weight = torch.ones((1, 1, 9)) * pitch_pos_weight
        self.pitch_pos_weight = pitch_pos_weight

    def forward(self, x: torch.Tensor, attention_mask=None):
        # Input dims [batch, seq_len, 1 + 2 * num_pitches]

        x = self.input_proj(x) # [batch, seq_len, embedding_dim]

        for layer in self.transformer:
            x = layer(x) # [batch, seq_len, embedding_dim]

        hit_logits = self.output_hit_proj(x)
        pitch_logits = self.output_pitches_proj(x)
        vel_logits = torch.sigmoid(self.output_velocities_proj(x))

        return hit_logits, pitch_logits, vel_logits

    def loss_function(self,
            preds: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            truths: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
            inpainting_mask: torch.Tensor,
            unmasked_weight: float,
        ) -> torch.Tensor:
        pred_hits, pred_pitches, pred_vels = preds
        true_hits, true_pitches, true_vels = truths

        device = pred_hits.device

        hit_weights = torch.ones_like(true_hits)
        hit_weights[~inpainting_mask] = unmasked_weight

        hit_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            pred_hits,
            true_hits,
            weight=hit_weights
        )

        hit_mask = inpainting_mask & (true_hits > 0.5)
        hit_mask = hit_mask.expand_as(true_pitches)

        pitch_weights = torch.ones_like(true_pitches)
        pitch_weights[~hit_mask] = unmasked_weight

        pitch_loss = torch.nn.functional.binary_cross_entropy_with_logits(
            pred_pitches,
            true_pitches,
            weight=pitch_weights,
            pos_weight=self.pitch_pos_weight.to(pred_pitches.device)
        )

        hit_pitch_mask = hit_mask & (true_pitches > 0.5)
        vel_loss = torch.nn.functional.mse_loss(
            pred_vels[hit_pitch_mask],
            true_vels[hit_pitch_mask]
        ) if hit_pitch_mask.any() else torch.tensor(0.0, device=device)

        return hit_loss + pitch_loss + vel_loss


In [None]:
# Dataset
def make_dataset(encoder: tokeniser.DrumSequenceEncoder, midi_files: list[str], max_seq_length: int = 64):
    sequences = encoder.encode_all(midi_files)
    split_hits = []
    split_pitches = []
    split_velocities = []
    for hits, pitches, velocities in sequences:
        len_padding = max_seq_length - (hits.size(0) % max_seq_length)
        if len_padding > 0:
            hits = torch.cat([hits, torch.zeros(len_padding, hits.size(1), dtype=hits.dtype, device=hits.device)])
            pitches = torch.cat([pitches, torch.zeros(len_padding, pitches.size(1), dtype=pitches.dtype, device=pitches.device)])
            velocities = torch.cat([velocities, torch.zeros(len_padding, velocities.size(1), dtype=velocities.dtype, device=velocities.device)])

        split_hits.extend(torch.tensor_split(hits, int(hits.size(0) / max_seq_length)))
        split_pitches.extend(torch.tensor_split(pitches, int(pitches.size(0) / max_seq_length)))
        split_velocities.extend(torch.tensor_split(velocities, int(velocities.size(0) / max_seq_length)))

    return torch.utils.data.TensorDataset(
        torch.stack(split_hits, dim=0),
        torch.stack(split_pitches, dim=0),
        torch.stack(split_velocities, dim=0)
    )

drum_encoder = tokeniser.DrumSequenceEncoder(subdivision=16)
dataset = make_dataset(drum_encoder, lakh_clean_df['File'].to_list(), max_seq_length=128)

print("Dataset_Size:", len(dataset))

In [None]:
if IS_COLAB:
    %load_ext tensorboard
    %tensorboard --logdir logs

In [None]:
# Set up training parameters
batch_size = 32

mask_prob_config = {
    'mean': 0.4,
    'std': 0.1,
    'min': 0.1,
    'max': 0.7,
}

unmasked_weight_config = {
    'initial': 0.5,
    'final': 0.1,
}

device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")

model = DrumInpaintingTransformer(num_pitches=9, embedding_dim=32, num_layers=20, num_heads=16, pitch_pos_weight=3.0)
model = model.to(device)  # Move model to the appropriate device

# Create a DataLoader for batching
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_set, valid_set = torch.utils.data.random_split(dataset, [train_size, val_size])

train_subset = torch.utils.data.Subset(train_set, random.sample(range(len(train_set)), 1000))
valid_subset = torch.utils.data.Subset(valid_set, random.sample(range(len(valid_set)), 200))

train_loader = torch.utils.data.DataLoader(
    train_subset,
    batch_size=batch_size,
    shuffle=True
)

val_loader = torch.utils.data.DataLoader(
    valid_subset,
    batch_size=batch_size,
    shuffle=False
)

NUM_HYPERPARAM_SEARCHES = 5
NUM_EPOCHS = 5

hyperparam_choices = {
    'num_layers': [16, 24, 32],
    'num_heads': [8, 16, 32],
    'embedding_dim': [16, 32, 64],
    'pitch_pos_weight': [1.0, 2.0, 3.0],
    'dropout': [0.1, 0.2, 0.3],
}


def make_example_prompt(length: int, prompt_hits: list[tuple[int, int, float]]):
    hits = torch.zeros(1, length, 1)
    pitches = torch.zeros(1, length, 9)
    velocities = torch.zeros(1, length, 9)

    for index, pitch, vel in prompt_hits:
        hits[:, index, :] = 1.0
        pitches[:, index, pitch] = 1.0
        velocities[:, index, pitch] = vel

    return hits.to(device), pitches.to(device), velocities.to(device)

BACKBEAT = make_example_prompt(16, [
    (0, drum_encoder.pitch_to_index[36], 0.8),
    (0, drum_encoder.pitch_to_index[46], 0.8),
    (2, drum_encoder.pitch_to_index[42], 0.5),
    (4, drum_encoder.pitch_to_index[38], 0.8),
    (4, drum_encoder.pitch_to_index[42], 0.6),
    (6, drum_encoder.pitch_to_index[42], 0.5),
    (8, drum_encoder.pitch_to_index[36], 0.8),
    (8, drum_encoder.pitch_to_index[42], 0.6),
    (10, drum_encoder.pitch_to_index[42], 0.5),
    (12, drum_encoder.pitch_to_index[38], 0.8),
    (12, drum_encoder.pitch_to_index[42], 0.6),
    (14, drum_encoder.pitch_to_index[42], 0.5)
])

PLEASE = make_example_prompt(32, [
    (0, drum_encoder.pitch_to_index[36], 0.8),
    (3, drum_encoder.pitch_to_index[36], 0.4),
    (4, drum_encoder.pitch_to_index[38], 0.8),
    (6, drum_encoder.pitch_to_index[36], 0.7),
    (11, drum_encoder.pitch_to_index[36], 0.5),
    (12, drum_encoder.pitch_to_index[38], 0.8),
    (14, drum_encoder.pitch_to_index[36], 0.7),
    (19, drum_encoder.pitch_to_index[36], 0.5),
    (20, drum_encoder.pitch_to_index[38], 0.8),
    (22, drum_encoder.pitch_to_index[36], 0.7),
    (27, drum_encoder.pitch_to_index[36], 0.5),
    (28, drum_encoder.pitch_to_index[38], 0.8),
    (29, drum_encoder.pitch_to_index[38], 0.3),
])

AMEN = make_example_prompt(16, [
    (0, drum_encoder.pitch_to_index[36], 0.8),
    (2, drum_encoder.pitch_to_index[36], 0.6),
    (4, drum_encoder.pitch_to_index[38], 0.8),
    (7, drum_encoder.pitch_to_index[38], 0.5),
    (9, drum_encoder.pitch_to_index[38], 0.4),
    (10, drum_encoder.pitch_to_index[36], 0.6),
    (11, drum_encoder.pitch_to_index[36], 0.5),
    (12, drum_encoder.pitch_to_index[38], 0.8),
    (15, drum_encoder.pitch_to_index[38], 0.5)
])

for _ in range(NUM_HYPERPARAM_SEARCHES):
    hyperparams = {key: random.choice(value) for key, value in hyperparam_choices.items()}

    print("******************** Starting search with hyperparameters: ********************")
    for key, value in hyperparams.items():
        print(f"{key}: {value}")
    print("*******************************************************************************")
    print()
    model = DrumInpaintingTransformer(num_pitches=9, **hyperparams)
    model = model.to(device) 

    optimiser = torch.optim.Adam(model.parameters(), lr=2e-5, weight_decay=1e-6)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimiser,
        T_max=(NUM_EPOCHS * len(train_loader)) // 10,
        eta_min=1e-7
    )

    log_dir = os.path.join('logs', '-'.join([f'{key}_{value}' for key, value in hyperparams.items()]))

    if os.path.exists(log_dir):
        i = 0
        while os.path.exists(f'{log_dir}_{i}'):
            i += 1
        log_dir = f'{log_dir}_{i}'
    os.makedirs(log_dir, exist_ok=True)

    writer = torch.utils.tensorboard.SummaryWriter(log_dir=log_dir, flush_secs=30)

    writer.add_text('hyperparams', str(hyperparams))

    for i in range(NUM_EPOCHS):
        print(f"******************** Epoch {i+1}/{NUM_EPOCHS} ********************")
        print("Training...")
        model.train()
        total_train_loss = 0

        unmasked_weight = unmasked_weight_config['initial'] * (unmasked_weight_config['final'] / unmasked_weight_config['initial']) ** (i / NUM_EPOCHS)

        for j, (true_hits, true_pitches, true_vels) in enumerate(train_loader, 1):
            true_hits, true_pitches, true_vels = true_hits.to(device), true_pitches.to(device), true_vels.to(device)

            optimiser.zero_grad()

            div_size = random.choices([1, 2, 4, 8], [0.3, 0.3, 0.2, 0.2])[0]
            true_hits = true_hits.tensor_split(div_size, dim=1)
            true_pitches = true_pitches.tensor_split(div_size, dim=1)
            true_vels = true_vels.tensor_split(div_size, dim=1)
            true_hits = torch.cat(true_hits, dim=0)
            true_pitches = torch.cat(true_pitches, dim=0)
            true_vels = torch.cat(true_vels, dim=0)

            mask_prob = torch.randn(1, device=device) * mask_prob_config['std'] + mask_prob_config['mean']
            mask_prob = torch.clamp(mask_prob, mask_prob_config['min'], mask_prob_config['max'])
            

            masked_hits, masked_pitches = true_hits.clone(), true_pitches.clone()

            hit_mask = torch.rand_like(true_hits, device=device) < mask_prob
            masked_hits[hit_mask] = 0.0
            masked_pitches[hit_mask.expand_as(masked_pitches)] = 0.0

            model_input = torch.cat((masked_hits, masked_pitches), dim=-1)

            pred_hits, pred_pitches, pred_vels = model(model_input)
            loss = model.loss_function(
                (pred_hits, pred_pitches, pred_vels),
                (true_hits, true_pitches, true_vels),
                hit_mask,
                unmasked_weight=unmasked_weight
            )
            loss.backward()
            optimiser.step()

            total_train_loss += loss.item()

            writer.add_scalar('loss/train/batch', loss.item(), i * len(train_loader) + j)
            writer.add_scalar('loss/train/running', total_train_loss / j, i * len(train_loader) + j)

            if j % 10 == 0:
                print(f'[{j}/{len(train_loader)}] - Running loss = {total_train_loss / j}')
                lr_scheduler.step()


        # Keep track of training loss
        avg_train_loss = total_train_loss / len(train_loader)
        writer.add_scalar('loss/train/epoch', avg_train_loss, i)


        print("Validation...")

        # Validation phase
        model.eval()
        total_val_loss = 0

        with torch.no_grad():
            for true_hits, true_pitches, true_vels in val_loader:
                true_hits, true_pitches, true_vels = true_hits.to(device), true_pitches.to(device), true_vels.to(device)

                div_size = random.choices([1, 2, 4, 8], [0.3, 0.3, 0.2, 0.2])[0]
                true_hits = true_hits.tensor_split(div_size, dim=1)
                true_pitches = true_pitches.tensor_split(div_size, dim=1)
                true_vels = true_vels.tensor_split(div_size, dim=1)
                true_hits = torch.cat(true_hits, dim=0)
                true_pitches = torch.cat(true_pitches, dim=0)
                true_vels = torch.cat(true_vels, dim=0)

                mask_prob = torch.randn(1, device=device) * mask_prob_config['std'] + mask_prob_config['mean']
                mask_prob = torch.clamp(mask_prob, mask_prob_config['min'], mask_prob_config['max'])

                # Create same masking pattern as in training
                masked_hits, masked_pitches = true_hits.clone(), true_pitches.clone()
                hit_mask = torch.rand_like(true_hits, device=device) < mask_prob
                masked_hits[hit_mask] = 0.0
                masked_pitches[hit_mask.expand_as(masked_pitches)] = 0.0

                model_input = torch.cat((masked_hits, masked_pitches), dim=-1)

                pred_hits, pred_pitches, pred_vels = model(model_input)
                loss = model.loss_function(
                    (pred_hits, pred_pitches, pred_vels),
                    (true_hits, true_pitches, true_vels),
                    hit_mask,
                    unmasked_weight=unmasked_weight
                )

                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        writer.add_scalar('loss/val/epoch', avg_val_loss, i)

        print(f"Epoch {i+1}/{NUM_EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        writer.add_hparams(hyperparams, {
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
        }, global_step=i)

        # Save the model checkpoint
        checkpoint_path = os.path.join(log_dir, f'checkpoint_epoch_{i+1}.pth')
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Model checkpoint saved to {checkpoint_path}")

        # Save example pattern to TensorBoard
        with torch.no_grad():
            for name, (prompt_hits, prompt_pitches, prompt_vels) in zip(['BACKBEAT', 'PLEASE', 'AMEN'], [BACKBEAT, PLEASE, AMEN]):
                print(f"Generating example pattern: {name}")
                model_input = torch.cat((prompt_hits, prompt_pitches), dim=-1)
                # Save the model to TensorBoard
                writer.add_graph(model, model_input)
                writer.flush()

                hit_logits, pitch_logits, vels = model(model_input)
                hit_probs = torch.sigmoid(hit_logits)
                pitch_probs = torch.sigmoid(pitch_logits)
                hit_preds = (hit_probs > 0.5 + 0.2 * torch.randn_like(hit_probs, device=device)).float()
                pitch_preds = (pitch_probs > 0.5 + 0.2 * torch.randn_like(pitch_probs, device=device)).float()

                writer.add_histogram(f'example/{name}/hits', hit_probs)
                writer.add_histogram(f'example/{name}/pitches', pitch_probs)
                writer.add_histogram(f'example/{name}/velocities', vels)

                midi = encoder.decode((hit_preds.squeeze(), pitch_preds.squeeze(), vels.squeeze()))
                audio, sr = utils.midi_to_audio_tensor(FS, midi)
                writer.add_audio(f'example/{name}/audio', audio, sample_rate=sr, global_step=i)



In [None]:
def make_prompt(length: int, prompt_hits: list[tuple[int, int, float]]):
    hits = torch.zeros(1, length, 1)
    pitches = torch.zeros(1, length, 9)
    velocities = torch.zeros(1, length, 9)

    for index, pitch, vel in prompt_hits:
        hits[:, index, :] = 1.0
        pitches[:, index, pitch] = 1.0
        velocities[:, index, pitch] = vel

    return hits.to(device), pitches.to(device), velocities.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")
model = DrumInpaintingTransformer(num_pitches=9, embedding_dim=32, num_layers=20, num_heads=16, pitch_pos_weight=3.0)
model.load_state_dict(torch.load('drum_inpainting_model.pth', map_location=device))
model = model.to(device)
please_prompt = [
    (0, drum_encoder.pitch_to_index[36], 0.8),
    (3, drum_encoder.pitch_to_index[36], 0.4),
    (4, drum_encoder.pitch_to_index[38], 0.8),
    (6, drum_encoder.pitch_to_index[36], 0.7),
    (11, drum_encoder.pitch_to_index[36], 0.5),
    (12, drum_encoder.pitch_to_index[38], 0.8),
    (14, drum_encoder.pitch_to_index[36], 0.7),
    (19, drum_encoder.pitch_to_index[36], 0.5),
    (20, drum_encoder.pitch_to_index[38], 0.8),
    (22, drum_encoder.pitch_to_index[36], 0.7),
    (27, drum_encoder.pitch_to_index[36], 0.5),
    (28, drum_encoder.pitch_to_index[38], 0.8),
    (29, drum_encoder.pitch_to_index[38], 0.3),
]
prompt_hits, prompt_pitches, prompt_vels = make_prompt(32, please_prompt)
midi = encoder.decode((prompt_hits.squeeze(), prompt_pitches.squeeze(), prompt_vels.squeeze()))
ipd.display(utils.midi_to_audio_display(FS, midi))

hit_logits, pitch_logits, vels = model(torch.cat((prompt_hits, prompt_pitches), dim=-1))
hit_probs = torch.sigmoid(hit_logits)
pitch_probs = torch.sigmoid(pitch_logits)

hit_logits = (torch.sigmoid(hit_logits) > 0.3).to(dtype=prompt_hits.dtype)
pitch_logits = (torch.sigmoid(pitch_logits) > 0.4).to(dtype=prompt_pitches.dtype)

midi = encoder.decode((hit_logits.squeeze(0), pitch_logits.squeeze(0), vels.squeeze(0)))
display(utils.midi_to_audio_display(FS, midi))

hit_mask = prompt_hits > 0.5
merged_hits = prompt_hits.masked_scatter(hit_mask, hit_logits)
merged_pitches = prompt_pitches.masked_scatter(hit_mask.expand_as(prompt_pitches), pitch_logits)
merged_vels = prompt_vels.masked_scatter(hit_mask.expand_as(prompt_vels), vels)

midi = encoder.decode((merged_hits.squeeze(), merged_pitches.squeeze(), merged_vels.squeeze()))
display(utils.midi_to_audio_display(FS, midi))

# Interface

In [None]:
drum_hits = {
    'Kick': 36,
    'Snare': 38,
    'Closed Hi-Hat': 42,
    'Open Hi-Hat': 46,
    'Low Tom': 43,
    'Mid Tom': 47,
    'High Tom': 50,
    'Crash': 49,
    'Ride': 51
}

# Define drum colors dictionary
drum_colors = {
    'Kick': '#FF5252',      # Red
    'Snare': '#FFEB3B',     # Yellow
    'Closed Hi-Hat': '#2196F3', # Blue
    'Open Hi-Hat': '#03A9F4',  # Light Blue
    'Low Tom': '#FF9800',   # Orange
    'Mid Tom': '#FF7043',   # Deep Orange
    'High Tom': '#F44336',  # Red
    'Crash': '#9C27B0',     # Purple
    'Ride': '#673AB7'       # Deep Purple
}


class DrumSequencer:
    def __init__(self, model: DrumInpaintingTransformer, encoder: tokeniser.DrumSequenceEncoder, fluid_synth: FluidSynth, num_steps: int = 16):
        self.model = model
        self.encoder = encoder
        self.fs = fluid_synth
        self.device = next(model.parameters()).device

        self.num_steps = num_steps

        self._create_interface()

    def _create_interface(self):
        # Create main layout
        self.output = widgets.Output()
        self.drum_pads = {}

        # Create grid container
        grid = widgets.GridspecLayout(
            len(drum_hits) + 1,
            self.num_steps + 1,
            grid_gap='2px',
            width='100%',
            height=f'{len(drum_hits)*40}px'
        )

        for i in range(1, self.num_steps + 1):
            grid[0, i] = widgets.Label(
                f'{(i - 1) // 4 + 1}.{(i - 1) % 4 + 1}',
                layout=widgets.Layout(margin='auto', justify_content='center')
            )

        # Add drum checkboxes
        for i, (drum_name, pitch) in enumerate(drum_hits.items(), 1):
            # Add label for the drum
            label = widgets.Label(
                drum_name,
                align='center',
                layout=widgets.Layout(
                    height='auto',
                    width='100px',
                    padding='3px',
                    justify_content='flex-end',
            ))
            grid[i, 0] = label

            # Create checkbox row for this drum
            self.drum_pads[drum_name] = []
            for step in range(1, self.num_steps + 1):
                pad = widgets.ToggleButton(
                    value=False,
                    description='',
                    tooltip=f'{drum_name} Step {step}',
                    disabled=False,
                    button_style='',
                    layout=widgets.Layout(
                        width='30px',
                        height='30px',
                        padding='0px',
                        margin='auto',
                        justify_content='center',
                        align_items='center',
                        border='2px solid #888',
                        border_radius='4px'
                    )
                )
                self.drum_pads[drum_name].append(pad)
                grid[i, step] = pad

        # Add sliders for controlling generation parameters
        density_slider = widgets.FloatSlider(
            value=2.0, 
            min=1.0,
            max=5.0, 
            step=0.01, 
            description='Density:', 
            tooltip='Controls how many hits will be generated (higher == more hits)',
            layout=widgets.Layout(width='300px'),
            readout=False
        )
        
        diversity_slider = widgets.FloatSlider(
            value=2.0,
            min=1.0, 
            max=5.0, 
            step=0.01, 
            description='Diversity:', 
            tooltip='Controls variety of drum types (higher == more variety)',
            layout=widgets.Layout(width='300px'),
            readout=False
        )
        
        tempo_slider = widgets.IntSlider(
            value=120, 
            min=60, 
            max=200, 
            step=1, 
            description='Tempo:',
            tooltip='Controls playback speed',
            layout=widgets.Layout(width='300px'),
            readout=False
        )

        parameter_controls = widgets.HBox(
            [density_slider, diversity_slider, tempo_slider],
            layout = widgets.Layout(
                justify_content='space-around',
                width='100%',
                height='auto',
                padding='auto',
                margin='auto',
                align_items='center',
            )
        )
        
        self.density_slider = density_slider
        self.diversity_slider = diversity_slider
        self.tempo_slider = tempo_slider

        # Add control buttons
        play_btn = widgets.Button(description='Play Pattern')
        play_btn.on_click(self._play_pattern)

        suggest_btn = widgets.Button(description='Suggest Pattern')
        suggest_btn.on_click(self._suggest_pattern)

        generate_btn = widgets.Button(description='Fill Pattern')
        generate_btn.on_click(self._generate_full_pattern)

        clear_btn = widgets.Button(description='Clear')
        clear_btn.on_click(self._clear_pattern)

        save_btn = widgets.Button(description='Save')
        save_btn.on_click(self._save_pattern)

        # Arrange the elements
        controls = widgets.HBox(
            [play_btn, suggest_btn, generate_btn, clear_btn, save_btn],
            layout=widgets.Layout(
                justify_content='space-around',
                width='100%',
                height='auto',
                padding='auto',
                margin='auto',
                align_items='center',
            )
        )

        main_layout = widgets.VBox(
            [grid, parameter_controls, controls, self.output],
            layout=widgets.Layout(
                justify_content='space-around',
                width='100%',
                height='auto',
                padding='10px',
                margin='20px 0',
                align_items='center',
                grid_gap='15px'
            )
        )
        display(main_layout)

    def _make_midi_from_widgets(self):
        track = symusic.Track(is_drum=True)

        tpq = self.encoder.tpq
        subdiv = self.encoder.subdivision

        for drum, boxes in self.drum_pads.items():
            for i, box in enumerate(boxes):
                if box.value:
                    track.notes.append(symusic.Note(int(i * tpq * 4 / subdiv), 80, drum_hits[drum], 80))

        track.sort()
        midi = symusic.Score(tpq)
        midi.tracks.append(track)
        return midi

    def _play_pattern(self, _):
        hits, pitches = self._make_tensors_from_widgets()
        _, _, vels = self.model(torch.cat((hits, pitches), dim=-1))

        midi = self.encoder.decode((hits.squeeze(), pitches.squeeze(), vels.squeeze()))
        self._display_midi(midi)

    def _make_tensors_from_widgets(self):
        hits = torch.zeros((1, self.num_steps, 1), device=device)
        pitches = torch.zeros((1, self.num_steps, len(drum_hits)), device=device)

        for drum, boxes in self.drum_pads.items():
            for i, box in enumerate(boxes):
                if box.value:
                    hits[:, i, :] = 1.0
                    pitches[:, i, self.encoder.pitch_to_index[drum_hits[drum]]] = 1.0

        return hits, pitches
    
    def _get_beta_dist(self):
        # Density for increasing likelihood of notes (centre of beta distribution)
        # Scale to fit beta distribution (restriction alpha, beta > 1.0)
        density = self.density_slider.value / 2.0

        # Diversity for increasing variety of pitches (spread of beta distribution)
        # Scale to fit beta distribution (restriction alpha, beta > 1.0)
        diversity = self.density_slider.value / 20.0
        
        denom = (diversity * (density + 1.0))
        alpha = 1.0 / denom
        beta = density / denom

        dist = torch.distributions.Beta(alpha, beta)
        return dist
        
    def _suggest_pattern(self, _):
        hits, pitches = self._make_tensors_from_widgets()
        pred_hits, pred_pitches, _ = model(torch.cat((hits, pitches), dim=-1))
        pred_hits = torch.sigmoid(pred_hits)
        pred_pitches = torch.sigmoid(pred_pitches)

        dist = self._get_beta_dist()

        new_pitches = pred_hits.where(pred_pitches > dist.sample(pred_pitches.shape).to(device), torch.zeros_like(pred_pitches, device=device)).squeeze()

        for drum, boxes in self.drum_pads.items():
            pitch_idx = self.encoder.pitch_to_index[drum_hits[drum]]
            for step, box in enumerate(boxes):
                box.value = (new_pitches[step, pitch_idx] > 0.5).item()

        self._play_pattern(_)

    def _generate_full_pattern(self, _):
        # Take current pattern and copy to N bars
        hits, pitches = self._make_tensors_from_widgets()
        hits = hits.repeat((1, 4, 1))
        pitches = pitches.repeat((1, 4, 1))
        with self.output:
            self.output.clear_output()
            print("Hits:", hits)
            print("Pitches:", pitches)
        
        hits, pitches, vels = model(torch.cat((hits, pitches), dim=-1))
        hits = torch.sigmoid(hits)
        pitches = torch.sigmoid(pitches)
        
        pitches = hits * pitches

        dist = self._get_beta_dist()
        new_pitches = hits.where(pitches > dist.sample(pitches.shape).to(device), torch.zeros_like(pitches, device=device)).squeeze()

        midi = self.encoder.decode((hits.squeeze(), new_pitches.squeeze(), vels.squeeze()))
        self._display_midi(midi)
    
    def _display_midi(self, midi: symusic.types.Score):
        midi.tempos = [symusic.Tempo(0, self.tempo_slider.value)]
        self.last_midi = midi
        with self.output:
            audio = utils.midi_to_audio_display(self.fs, midi)
            ipd.clear_output()
            display(audio)
    
    def _save_pattern(self, _):
        if hasattr(self, 'last_midi'):
            output_dir = os.path.join('output', 'midis')
            os.makedirs(output_dir, exist_ok=True)

            # Generate a unique filename with timestamp
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"drum_pattern_{timestamp}.mid"
            file_path = os.path.join(output_dir, filename)
            
            # Save the MIDI file
            self.last_midi.dump_midi(file_path)
            with self.output:
                self.output.clear_output()
                print(f"Pattern saved to {file_path}")
                audio = utils.midi_to_audio_display(self.fs, self.last_midi)
                display(audio)
        else:
            with self.output:
                self.output.clear_output()
                print("No pattern to save. Please generate a pattern first.")


    def _clear_pattern(self, _):
        for boxes in self.drum_pads.values():
            for box in boxes:
                box.value = False

DrumSequencer(model, drum_encoder, FS)
