In [None]:
import json
import math
import os
from pathlib import Path
import time

import mido
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from tqdm import tqdm

## Config

In [None]:
lakh_midi_dir = Path("../data/lmd_full")

In [None]:
SEED = 0

torch.random.manual_seed(SEED)
np.random.seed(SEED)

In [None]:
# Config
TARGET_TICKS_PER_BEAT = 4
WINDOW_HALF_TICKS = 256
INSTRUMENT_OVERTONES = True
SEPARATE_DRUMS = True
PATCH_NORMALIZE = True
PAD_PIANO_ROLL = False # For handling boundary patches

## Utils

In [None]:
MARKERS = json.load(open("../data/markers_qn.json"))

def parse_midi(file_path):
    """
    Parse a MIDI file into a list of bar segments per track.
    A bar segment is defined as a list of MIDI messages encoded as tuples that fit into a single bar.
    A tuple is defined as (time, note, velocity, duration, channel, program)
    """
    midi = mido.MidiFile(file_path, clip=True)

    track_data = {
        (track.name if track.name else f"track_{idx}"): []
        for idx, track in enumerate(midi.tracks)
    }

    file_name = os.path.basename(file_path).split('.mid')[0]
    marker_qns = MARKERS[file_name]
    markers_ticks = [int(round(x[0] * midi.ticks_per_beat)) for x in marker_qns]

    channel_volumes = {
        i: 127
        for i in range(16)
    }
    channel_expressions = {
        i: 127
        for i in range(16)
    }
    channel_instruments = {
        i: 0
        for i in range(16)
    }

    for idx, track in enumerate(midi.tracks):
        track_name = track.name if track.name else f"track_{idx}"
        current_ticks = 0
        for msg in track:
            current_ticks += msg.time
            if msg.type == "control_change":
                if msg.control == 7:
                    channel_volumes[msg.channel] = msg.value
                elif msg.control == 11:
                    channel_expressions[msg.channel] = msg.value
            elif msg.type == "program_change":
                channel_instruments[msg.channel] = msg.program
            elif msg.type == "marker":
                pass
            elif msg.type == "note_on" and msg.velocity > 0:
                velocity = msg.velocity * (channel_volumes[msg.channel] / 127.) * (
                            channel_expressions[msg.channel] / 127.)
                program = channel_instruments[msg.channel]
                track_data[track_name].append({
                    "time": current_ticks,
                    "note": msg.note,
                    "velocity": velocity,
                    "duration": -1,
                    "channel": msg.channel,
                    "program": program
                })
            elif msg.type == "note_off" or (msg.type == "note_on" and msg.velocity == 0):
                for note in track_data[track_name]:
                    if note["duration"] == -1 and note["note"] == msg.note and note["channel"] == msg.channel:
                        note["duration"] = current_ticks - note["time"]
                        break

    # Remove duplicate marker ticks
    markers_ticks = list(set(markers_ticks))
    markers_ticks.sort()

    return track_data, markers_ticks, midi.ticks_per_beat

In [None]:
def instrument_overtone_intensities(program, num_harmonics=3, max_harmonic=5):
    """
    Generate a set of harmonics and their intensities for a given instrument program.
    The harmonics are random but fixed for a given program.
    """
    np.random.seed(hash(str(program)) % 2**32)

    harmonics = np.sort(np.random.choice(max_harmonic, num_harmonics, replace=False) + 2)
    intensities = np.sort(np.random.rand(num_harmonics))[::-1]

    # Return to original seed
    np.random.seed(SEED)

    return harmonics, intensities

instrument_categories = {
    "Piano": range(0, 8), "Chromatic Percussion": range(8, 16),
    "Organ": range(16, 24), "Guitar": range(24, 32),
    "Bass": range(32, 40), "Strings": range(40, 48),
    "Ensemble": range(48, 56), "Brass": range(56, 64),
    "Reed": range(64, 72), "Pipe": range(72, 80),
    "Synth Lead": range(80, 88), "Synth Pad": range(88, 96),
    "Synth Effects": range(96, 104), "Ethnic": range(104, 112),
    "Percussive": range(112, 120), "Sound Effects": range(120, 128)
}

def hz_to_midi(frequency):
    if frequency <= 0:
        raise ValueError("Frequency must be greater than 0 Hz.")
    return 69 + 12 * np.log2(frequency / 440.0)

def midi_to_hz(midi_note):
    return 440.0 * 2**((midi_note - 69) / 12)

In [None]:
def create_piano_roll(
    note_data,
    ticks_per_beat,
    chroma=False,
    target_ticks_per_beat=24,
    instrument_overtones=False,
    separate_drums=False
):
    if len(note_data) == 0:
        return None
    num_notes = 12 if chroma else 128
    duration_ticks = note_data[-1]["time"] + note_data[-1]["duration"]
    piano_roll = np.zeros((3, num_notes, duration_ticks))

    for note in note_data:
        # fixed duration for drum tracks since we only need the onsets
        drum_track = note["channel"] == 9
        duration = 1 if drum_track else note["duration"]

        start = note["time"]
        end = min(start + duration, duration_ticks)
        if end - start <= 0:
            continue

        pitch_class = note["note"] % 12 if chroma else note["note"]

        velocity = note["velocity"]
        piano_roll_channel = 2 if drum_track and separate_drums else 0
        piano_roll[piano_roll_channel, pitch_class, start:end] = velocity
        if not instrument_overtones:
            piano_roll[1, pitch_class, start:end] = velocity

        if drum_track and not separate_drums:
            piano_roll[0, pitch_class, start:end] = velocity

        # Add overtones
        if instrument_overtones and not drum_track:
            program = note["program"]
            harmonics, intensities = instrument_overtone_intensities(program)
            pitch = midi_to_hz(note["note"])
            max_intensity = intensities[0]
            for harmonic, intensity in zip(harmonics, intensities):
                overtone_pitch = pitch * harmonic
                overtone_midi = hz_to_midi(overtone_pitch)
                overtone_pitch_class = overtone_midi % 12 if chroma else overtone_midi
                overtone_pitch_class = int(np.round(overtone_pitch_class))
                if overtone_pitch_class <= 127:
                    decay = np.linspace(1.0, 0.0, end - start) * intensity / max_intensity
                    piano_roll[1, overtone_pitch_class, start:end] = velocity * intensity * decay

    # Downsample to target_ticks_per_beat ticks per beat using max pooling
    if ticks_per_beat > target_ticks_per_beat:
        pool_size = ticks_per_beat // target_ticks_per_beat
        try:
            piano_roll = F.max_pool1d(torch.tensor(piano_roll), pool_size, stride=pool_size).numpy()
        except Exception as e:
            print(e)
            print(piano_roll.shape)
            return None
    

    return piano_roll

In [None]:
def novelty_peak_pick_gsu(novelty_function: torch.Tensor, window_size: int, threshold_window_left: int, threshold_window_right: int, threshold: float = 0.37) -> torch.Tensor:
    if novelty_function.shape[-1] < window_size:
        candidate_idx = novelty_function.argmax()
        return candidate_idx.unsqueeze(0) if novelty_function[candidate_idx] - novelty_function.mean() > threshold else torch.tensor([], dtype=torch.long, requires_grad=False, device=novelty_function.device)

    window_half = window_size // 2

    novelty_function = torch.cat((torch.zeros(*novelty_function.shape[:-1], window_half, device=novelty_function.device), novelty_function, torch.zeros(*novelty_function.shape[:-1], window_half, device=novelty_function.device)))

    windows = novelty_function.unfold(-1, window_size, 1)
    window_argmax = windows.argmax(dim=1)
    
    indices = torch.argwhere(window_argmax == window_half)[:, 0]
    candidates = novelty_function[indices + window_half]

    novelty_function = novelty_function[window_half:-window_half]

    if candidates.shape[-1] == 0:
        return torch.tensor([], dtype=torch.long, requires_grad=False, device=novelty_function.device)

    starts = torch.maximum(indices - threshold_window_left, torch.tensor(0))
    ends = torch.minimum(indices + threshold_window_right + 1, torch.tensor(novelty_function.shape[-1]))

    means = torch.tensor([torch.mean(novelty_function[start:end]) for start, end in zip(starts, ends)], device=candidates.device)
    candidates = candidates - means

    return indices[candidates > threshold]

### Metrics

In [None]:
def acc_prec_recall(input, target):
    tp_count = ((input > 0) & (target == 1)).sum().item()
    fp_count = ((input > 0) & (target == 0)).sum().item()
    tn_count = ((input <= 0) & (target == 0)).sum().item()
    fn_count = ((input <= 0) & (target == 1)).sum().item()

    accuracy = (tp_count + tn_count) / (tp_count + tn_count + fp_count + fn_count) if tp_count + tn_count + fp_count + fn_count > 0 else 0
    precision = tp_count / (tp_count + fp_count) if tp_count + fp_count > 0 else 0
    recall = tp_count / (tp_count + fn_count) if tp_count + fn_count > 0 else 0

    return accuracy, precision, recall

def compute_metrics(input, target):
    results = {}
    for i in range(input.size(-1)):
        acc, prec, recall = acc_prec_recall(input[..., i], target[..., i])
        results["accuracy_" + str(i)] = acc
        results["precision_" + str(i)] = prec
        results["recall_" + str(i)] = recall
    return results

In [None]:
from typing import List

def prec_recall(input: List[bool], target: List[bool]):
    tp = sum([1 for i, j in zip(input, target) if i and j])
    fp = sum([1 for i, j in zip(input, target) if i and not j])
    tn = sum([1 for i, j in zip(input, target) if not i and not j])
    fn = sum([1 for i, j in zip(input, target) if not i and j])

    prec = tp / (tp + fp) if tp + fp > 0 else 0
    recall = tp / (tp + fn) if tp + fn > 0 else 0

    return prec, recall

### Define model classes

In [None]:
# Simple CNN boundary classifier
class BoundaryClassifier(nn.Module):
    def __init__(self):
        super().__init__()

        weights = torchvision.models.MobileNet_V3_Small_Weights.DEFAULT
        backbone = torchvision.models.mobilenet_v3_small(weights=weights)
        backbone.classifier[-1] = nn.Sequential(
            nn.Linear(backbone.classifier[-1].in_features, NUM_TARGETS),
        )
        for layer in backbone.classifier:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight.data)
        self.backbone = backbone
        
    def forward(self, x):
        return self.backbone(x)

In [None]:
class BaggedBoundaryClassifier(nn.Module):
    def __init__(self, models):
        super().__init__()
        self.models = models

    def forward(self, x):
        return torch.stack([model(x)[..., 0] for model in self.models], dim=0).mean(dim=0)

#### Load trained models

In [None]:
device = torch.device("mps")

# num_targets = [5, 1, 1, 1, 1]
num_targets = [1, 1, 1, 1]
model_paths = [
    # Path('../models/pretrain_True_mn_overtones_True_normalized_True_separate_drums_True_targets_5_epoch_4 - best.pt'),
    Path('../models/pretrain_False_mn_overtones_True_normalized_True_separate_drums_True_targets_1_epoch_5 - best.pt'),
    Path('../models/pretrain_True_mn_overtones_False_normalized_True_separate_drums_False_targets_1_epoch_4 - best.pt'),
    Path('../models/pretrain_True_mn_overtones_False_normalized_True_separate_drums_True_targets_1_epoch_2 - best.pt'),
    Path('../models/pretrain_True_mn_overtones_True_normalized_True_separate_drums_True_targets_1_epoch_8 - best.pt')
]

models = []
for model_path, num_targets in zip(model_paths, num_targets):
    model = BoundaryClassifier(num_targets).to(device)
    if model_path.exists():
        print(f"Loading model from {model_path}")
        model.load_state_dict(torch.load(model_path, map_location=device))

    model = model.eval()
    models.append(model)

model = BaggedBoundaryClassifier(models).to(device).eval()

## Evaluation Loop

In [None]:
MAX_BATCH_SIZE = 128
CLASSIFICATION_THRESHOLD = 0.5
padding = WINDOW_HALF_TICKS

measure_qns_all = json.load(open("../data/measures_qn.json"))
good_files = json.load(open("../data/good_files_3_7_2025.json"))
test_files = {key: good_files[key] for key in good_files if key.endswith("test")}

evaluation_results = {}

targets = {}
results = {}
with torch.no_grad():
    for key in test_files:
        print(f"Processing files: {key}")
        evaluation_results[key] = []
        targets[key] = []
        results[key] = []

        for test_example in (pbar := tqdm(test_files[key], desc="Loading test examples")):
            measure_qns = measure_qns_all[test_example]
            midi_path = lakh_midi_dir / Path(f"{test_example[0]}") / Path(test_example + ".mid")

            # MIDI
            try:
                track_data, markers_ticks, ticks_per_beat = parse_midi(midi_path)
            except Exception as e:
                print(f"Error loading MIDI file: {midi_path}")
                print(e)
                continue

            # Annotation
            if ticks_per_beat > TARGET_TICKS_PER_BEAT:
                downsample_factor = ticks_per_beat // TARGET_TICKS_PER_BEAT
                markers_ticks = [marker // downsample_factor for marker in markers_ticks]
                measure_ticks = [int(round(qn * TARGET_TICKS_PER_BEAT)) for qn in measure_qns]
            else:
                print(f"Skipping {test_example} due to downsample factor")
                continue

            piano_rolls = []
            # drum_piano_roll = None  # Separate channel
            for track_name, note_data in track_data.items():
                piano_roll = create_piano_roll(
                    note_data,
                    ticks_per_beat,
                    chroma=False,
                    target_ticks_per_beat=TARGET_TICKS_PER_BEAT,
                    instrument_overtones=INSTRUMENT_OVERTONES,
                    separate_drums=SEPARATE_DRUMS
                )
                # Some tracks are empty
                if piano_roll is None:
                    continue
                piano_rolls.append(piano_roll)

            if len(piano_rolls) == 0:
                print(f"Skipping {test_example} due to empty piano rolls")
                continue

            actual_length = max(piano_roll.shape[-1] for piano_roll in piano_rolls)
            for i, piano_roll in enumerate(piano_rolls):
                piano_rolls[i] = torch.nn.functional.pad(torch.tensor(piano_roll),
                                                         (0, actual_length - piano_roll.shape[-1]))

            piano_roll = torch.stack(piano_rolls)
            # Merge channels
            piano_roll = piano_roll.sum(dim=0).clamp(0, 127).float().to(device)

            # Compute first and last nonzero columns of the first channel (first and last onset, respectively)
            batch_mask = piano_roll[0]  # Select the first batch

            # Find nonzero column indices
            nonzero_indices = batch_mask.nonzero(as_tuple=True)
            if nonzero_indices[1].numel() > 0:
                first_nonzero_column = nonzero_indices[1].min().item()
                last_nonzero_column = nonzero_indices[1].max().item()
            else:
                continue

            markers_ticks = torch.tensor(markers_ticks, device=device, dtype=torch.float32)
            measure_ticks = torch.tensor(measure_ticks, device=device, dtype=torch.float32)
            
            markers_ticks = markers_ticks[markers_ticks > first_nonzero_column]
            markers_ticks = markers_ticks[markers_ticks < last_nonzero_column]
            measure_ticks = measure_ticks[measure_ticks > first_nonzero_column]
            measure_ticks = measure_ticks[measure_ticks < last_nonzero_column]

            # Add first and last nonzero column to the segment boundaries
            markers_ticks = torch.cat([
                torch.tensor([first_nonzero_column], dtype=torch.float32, device=piano_roll.device),
                markers_ticks,
                torch.tensor([last_nonzero_column], dtype=torch.float32, device=piano_roll.device)
            ])
            measure_tickss = torch.cat([
                torch.tensor([first_nonzero_column], dtype=torch.float32, device=piano_roll.device),
                measure_ticks,
                torch.tensor([last_nonzero_column], dtype=torch.float32, device=piano_roll.device)
            ])

            # Crop piano roll to the first and last onset
            piano_roll = piano_roll[..., first_nonzero_column:last_nonzero_column + 1]
            # Adjust segment boundaries to the cropped piano roll
            markers_ticks -= first_nonzero_column
            measure_ticks -= first_nonzero_column

            if PAD_PIANO_ROLL:
                piano_roll = F.pad(piano_roll, (padding, padding), mode='constant', value=0)
                markers_ticks += padding
                measure_ticks += padding

            batches = []
            curr_targets = []
            tik = time.time()
            for measure_tick in measure_ticks:
                start = measure_tick - WINDOW_HALF_TICKS
                end = measure_tick + WINDOW_HALF_TICKS
                
                if not PAD_PIANO_ROLL and (start <= 0 or end >= piano_roll.shape[-1]):
                    continue

                curr_targets.append(True if measure_tick in markers_ticks else False)

                patch = piano_roll[..., start:end].float().to(device)
                if PATCH_NORMALIZE:
                    patch = patch / patch.max()

                batches.append(patch)

            if len(batches) == 0:
                continue

            batches = torch.stack(batches)

            tok = time.time()
            patch_preparation_time = tok - tik

            measure_predictions = []
            tik = time.time()
            for i in range(math.ceil(batches.shape[0] / MAX_BATCH_SIZE)):
                torch.mps.empty_cache()
                result = model(batches[i * MAX_BATCH_SIZE:i * MAX_BATCH_SIZE + MAX_BATCH_SIZE])
                measure_prediction = F.sigmoid(result)
                measure_predictions += measure_prediction.tolist()
            tok = time.time()
            prediction_time = tok - tik

            tik = time.time()
            peak_indices = novelty_peak_pick_gsu(torch.tensor(measure_predictions), 8, 8, 4)
            measure_predictions_peak_pick = torch.tensor([False for _ in range(len(curr_targets))])
            measure_predictions_peak_pick[peak_indices] = True
            tok = time.time()
            peak_pick_time = tok - tik

            results[key] += [prediction > CLASSIFICATION_THRESHOLD for prediction in measure_predictions]
            targets[key] += curr_targets

            prec_threshold, recall_threshold = prec_recall([prediction > CLASSIFICATION_THRESHOLD for prediction in measure_predictions], curr_targets)
            f1_threshold = 2 * prec_threshold * recall_threshold / (prec_threshold + recall_threshold) if prec_threshold + recall_threshold > 0 else 0

            prec_peak_pick, recall_peak_pick = prec_recall(measure_predictions_peak_pick, curr_targets)
            f1_peak_pick = 2 * prec_peak_pick * recall_peak_pick / (prec_peak_pick + recall_peak_pick) if prec_peak_pick + recall_peak_pick > 0 else 0

            evaluation_results[key].append({
                test_example: {
                    "f1_threshold": f1_threshold,
                    "f1_peak_pick": f1_peak_pick,
                    "prec_threshold": prec_threshold,
                    "recall_threshold": recall_threshold,
                    "prec_peak_pick": prec_peak_pick,
                    "recall_peak_pick": recall_peak_pick,
                    "patch_preparation_time_ms": patch_preparation_time * 1000,
                    "prediction_time_ms": prediction_time * 1000,
                    "peak_pick_time_ms": peak_pick_time * 1000
                }
            })
            f1_threshold_mean = np.mean([example_result["f1_threshold"] for x in evaluation_results[key] for example_result in x.values()])
            f1_peak_pick_mean = np.mean([example_result["f1_peak_pick"] for x in evaluation_results[key] for example_result in x.values()])

            pbar.set_postfix(f1_thresholding_mean=f1_threshold_mean, f1_peak_pick_mean=f1_peak_pick_mean)


## Results

In [None]:
prec_tubb, recall_tubb = prec_recall(results["tubb_test"], targets["tubb_test"])
f1_tubb = 2 * prec_tubb * recall_tubb / (prec_tubb + recall_tubb) if prec_tubb + recall_tubb > 0 else 0

prec_non_tubb, recall_non_tubb = prec_recall(results["non_tubb_test"], targets["non_tubb_test"])
f1_non_tubb = 2 * prec_non_tubb * recall_non_tubb / (prec_non_tubb + recall_non_tubb) if prec_non_tubb + recall_non_tubb > 0 else 0

prec_all, recall_all = prec_recall([result for key in results for result in results[key]], [target for key in targets for target in targets[key]])
f1_all = 2 * prec_all * recall_all / (prec_all + recall_all) if prec_all + recall_all > 0 else 0

print(f"Precision Tubb: {prec_tubb}")
print(f"Recall Tubb: {recall_tubb}")
print(f"F1 Tubb: {f1_tubb}")

print()

print(f"Precision Non-Tubb: {prec_non_tubb}")
print(f"Recall Non-Tubb: {recall_non_tubb}")
print(f"F1 Non-Tubb: {f1_non_tubb}")

print()

print(f"Precision All: {prec_all}")
print(f"Recall All: {recall_all}")
print(f"F1 All: {f1_all}")

print()

In [None]:
tubb_test_f1_mean = np.mean([example_result["f1_threshold"] for x in evaluation_results["tubb_test"] for example_result in x.values()])
tubb_test_prec_mean = np.mean([example_result["prec_threshold"] for x in evaluation_results["tubb_test"] for example_result in x.values()])
tubb_test_recall_mean = np.mean([example_result["recall_threshold"] for x in evaluation_results["tubb_test"] for example_result in x.values()])

non_tubb_test_f1_mean = np.mean([example_result["f1_threshold"] for x in evaluation_results["non_tubb_test"] for example_result in x.values()])
non_tubb_test_prec_mean = np.mean([example_result["prec_threshold"] for x in evaluation_results["non_tubb_test"] for example_result in x.values()])
non_tubb_test_recall_mean = np.mean([example_result["recall_threshold"] for x in evaluation_results["non_tubb_test"] for example_result in x.values()])

print(f"TUBB test set F1 mean: {tubb_test_f1_mean}")
print(f"TUBB test set Precision mean: {tubb_test_prec_mean}")
print(f"TUBB test set Recall mean: {tubb_test_recall_mean}")

print()

print(f"Non-TUBB test set F1 mean: {non_tubb_test_f1_mean}")
print(f"Non-TUBB test set Precision mean: {non_tubb_test_prec_mean}")
print(f"Non-TUBB test set Recall mean: {non_tubb_test_recall_mean}")

print()

tubb_test_f1_mean = np.mean([example_result["f1_peak_pick"] for x in evaluation_results["tubb_test"] for example_result in x.values()])
tubb_test_prec_mean = np.mean([example_result["prec_peak_pick"] for x in evaluation_results["tubb_test"] for example_result in x.values()])
tubb_test_recall_mean = np.mean([example_result["recall_peak_pick"] for x in evaluation_results["tubb_test"] for example_result in x.values()])

non_tubb_test_f1_mean = np.mean([example_result["f1_peak_pick"] for x in evaluation_results["non_tubb_test"] for example_result in x.values()])
non_tubb_test_prec_mean = np.mean([example_result["prec_peak_pick"] for x in evaluation_results["non_tubb_test"] for example_result in x.values()])
non_tubb_test_recall_mean = np.mean([example_result["recall_peak_pick"] for x in evaluation_results["non_tubb_test"] for example_result in x.values()])

print(f"TUBB test set F1 mean: {tubb_test_f1_mean}")
print(f"TUBB test set Precision mean: {tubb_test_prec_mean}")
print(f"TUBB test set Recall mean: {tubb_test_recall_mean}")

print()

print(f"Non-TUBB test set F1 mean: {non_tubb_test_f1_mean}")
print(f"Non-TUBB test set Precision mean: {non_tubb_test_prec_mean}")
print(f"Non-TUBB test set Recall mean: {non_tubb_test_recall_mean}")

In [None]:
Path("../results").mkdir(parents=True, exist_ok=True)

# export evaluation results to json
json.dump(evaluation_results, open(f"../results/evaluation_results_{int(time.time())}.json", "w"), indent=4)

In [None]:
np.mean([example_result["patch_preparation_time_ms"] for x in evaluation_results["tubb_test"] for example_result in x.values()])