# Section Boundary Detection for MIDI Music

In [None]:
import json
import os
from pathlib import Path
import re
from typing import List, Tuple, Dict, Any

import mido
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

## Config

In [None]:
lakh_midi_dir = Path("../data/lmd_full")

In [None]:
SEED = 0

torch.random.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
# Config
TARGET_TICKS_PER_BEAT = 4
WINDOW_HALF_TICKS = 256
INSTRUMENT_OVERTONES = True
NUM_TARGETS = 1 # Additional targets within 2**i bars of center where i < NUM_TARGETS
SEPARATE_DRUMS = True
PATCH_NORMALIZE = True
PRETRAIN = True
POSITIVE_OVERSAMPLING_FACTOR = 2
NEGATIVE_UNDERSAMPLING_FACTOR = 1
PAD_PIANO_ROLL = False # For handling boundary patches

## Utils

In [None]:
MARKERS = json.load(open("../markers_qn.json"))

def parse_midi(file_path):
    """
    Parse a MIDI file into a list of bar segments per track.
    A bar segment is defined as a list of MIDI messages encoded as tuples that fit into a single bar.
    A tuple is defined as (time, note, velocity, duration, channel, program)
    """
    midi = mido.MidiFile(file_path, clip=True)

    track_data = {
        (track.name if track.name else f"track_{idx}"): []
        for idx, track in enumerate(midi.tracks)
    }

    file_name = os.path.basename(file_path).split('.mid')[0]
    marker_qns = MARKERS[file_name]
    markers_ticks = [int(round(x[0] * midi.ticks_per_beat)) for x in marker_qns]

    channel_volumes = {
        i: 127
        for i in range(16)
    }
    channel_expressions = {
        i: 127
        for i in range(16)
    }
    channel_instruments = {
        i: 0
        for i in range(16)
    }

    for idx, track in enumerate(midi.tracks):
        track_name = track.name if track.name else f"track_{idx}"
        current_ticks = 0
        for msg in track:
            current_ticks += msg.time
            if msg.type == "control_change":
                if msg.control == 7:
                    channel_volumes[msg.channel] = msg.value
                elif msg.control == 11:
                    channel_expressions[msg.channel] = msg.value
            elif msg.type == "program_change":
                channel_instruments[msg.channel] = msg.program
            elif msg.type == "marker":
                pass
            elif msg.type == "note_on" and msg.velocity > 0:
                velocity = msg.velocity * (channel_volumes[msg.channel] / 127.) * (
                            channel_expressions[msg.channel] / 127.)
                program = channel_instruments[msg.channel]
                track_data[track_name].append({
                    "time": current_ticks,
                    "note": msg.note,
                    "velocity": velocity,
                    "duration": -1,
                    "channel": msg.channel,
                    "program": program
                })
            elif msg.type == "note_off" or (msg.type == "note_on" and msg.velocity == 0):
                for note in track_data[track_name]:
                    if note["duration"] == -1 and note["note"] == msg.note and note["channel"] == msg.channel:
                        note["duration"] = current_ticks - note["time"]
                        break

    # Remove duplicate marker ticks
    markers_ticks = list(set(markers_ticks))
    markers_ticks.sort()

    return track_data, markers_ticks, midi.ticks_per_beat

In [None]:
def instrument_overtone_intensities(program, num_harmonics=3, max_harmonic=5):
    """
    Generate a set of harmonics and their intensities for a given instrument program.
    The harmonics are random but fixed for a given program.
    """
    np.random.seed(hash(str(program)) % 2**32)

    harmonics = np.sort(np.random.choice(max_harmonic, num_harmonics, replace=False) + 2)
    intensities = np.sort(np.random.rand(num_harmonics))[::-1]

    # Return to original seed
    np.random.seed(SEED)

    return harmonics, intensities

In [None]:
instrument_categories = {
    "Piano": range(0, 8), "Chromatic Percussion": range(8, 16),
    "Organ": range(16, 24), "Guitar": range(24, 32),
    "Bass": range(32, 40), "Strings": range(40, 48),
    "Ensemble": range(48, 56), "Brass": range(56, 64),
    "Reed": range(64, 72), "Pipe": range(72, 80),
    "Synth Lead": range(80, 88), "Synth Pad": range(88, 96),
    "Synth Effects": range(96, 104), "Ethnic": range(104, 112),
    "Percussive": range(112, 120), "Sound Effects": range(120, 128)
}

[instrument_overtone_intensities(instr) for instr in instrument_categories.keys()]

In [None]:
def hz_to_midi(frequency):
    if frequency <= 0:
        raise ValueError("Frequency must be greater than 0 Hz.")
    return 69 + 12 * np.log2(frequency / 440.0)

def midi_to_hz(midi_note):
    return 440.0 * 2**((midi_note - 69) / 12)

hz_to_midi(440), midi_to_hz(69)

In [None]:
def create_piano_roll(
    note_data,
    ticks_per_beat,
    chroma=False,
    target_ticks_per_beat=24,
    instrument_overtones=False,
    separate_drums=False
):
    if len(note_data) == 0:
        return None
    num_notes = 12 if chroma else 128
    duration_ticks = note_data[-1]["time"] + note_data[-1]["duration"]
    piano_roll = np.zeros((3, num_notes, duration_ticks))

    for note in note_data:
        # fixed duration for drum tracks since we only need the onsets
        drum_track = note["channel"] == 9
        duration = 1 if drum_track else note["duration"]

        start = note["time"]
        end = min(start + duration, duration_ticks)
        if end - start <= 0:
            continue

        pitch_class = note["note"] % 12 if chroma else note["note"]

        velocity = note["velocity"]
        piano_roll_channel = 2 if drum_track and separate_drums else 0
        piano_roll[piano_roll_channel, pitch_class, start:end] = velocity
        if not instrument_overtones:
            piano_roll[1, pitch_class, start:end] = velocity

        if drum_track and not separate_drums:
            piano_roll[0, pitch_class, start:end] = velocity

        # Add overtones
        if instrument_overtones and not drum_track:
            program = note["program"]
            harmonics, intensities = instrument_overtone_intensities(program)
            pitch = midi_to_hz(note["note"])
            max_intensity = intensities[0]
            for harmonic, intensity in zip(harmonics, intensities):
                overtone_pitch = pitch * harmonic
                overtone_midi = hz_to_midi(overtone_pitch)
                overtone_pitch_class = overtone_midi % 12 if chroma else overtone_midi
                overtone_pitch_class = int(np.round(overtone_pitch_class))
                if overtone_pitch_class <= 127:
                    decay = np.linspace(1.0, 0.0, end - start) * intensity / max_intensity
                    piano_roll[1, overtone_pitch_class, start:end] = velocity * intensity * decay

    # Downsample to target_ticks_per_beat ticks per beat using max pooling
    if ticks_per_beat > target_ticks_per_beat:
        pool_size = ticks_per_beat // target_ticks_per_beat
        try:
            piano_roll = F.max_pool1d(torch.tensor(piano_roll), pool_size, stride=pool_size).numpy()
        except Exception as e:
            print(e)
            print(piano_roll.shape)
            return None
    

    return piano_roll

In [None]:
def plot_piano_roll(piano_roll, start_tick, num_ticks, label=None, markers_ticks=None):
    if markers_ticks is None:
        markers_ticks = []
    plt.figure(figsize=(12, 6))
    # 3 plots, one for the main notes, one for the overtones, and one for the drums

    x_min = start_tick
    x_max = start_tick + num_ticks
    y_min = 0
    y_max = piano_roll.shape[-2]
    extent = (x_min, x_max, y_min, y_max)

    plt.subplot(3, 1, 1)
    for marker in markers_ticks:
        if marker >= x_min and marker <= x_max:
            plt.axvline(marker, color="red", linestyle="--")
    plt.imshow(piano_roll[0, :, start_tick:start_tick + num_ticks], aspect="auto", origin="lower", cmap="viridis", extent=extent)

    plt.subplot(3, 1, 2)
    plt.imshow(piano_roll[1, :, start_tick:start_tick + num_ticks], aspect="auto", origin="lower", cmap="viridis", extent=extent)

    plt.subplot(3, 1, 3)
    plt.imshow(piano_roll[2, :, start_tick:start_tick + num_ticks], aspect="auto", origin="lower", cmap="viridis", extent=extent)

    plt.xlabel("Time (ticks)")
    plt.ylabel("Note")
    plt.title(label)

In [None]:
def random_take(one_in_n: int) -> bool:
    return (torch.randint(0, one_in_n, ()) < 1).bool().item()

## Test Example

In [None]:
lakh_test_example = "0148c9c216484115f87daac532ef57db"

midi_path = lakh_midi_dir / Path(f"{lakh_test_example[0]}") / Path(f"{lakh_test_example}.mid")
midi = mido.MidiFile(midi_path)

track_data, markers_ticks, _ = parse_midi(midi_path)

track_data.keys()

#### Create Piano Roll

In [None]:
piano_rolls = []
# drum_piano_roll = None  # Separate channel
for track_name, note_data in track_data.items():
    piano_roll = create_piano_roll(
        note_data,
        midi.ticks_per_beat,
        chroma=False,
        target_ticks_per_beat=TARGET_TICKS_PER_BEAT,
        instrument_overtones=INSTRUMENT_OVERTONES,
        separate_drums=SEPARATE_DRUMS
    )
    # Some tracks are empty
    if piano_roll is None:
        continue
    piano_rolls.append(piano_roll)

actual_length = max(piano_roll.shape[-1] for piano_roll in piano_rolls)
for i, piano_roll in enumerate(piano_rolls):
    piano_rolls[i] = torch.nn.functional.pad(torch.tensor(piano_roll), (0, actual_length - piano_roll.shape[-1]))

print([piano_roll.shape for piano_roll in piano_rolls])
piano_roll = torch.stack(piano_rolls)
# Merge channels
piano_roll = piano_roll.sum(dim=0).clamp(0, 127)

In [None]:
piano_roll.shape

#### Plot Piano Roll

In [None]:
if midi.ticks_per_beat > TARGET_TICKS_PER_BEAT:
    downsample_factor = midi.ticks_per_beat // TARGET_TICKS_PER_BEAT
    markers_ticks_downsampled = [marker // downsample_factor for marker in markers_ticks]

track_name = "track_4"
drum_track = "drum" in track_name.lower()   # TODO: look for channel 9 instead
piano_roll_bass = create_piano_roll(track_data[track_name], midi.ticks_per_beat, chroma=False, instrument_overtones=True)

plot_piano_roll(piano_roll_bass, 9233 - WINDOW_HALF_TICKS, WINDOW_HALF_TICKS * 2, label=track_name, markers_ticks=markers_ticks_downsampled)

## Dataset Preparation

#### Precompute piano rolls and save features and labels in a structured manner.

In [None]:
good_files = json.load(open("../data/good_files_3_7_2025.json"))

DATA_DIR = Path(f"/Volumes/ExtremePro/lakh_data_{TARGET_TICKS_PER_BEAT}_overtones_{INSTRUMENT_OVERTONES}_separate_drums_{SEPARATE_DRUMS}")
def create_lakh_dataset():
    """
    Loads MIDI files from the Lakh MIDI dataset, processes them into piano rolls,
    and saves them in a structured directory format for training, validation, and testing.
    The dataset is split into training, validation, and test sets based on the provided good files.
    The processed data is saved in PyTorch tensor format.
    The directory structure is as follows:
    - DATA_DIR/
        - tubb_train/
        - non_tubb_train/
        - tubb_val/
        - non_tubb_val/
        - tubb_test/
        - non_tubb_test/
    """
    if not DATA_DIR.exists():
        DATA_DIR.mkdir()

    Path(DATA_DIR / "tubb_train").mkdir(exist_ok=True)
    Path(DATA_DIR / "non_tubb_train").mkdir(exist_ok=True)
    Path(DATA_DIR / "tubb_val").mkdir(exist_ok=True)
    Path(DATA_DIR / "non_tubb_val").mkdir(exist_ok=True)
    Path(DATA_DIR / "tubb_test").mkdir(exist_ok=True)
    Path(DATA_DIR / "non_tubb_test").mkdir(exist_ok=True)

    measure_qns_all = json.load(open("../data/measures_qn.json"))
    for key in good_files:
        print(f"Processing files: {key}")
        for test_example in tqdm(good_files[key], desc="Loading test examples"):
            save_path = DATA_DIR / Path(key) / Path(f"{test_example}.pt")
            if save_path.exists():
                continue

            measure_qns = measure_qns_all[test_example]
            midi_path = lakh_midi_dir / Path(f"{test_example[0]}") / Path(test_example + ".mid")
            if not midi_path.exists():
                print(f"Missing MIDI file: {midi_path}")
                continue

            # MIDI
            try:
                track_data, markers_ticks, ticks_per_beat = parse_midi(midi_path)
            except Exception as e:
                print(f"Error loading MIDI file: {midi_path}")
                print(e)
                continue

            # Annotation
            if ticks_per_beat > TARGET_TICKS_PER_BEAT:
                markers_ticks = [int(round(marker * TARGET_TICKS_PER_BEAT / ticks_per_beat)) for marker in markers_ticks]
                measure_ticks = [int(round(qn * TARGET_TICKS_PER_BEAT)) for qn in measure_qns]
            else:
                print(f"Skipping {test_example} due to downsample factor")
                continue

            piano_rolls = []
            # drum_piano_roll = None  # Separate channel
            for track_name, note_data in track_data.items():
                piano_roll = create_piano_roll(
                    note_data,
                    ticks_per_beat,
                    chroma=False,
                    target_ticks_per_beat=TARGET_TICKS_PER_BEAT,
                    instrument_overtones=INSTRUMENT_OVERTONES,
                    separate_drums=SEPARATE_DRUMS
                )
                # Some tracks are empty
                if piano_roll is None:
                    continue
                piano_rolls.append(piano_roll)

            if len(piano_rolls) == 0:
                print(f"Skipping {test_example} due to empty piano rolls")
                continue

            actual_length = max(piano_roll.shape[-1] for piano_roll in piano_rolls)
            for i, piano_roll in enumerate(piano_rolls):
                piano_rolls[i] = torch.nn.functional.pad(torch.tensor(piano_roll), (0, actual_length - piano_roll.shape[-1]))

            piano_roll = torch.stack(piano_rolls)
            # Merge channels
            piano_roll = piano_roll.sum(dim=0).clamp(0, 127)

            # Additionally pad 4 bars to each side to allow for segment extraction
            # piano_roll = torch.nn.functional.pad(piano_roll, (WINDOW_HALF_TICKS, WINDOW_HALF_TICKS))
            # markers_ticks = [marker + WINDOW_HALF_TICKS for marker in markers_ticks]
            # measure_ticks = [measure_tick + WINDOW_HALF_TICKS for measure_tick in measure_ticks]

            torch.save({
                "piano_roll": piano_roll.to(torch.float32),
                "segment_boundaries": torch.tensor(markers_ticks).to(torch.float32),
                "measure_ticks": torch.tensor(measure_ticks).to(torch.float32)
            }, save_path)
            
create_lakh_dataset()

#### Load and process (padding, etc.) precomputed piano rolls and prepare the patch metadata DataFrame

In [None]:
# Paths defined above in create_lakh_dataset()
piano_roll_paths = [path for path in (DATA_DIR / "tubb_train").iterdir() if path.suffix == ".pt" and not path.name.startswith(".")] + \
                   [path for path in (DATA_DIR / "non_tubb_train").iterdir() if path.suffix == ".pt" and not path.name.startswith(".")] + \
                   [path for path in (DATA_DIR / "tubb_val").iterdir() if path.suffix == ".pt" and not path.name.startswith(".")] + \
                   [path for path in (DATA_DIR / "non_tubb_val").iterdir() if path.suffix == ".pt" and not path.name.startswith(".")] + \
                     [path for path in (DATA_DIR / "tubb_test").iterdir() if path.suffix == ".pt" and not path.name.startswith(".")] + \
                     [path for path in (DATA_DIR / "non_tubb_test").iterdir() if path.suffix == ".pt" and not path.name.startswith(".")]

def get_piano_rolls():
    """
    Load piano rolls from the specified paths, process them, and return a list of piano rolls
    and a dictionary of patch data.
    Handles positive oversampling and negative undersampling and pads the piano rolls if specified.
    """
    padding = WINDOW_HALF_TICKS

    piano_rolls = []
    patch_data = {}
    sample_idx = 0
    piano_roll_idx = 0

    for piano_roll_path in tqdm(piano_roll_paths, desc="Loading inputs and labels"):
        try:
            data = torch.load(piano_roll_path)
        except RuntimeError:
            print(f"Error loading {piano_roll_path}")
            continue

        positive_oversampling_factor = POSITIVE_OVERSAMPLING_FACTOR if 'train' in str(piano_roll_path) else 1
        negative_undersampling_factor = NEGATIVE_UNDERSAMPLING_FACTOR if 'train' in str(piano_roll_path) else 1

        piano_roll = data["piano_roll"]
        segment_boundaries = data["segment_boundaries"]
        measure_boundaries = data["measure_ticks"]

        # Compute first and last nonzero columns of the first channel (first and last onset, respectively)
        if piano_roll.dim() == 4:
            batch_mask = piano_roll[0]  # Select the first batch 
        else:
            batch_mask = piano_roll
        channel_mask = batch_mask[0]  # Select the first channel

        # Find nonzero column indices
        nonzero_indices = channel_mask.nonzero(as_tuple=True)
        if nonzero_indices[1].numel() > 0:
            first_nonzero_column = nonzero_indices[1].min().item()
            last_nonzero_column = nonzero_indices[1].max().item()
        else:
            continue

        # Throw out markers before first onset or after last onset
        segment_boundaries = segment_boundaries[segment_boundaries > first_nonzero_column]
        segment_boundaries = segment_boundaries[segment_boundaries < last_nonzero_column]
        measure_boundaries = measure_boundaries[measure_boundaries > first_nonzero_column]
        measure_boundaries = measure_boundaries[measure_boundaries < last_nonzero_column]

        # Add first and last nonzero column to the segment boundaries
        segment_boundaries = torch.cat([
            torch.tensor([first_nonzero_column], dtype=torch.float32, device=piano_roll.device),
            segment_boundaries,
            torch.tensor([last_nonzero_column], dtype=torch.float32, device=piano_roll.device)
        ])
        measure_boundaries = torch.cat([
            torch.tensor([first_nonzero_column], dtype=torch.float32, device=piano_roll.device),
            measure_boundaries,
            torch.tensor([last_nonzero_column], dtype=torch.float32, device=piano_roll.device)
        ])

        # Crop piano roll to the first and last onset
        piano_roll = piano_roll[..., first_nonzero_column:last_nonzero_column + 1]
        # Adjust segment boundaries to the cropped piano roll
        segment_boundaries -= first_nonzero_column
        measure_boundaries -= first_nonzero_column

        # Pad piano roll to the left and right for boundary segment extraction
        if PAD_PIANO_ROLL:
            piano_roll = F.pad(piano_roll, (padding, padding), mode='constant', value=0)
            segment_boundaries += padding
            measure_boundaries += padding

        piano_rolls.append(piano_roll)

        for i in measure_boundaries:
            if not PAD_PIANO_ROLL and (i - padding <= 0 or i + padding >= piano_roll.shape[-1]):
                continue

            is_segment_boundary = (segment_boundaries == i).any().item()
            repetitions = positive_oversampling_factor if is_segment_boundary == 1. else int(random_take(one_in_n=negative_undersampling_factor))

            nearest_segment_boundary = segment_boundaries[torch.argmin(torch.abs(segment_boundaries - i))].item()

            sample = {
                # Metadata
                "filename": piano_roll_path.stem,
                "from": i - padding,
                "to": i + padding,
                # Data
                "piano_roll_idx": piano_roll_idx,
                "patch_idx": i,
                "is_segment_boundary": is_segment_boundary,
                "key": piano_roll_path.parent.stem, # non_tubb_train, non_tubb_val, tubb_train, tubb_val

                # New: nearest segment boundary
                "nearest_segment_boundary": nearest_segment_boundary
            }

            for _ in range(repetitions):
                patch_data[sample_idx] = sample
                sample_idx += 1

        piano_roll_idx += 1
    return piano_rolls, patch_data


### Patch statistics

In [None]:
piano_rolls, patch_data = get_piano_rolls()
metadata_df = pd.DataFrame.from_dict(patch_data, orient="index")

num_piano_rolls = len(piano_rolls)
num_patches = metadata_df.shape[0]

train_mask = metadata_df["key"].str.contains("train")
val_mask = metadata_df["key"].str.contains("val")
test_mask = metadata_df["key"].str.contains("test")

segment_boundary_mask = metadata_df["is_segment_boundary"] == True

num_train_patches = metadata_df[train_mask].shape[0]
num_positive_train_patches = metadata_df[train_mask & segment_boundary_mask].shape[0]
num_negative_train_patches = metadata_df[train_mask & ~segment_boundary_mask].shape[0]

num_val_patches = metadata_df[val_mask].shape[0]
num_positive_val_patches = metadata_df[val_mask & segment_boundary_mask].shape[0]
num_negative_val_patches = metadata_df[val_mask & ~segment_boundary_mask].shape[0]

num_test_patches = metadata_df[test_mask].shape[0]
num_positive_test_patches = metadata_df[test_mask & segment_boundary_mask].shape[0]
num_negative_test_patches = metadata_df[test_mask & ~segment_boundary_mask].shape[0]

print(f"Total number of piano rolls: {num_piano_rolls}")
print(f"Total number of patches: {num_patches}")
print(f"Number of train patches: {num_train_patches}")
print(f"Number of positive train patches: {num_positive_train_patches}")
print(f"Number of negative train patches: {num_negative_train_patches}")
print(f"Number of val patches: {num_val_patches}")
print(f"Number of positive val patches: {num_positive_val_patches}")
print(f"Number of negative val patches: {num_negative_val_patches}")
print(f"Number of test patches: {num_test_patches}")
print(f"Number of positive test patches: {num_positive_test_patches}")
print(f"Number of negative test patches: {num_negative_test_patches}")

### Torch PianoRollDataset Interface

In [None]:
def transpose_augmentation(piano_roll, transpose_range=6):
    transpose_amount = torch.randint(-transpose_range, transpose_range, ())
    return torch.roll(piano_roll, transpose_amount.item(), dims=-2)


class PianoRollDataset(Dataset):
    def __init__(self, piano_rolls, metadata_df, normalize=False, transpose_augmentation=True):
        self.piano_rolls = piano_rolls
        self.metadata_df = metadata_df
        self.normalize = normalize
        self.transpose_augmentation = transpose_augmentation

    def __len__(self):
        return len(self.metadata_df)

    def __getitem__(self, idx):
        sample = self.metadata_df.loc[idx]
        piano_roll = self.piano_rolls[sample["piano_roll_idx"]]
        patch = piano_roll[..., sample["from"].int():sample["to"].int()]

        center = sample["from"] + (sample["to"] - sample["from"]) / 2
        nearest_segment_boundary = sample["nearest_segment_boundary"]

        # targets: boundary at center? boundary within (2, 4, 8) bars of center?
        # TODO: assumes a single segment boundary per patch
        main_target = [sample["is_segment_boundary"]]
        additional_targets = [(nearest_segment_boundary - center).abs() <= 2**i * TARGET_TICKS_PER_BEAT * 4 for i in range(NUM_TARGETS - 1)]
        
        targets = torch.tensor(main_target + additional_targets).to(torch.float32)

        if self.normalize:
            patch = patch / patch.max() 

        if self.transpose_augmentation:
            patch = transpose_augmentation(patch)

        return patch, targets
    
    def metadata_at(self, idx):
        sample = self.metadata_df.loc[idx]
        return {
            "filename": sample["filename"],
            "from": sample["from"],
            "to": sample["to"]
        }

In [None]:
metadata_df["is_segment_boundary"].value_counts()

### Visualize Dataset Patches

In [None]:
dataset = PianoRollDataset(piano_rolls, metadata_df, normalize=True)

sample_idx = int(torch.randint(0, len(dataset), ()))
patch, targets = dataset[sample_idx]
is_segment_boundary = targets[0]

patch_info = dataset.metadata_at(sample_idx)
from_tick, to_tick, filename = patch_info["from"], patch_info["to"], patch_info["filename"]
print(f"Sample {sample_idx} from {filename} ({from_tick} to {to_tick}), is_segment_boundary: {is_segment_boundary.item()}")

plt.figure(figsize=(6.4*2, 4*2))

cmap = 'gist_yarg'
center_tick = (from_tick + to_tick) / 2

# Plot channel 0
plt.subplot(3, 1, 1)
plt.imshow(patch.squeeze().permute(1, 2, 0)[..., 0], aspect="auto", origin="lower", cmap=cmap, extent=[from_tick, to_tick, 0, patch.shape[-2]])
plt.axvline(x=center_tick, color="r", alpha=0.5)
plt.xlabel("Time (ticks)")
plt.ylabel("Note")
plt.title(f"Channel 0 (Piano Roll)")

# Plot channel 1
plt.subplot(3, 1, 2)
plt.imshow(patch.squeeze().permute(1, 2, 0)[..., 1], aspect="auto", origin="lower", cmap=cmap, extent=[from_tick, to_tick, 0, patch.shape[-2]])
plt.axvline(x=center_tick, color="r", alpha=0.5)
plt.xlabel("Time (ticks)")
plt.ylabel("Note")
plt.title("Channel 1 (Overtones)")

# Plot channel 2
plt.subplot(3, 1, 3)
plt.imshow(patch.squeeze().permute(1, 2, 0)[..., 2], aspect="auto", origin="lower", cmap=cmap, extent=[from_tick, to_tick, 0, patch.shape[-2]])
plt.axvline(x=center_tick, color="r", alpha=0.5)
plt.xlabel("Time (ticks)")
plt.ylabel("Note")
plt.title("Channel 2 (Drums)")

# # Plot combined
# plt.subplot(4, 1, 4)
# plt.imshow(patch.squeeze().permute(1, 2, 0), aspect="auto", origin="lower", cmap="gray", extent=[patch_info["from"], patch_info["to"], 0, patch.shape[-2]])
# plt.axvline(x=(patch_info["from"] + patch_info["to"]) / 2, color="r", alpha=0.5)
# plt.xlabel("Time (ticks)")
# plt.ylabel("Note")
# plt.title(f"Combined (RGB representation)")

plt.tight_layout()

SAVE_PLOT = False
if SAVE_PLOT:
    plt.savefig(os.path.expanduser(f"~/Downloads/patch_{filename}.png"))

## Training

#### Define the Model class

In [None]:
# Simple CNN boundary classifier
class BoundaryClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        weights = torchvision.models.MobileNet_V3_Small_Weights.DEFAULT
        backbone = torchvision.models.mobilenet_v3_small(weights=weights)
        backbone.classifier[-1] = nn.Sequential(
            nn.Linear(backbone.classifier[-1].in_features, NUM_TARGETS),
        )
        for layer in backbone.classifier:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight.data)
        self.backbone = backbone
        
    def forward(self, x):
        return self.backbone(x)

# class BoundaryClassifier(nn.Module):
#     def __init__(self):
#         super().__init__()

#         self.conv1 = nn.Conv2d(3, 32, kernel_size=(6, 8))
#         self.pool1 = nn.MaxPool2d(kernel_size=(6, 3))
#         self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 6))
#         self.dense = nn.Linear(64 * 18 * 504, 128)
#         self.out = nn.Linear(128, NUM_TARGETS)

#     def forward(self, x):
#         x = F.relu(self.conv1(x))
#         x = self.pool1(x)
#         x = F.relu(self.conv2(x))
#         x = x.view(x.size(0), -1)
#         x = F.relu(self.dense(x))
#         return self.out(x)

### Metrics

In [None]:
def acc_prec_recall(input, target):
    tp_count = ((input > 0) & (target == 1)).sum().item()
    fp_count = ((input > 0) & (target == 0)).sum().item()
    tn_count = ((input <= 0) & (target == 0)).sum().item()
    fn_count = ((input <= 0) & (target == 1)).sum().item()

    accuracy = (tp_count + tn_count) / (tp_count + tn_count + fp_count + fn_count) if tp_count + tn_count + fp_count + fn_count > 0 else 0
    precision = tp_count / (tp_count + fp_count) if tp_count + fp_count > 0 else 0
    recall = tp_count / (tp_count + fn_count) if tp_count + fn_count > 0 else 0

    return accuracy, precision, recall

In [None]:
def compute_metrics(input, target):
    results = {}
    for i in range(input.size(-1)):
        acc, prec, recall = acc_prec_recall(input[..., i], target[..., i])
        results["accuracy_" + str(i)] = acc
        results["precision_" + str(i)] = prec
        results["recall_" + str(i)] = recall
    return results

### Dataset Initialization

In [None]:
BATCH_SIZE = 32
NUM_EPOCHS = 30
RESUME_TRAINING = False

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")

In [None]:
metadata_df = metadata_df.sample(frac=1)
metadata_train = metadata_df[metadata_df["key"].isin(["tubb_train", "non_tubb_train"])]
metadata_val_tubb = metadata_df[metadata_df["key"] == "tubb_val"]
metadata_val_non_tubb = metadata_df[metadata_df["key"] == "non_tubb_val"]
metadata_train.reset_index(drop=True, inplace=True)
metadata_val_tubb.reset_index(drop=True, inplace=True)
metadata_val_non_tubb.reset_index(drop=True, inplace=True)

dataset_train = PianoRollDataset(piano_rolls, metadata_train, normalize=PATCH_NORMALIZE)
dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)

dataset_val_tubb = PianoRollDataset(piano_rolls, metadata_val_tubb, normalize=PATCH_NORMALIZE)
dataloader_val_tubb = DataLoader(dataset_val_tubb, batch_size=BATCH_SIZE, shuffle=False)

dataset_val_non_tubb = PianoRollDataset(piano_rolls, metadata_val_non_tubb, normalize=PATCH_NORMALIZE)
dataloader_val_non_tubb = DataLoader(dataset_val_non_tubb, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
writer = SummaryWriter()

SAVE_PATH = Path("./models")
SAVE_PATH.mkdir(exist_ok=True)
model_name = Path(f"pretrain_{PRETRAIN}_mn_overtones_{INSTRUMENT_OVERTONES}_normalized_{PATCH_NORMALIZE}_separate_drums_{SEPARATE_DRUMS}_targets_{NUM_TARGETS}.pt")

model = BoundaryClassifier().to(device)

if RESUME_TRAINING and (model_path := SAVE_PATH / model_name).exists():
    print(f"Loading model from {model_path}")
    model.load_state_dict(torch.load(model_path))

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

train_losses = []

val_losses_tubb = []
val_accuracies_tubb = []
val_precisions_tubb = []
val_recalls_tubb = []

val_losses_non_tubb = []
val_accuracies_non_tubb = []
val_precisions_non_tubb = []
val_recalls_non_tubb = []

metrics_tubb_all = []
metrics_non_tubb_all = []

best_model = model
best_val_f1 = 0
epochs_without_improvement = 0

for epoch in range(NUM_EPOCHS):
    # Reload piano rolls and reinitialize dataloaders if we have negative undersampling to reshuffle
    if NEGATIVE_UNDERSAMPLING_FACTOR:
        piano_rolls, patch_data = get_piano_rolls()
        metadata_df = pd.DataFrame.from_dict(patch_data, orient="index")

        metadata_df = metadata_df.sample(frac=1)
        metadata_train = metadata_df[metadata_df["key"].isin(["tubb_train", "non_tubb_train"])]
        metadata_val_tubb = metadata_df[metadata_df["key"] == "tubb_val"]
        metadata_val_non_tubb = metadata_df[metadata_df["key"] == "non_tubb_val"]
        metadata_train.reset_index(drop=True, inplace=True)
        metadata_val_tubb.reset_index(drop=True, inplace=True)
        metadata_val_non_tubb.reset_index(drop=True, inplace=True)

        train_fnames = set(metadata_train['filename'])
        val_tubb_fnames = set(metadata_val_tubb['filename'])
        val_non_tubb_fnames = set(metadata_val_non_tubb['filename'])

        dataset_train = PianoRollDataset(piano_rolls, metadata_train, normalize=PATCH_NORMALIZE)
        dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)

        dataset_val_tubb = PianoRollDataset(piano_rolls, metadata_val_tubb, normalize=PATCH_NORMALIZE)
        dataloader_val_tubb = DataLoader(dataset_val_tubb, batch_size=BATCH_SIZE, shuffle=False)

        dataset_val_non_tubb = PianoRollDataset(piano_rolls, metadata_val_non_tubb, normalize=PATCH_NORMALIZE)
        dataloader_val_non_tubb = DataLoader(dataset_val_non_tubb, batch_size=BATCH_SIZE, shuffle=False)

    # Log some labeled images from the training dataloader
    example_batch = next(iter(dataloader_train))
    example_images, example_targets = example_batch

    for i in range(min(4, len(example_images))):
        label = example_targets[i]
        writer.add_image(
            tag=f"Train/Example_image_{i}_label_{label}",
            img_tensor=example_images[i],
            global_step=epoch,
            dataformats="CHW",
        )

    model.train()
    train_loss = 0
    step = 0
    for batch in (pbar := tqdm(dataloader_train)):
        pbar.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
        piano_roll, targets = batch
        piano_roll, targets = piano_roll.to(device), targets.to(device)

        optimizer.zero_grad()
        output = model(piano_roll)
        loss = criterion(output, targets.float().to(device))
        writer.add_scalar("Loss/Train", loss, epoch * len(dataloader_train) + step)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix({"Train Loss (current/average)": f"{loss.item():.4f}/{train_loss / (step + 1):.4f}"})

        step += 1
    writer.flush()
    train_loss /= len(dataloader_train)

    model.eval()
    val_outputs_tubb = []
    val_targets_tubb = []
    val_outputs_non_tubb = []
    val_targets_non_tubb = []
    with torch.no_grad():
        val_loss_tubb = 0
        val_loss_non_tubb = 0
        for batch_tubb in dataloader_val_tubb:
            piano_roll, targets = batch_tubb
            piano_roll, targets = piano_roll.to(device), targets.to(device)

            output = model(piano_roll)

            val_outputs_tubb.append(output)
            val_targets_tubb.append(targets)

            loss = criterion(output, targets.float().to(device))
            val_loss_tubb += loss.item()
        for batch_non_tubb in dataloader_val_non_tubb:
            piano_roll, targets = batch_non_tubb
            piano_roll, targets = piano_roll.to(device), targets.to(device)

            output = model(piano_roll)

            val_outputs_non_tubb.append(output)
            val_targets_non_tubb.append(targets)

            loss = criterion(output, targets.float().to(device))
            val_loss_non_tubb += loss.item()

        val_loss_tubb /= len(dataloader_val_tubb)
        val_loss_non_tubb /= len(dataloader_val_non_tubb)

    train_losses.append(train_loss)

    metrics_tubb = compute_metrics(torch.cat(val_outputs_tubb), torch.cat(val_targets_tubb))
    val_losses_tubb.append(val_loss_tubb)
    val_accuracies_tubb.append(metrics_tubb["accuracy_0"])
    val_precisions_tubb.append(metrics_tubb["precision_0"])
    val_recalls_tubb.append(metrics_tubb["recall_0"])

    metrics_non_tubb = compute_metrics(torch.cat(val_outputs_non_tubb), torch.cat(val_targets_non_tubb))
    val_losses_non_tubb.append(val_loss_non_tubb)
    val_accuracies_non_tubb.append(metrics_non_tubb["accuracy_0"])
    val_precisions_non_tubb.append(metrics_non_tubb["precision_0"])
    val_recalls_non_tubb.append(metrics_non_tubb["recall_0"])

    metrics_tubb_all.append(metrics_tubb)
    metrics_non_tubb_all.append(metrics_non_tubb)

    f1_tubb = 2 * (metrics_tubb["precision_0"] * metrics_tubb["recall_0"]) / (metrics_tubb["precision_0"] + metrics_tubb["recall_0"]) if metrics_tubb["precision_0"] + metrics_tubb["recall_0"] > 0 else 0
    f1_non_tubb = 2 * (metrics_non_tubb["precision_0"] * metrics_non_tubb["recall_0"]) / (metrics_non_tubb["precision_0"] + metrics_non_tubb["recall_0"]) if metrics_non_tubb["precision_0"] + metrics_non_tubb["recall_0"] > 0 else 0
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {train_loss:.4f}, Val Loss (Tubb/Non-Tubb): ({val_loss_tubb:.4f}, {val_loss_non_tubb:.4f}), Val F1 (Tubb/Non-Tubb): ({f1_tubb:.4f}, {f1_non_tubb:.4f})")

    val_f1 = (f1_tubb + f1_non_tubb) / 2

    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_model = model
        torch.save(model.state_dict(), SAVE_PATH / model_name)
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= 5:
            print(f"Early stopping at epoch {epoch + 1}")
            break

    # --- Log metrics to TensorBoard ---
    writer.add_scalar("Loss/Val_Tubb", val_loss_tubb, epoch)
    writer.add_scalar("Accuracy/Val_Tubb", metrics_tubb["accuracy_0"], epoch)
    writer.add_scalar("Precision/Val_Tubb", metrics_tubb["precision_0"], epoch)
    writer.add_scalar("Recall/Val_Tubb", metrics_tubb["recall_0"], epoch)

    writer.add_scalar("Loss/Val_Non_Tubb", val_loss_non_tubb, epoch)
    writer.add_scalar("Accuracy/Val_Non_Tubb", metrics_non_tubb["accuracy_0"], epoch)
    writer.add_scalar("Precision/Val_Non_Tubb", metrics_non_tubb["precision_0"], epoch)
    writer.add_scalar("Recall/Val_Non_Tubb", metrics_non_tubb["recall_0"], epoch)

    writer.flush()

writer.close()

# save metrics_tubb_all, metrics_non_tubb_all as json
json.dump(metrics_tubb_all, open(f"./metrics_tubb_all_overtones_{INSTRUMENT_OVERTONES}_normalized_{PATCH_NORMALIZE}_separate_drums_{SEPARATE_DRUMS}_targets_{NUM_TARGETS}_2.json", "w"))
json.dump(metrics_non_tubb_all, open(f"./metrics_non_tubb_all_overtones_{INSTRUMENT_OVERTONES}_normalized_{PATCH_NORMALIZE}_separate_drums_{SEPARATE_DRUMS}_targets_{NUM_TARGETS}_2.json", "w"))

In [None]:
2 * (np.mean(val_precisions_tubb) * np.mean(val_recalls_tubb)) / (np.mean(val_precisions_tubb) + np.mean(val_recalls_tubb)), 2 * (np.mean(val_precisions_non_tubb) * np.mean(val_recalls_non_tubb)) / (np.mean(val_precisions_non_tubb) + np.mean(val_recalls_non_tubb))

In [None]:
torch.save(model.state_dict(), SAVE_PATH / model_name)

In [None]:
val_accuracies_non_tubb

In [None]:
# Plot metrics
fig, ax = plt.subplots(1, 4, figsize=(20, 5))
ax[0].plot(train_losses, label="Train Loss")
ax[0].plot(val_losses_tubb[::2], label="Val Loss")
ax[0].set_title("Loss")
ax[0].legend()

ax[1].plot(val_accuracies_tubb[::2], label="Accuracy")
ax[1].set_title("Accuracy")
ax[1].legend()

ax[2].plot(val_precisions_tubb[::2], label="Precision")
ax[2].set_title("Precision")
ax[2].legend()

ax[3].plot(val_recalls_tubb[::2], label="Recall")
ax[3].set_title("Recall")
ax[3].legend()

In [None]:
# num params in model
sum(p.numel() for p in model.parameters())

In [None]:
Path("../training_output").mkdir(parents=True, exist_ok=True)

# save metrics_tubb_all, metrics_non_tubb_all as json
json.dump(metrics_tubb_all, open(f"../training_output/val_metrics_tubb_all_overtones_{INSTRUMENT_OVERTONES}_normalized_{PATCH_NORMALIZE}_channels_{3}_targets_{NUM_TARGETS}_2.json", "w"))
json.dump(metrics_non_tubb_all, open(f"../training_output/val_metrics_non_tubb_all_overtones_{INSTRUMENT_OVERTONES}_normalized_{PATCH_NORMALIZE}_channels_{3}_targets_{NUM_TARGETS}_2.json", "w"))