# Tempo Segmentation Parameter Fitting

Use this notebook to tune `TempoSegmentationParams` against aligned performance data.
Will need more data until further testing.


## Dataset Format

Provide a JSON array where each entry describes one recording pair:

```json
[
  {
    "title": "My Recording",
    "actual": "../scores/reference.scoredata",
    "played": "../performances/take1.midi",
    "sections": [
      { "end_ind": 120, "label": "intro" },
      { "end_ind": 268, "label": "verse" },
      { "end_ind": 410, "label": "coda" }
    ]
  }
]
```

* `actual` points to a `.scoredata` protobuf or a reference `.midi` file.
* `played` is the performed `.midi` file.
* `sections` is an ordered list. Each `end_ind` is the inclusive index of the final
  reference note in that tempo segment. The first segment begins at `0`; each
  subsequent segment starts at the previous end index plus one. The final section must
  end at the last reference note index.


In [1]:
from __future__ import annotations

import itertools
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Sequence, Tuple
from bisect import bisect_right

import numpy as np

NOTEBOOK_DIR = Path.cwd().resolve()
BACKEND_ROOT = NOTEBOOK_DIR.parents[2]
SRC_ROOT = BACKEND_ROOT / "src"
if str(SRC_ROOT) not in sys.path:
    sys.path.insert(0, str(SRC_ROOT))

from scoring import (
    scoring_native,
    analyze_tempo,
    extract_midi_notes,
    extract_pb_notes,
)
from scoring.edit_distance import find_ops
from scoring.notes_pb2 import NoteList, TempoSection

  import pkg_resources


In [2]:
DATA_PATH = Path("fit_tempo_dataset.json")
assert DATA_PATH.exists(), f"Dataset config not found: {DATA_PATH}"
with DATA_PATH.open("r", encoding="utf-8") as handle:
    raw_entries = json.load(handle)
assert isinstance(raw_entries, list) and raw_entries, (
    "Dataset must be a non-empty list."
)
len(raw_entries)

1

In [3]:
@dataclass
class TempoDataset:
    title: str
    actual_times: List[float]
    played_times: List[float]
    aligned_pairs: List[Tuple[int, int]]
    boundaries: List[int]
    segments: List[Tuple[int, int, str]]
    reference_segments: List[Tuple[int, int, str]]
    aligned_actual_indices: List[int]

    @property
    def note_count(self) -> int:
        return len(self.actual_times)

    @property
    def alignment_count(self) -> int:
        return len(self.aligned_pairs)

    @property
    def reference_boundaries(self) -> List[int]:
        return [end for _, end, _ in self.reference_segments[:-1]]


def resolve_path(base: Path, raw: str) -> Path:
    path = Path(raw)
    if not path.is_absolute():
        path = (base / path).resolve()
    return path


def load_note_list(path: Path) -> NoteList:
    suffix = path.suffix.lower()
    if suffix in {".scoredata", ".pb"}:
        return extract_pb_notes(path.read_bytes())
    if suffix in {".midi", ".mid"}:
        return extract_midi_notes(str(path))
    raise ValueError(f"Unsupported note format: {path}")


def compute_alignment(
    actual_list: NoteList, played_list: NoteList
) -> List[Tuple[int, int]]:
    _, aligned = find_ops(actual_list.notes, played_list.notes)
    aligned_sorted = sorted((int(a), int(b)) for a, b in aligned)
    return aligned_sorted


def to_time_list(note_list: NoteList) -> List[float]:
    return [float(note.start_time) for note in note_list.notes]


def parse_sections(
    sections: Sequence[Dict[str, object]],
    aligned_actual_indices: Sequence[int],
) -> Tuple[List[Tuple[int, int, str]], List[int], List[Tuple[int, int, str]]]:
    if not sections:
        raise ValueError("Sections list may not be empty.")
    if not aligned_actual_indices:
        raise ValueError("Cannot derive sections without any aligned notes.")

    segments_aligned: List[Tuple[int, int, str]] = []
    reference_segments: List[Tuple[int, int, str]] = []
    prev_actual_end = -1
    prev_aligned_end = -1

    for idx, raw in enumerate(sections):
        end_actual = int(raw["end_ind"])
        label = str(raw.get("label", f"section_{idx}"))
        if end_actual < 0:
            raise ValueError("Section end index must be non-negative.")
        if end_actual < prev_actual_end:
            raise ValueError("Section end indices must be non-decreasing.")

        aligned_pos = bisect_right(aligned_actual_indices, end_actual) - 1
        if aligned_pos < 0:
            raise ValueError(
                f"Section end index {end_actual} precedes all aligned reference notes."
            )
        if aligned_pos <= prev_aligned_end:
            raise ValueError(
                f"Alignment collapsed segment '{label}' (end {end_actual}) to a non-increasing index."
            )

        start_actual = prev_actual_end + 1
        start_aligned = prev_aligned_end + 1
        segments_aligned.append((start_aligned, aligned_pos, label))
        reference_segments.append((start_actual, end_actual, label))

        prev_actual_end = end_actual
        prev_aligned_end = aligned_pos

    expected_final = len(aligned_actual_indices) - 1
    if prev_aligned_end > expected_final:
        raise ValueError(
            f"Final section exceeded aligned index {expected_final}, got {prev_aligned_end}."
        )

    boundaries = [segment[1] for segment in segments_aligned[:-1]]
    return segments_aligned, boundaries, reference_segments


def build_dataset(entry: Dict[str, object], base_dir: Path) -> TempoDataset:
    title = str(entry.get("title", "Untitled"))
    actual_path = resolve_path(base_dir, entry["actual"])
    played_path = resolve_path(base_dir, entry["played"])
    actual_notes = load_note_list(actual_path)
    played_notes = load_note_list(played_path)
    aligned_pairs = compute_alignment(actual_notes, played_notes)
    aligned_actual_indices = [idx for idx, _ in aligned_pairs]
    segments, boundaries, reference_segments = parse_sections(
        entry["sections"], aligned_actual_indices
    )
    return TempoDataset(
        title=title,
        actual_times=to_time_list(actual_notes),
        played_times=to_time_list(played_notes),
        aligned_pairs=aligned_pairs,
        boundaries=boundaries,
        segments=segments,
        reference_segments=reference_segments,
        aligned_actual_indices=aligned_actual_indices,
    )

In [4]:
datasets: List[TempoDataset] = []
for entry in raw_entries:
    print(path := DATA_PATH.absolute().parent.parent)
    datasets.append(build_dataset(entry, path))

for ds in datasets:
    aligned_summary = [f"{label}:{start}-{end}" for start, end, label in ds.segments]
    reference_summary = [
        f"{label}:{start}-{end}" for start, end, label in ds.reference_segments
    ]
    print(
        f"{ds.title}: ref_notes={ds.note_count}, aligned={ds.alignment_count}, "
        f"boundaries={ds.boundaries}, ref_boundaries={ds.reference_boundaries}"
    )
    print("  aligned segments ->", aligned_summary)
    print("  reference segments ->", reference_summary)
len(datasets)

[32m2025-09-28 15:51:06.058[0m | [1mINFO    [0m | [36mscoring.edit_distance[0m:[36mfind_ops[0m:[36m139[0m - [1m	[preprocess] took 2.178 ms[0m


/Users/timothyliu/PycharmProjects/note/backend/resources


[32m2025-09-28 15:51:07.104[0m | [1mINFO    [0m | [36mscoring.edit_distance[0m:[36mfind_ops[0m:[36m140[0m - [1m	[edit_distance] took 1046.102 ms[0m
[32m2025-09-28 15:51:07.107[0m | [1mINFO    [0m | [36mscoring.edit_distance[0m:[36mfind_ops[0m:[36m142[0m - [1m	[postprocess] took 1.975 ms[0m


Spider Dance Take 1: ref_notes=1774, aligned=1563, boundaries=[405, 699, 1239], ref_boundaries=[413, 744, 1381]
  aligned segments -> ['intro:0-405', 'difficult chorus:406-699', 'bridge:700-1239', 'chorus:1240-1410']
  reference segments -> ['intro:0-413', 'difficult chorus:414-744', 'bridge:745-1381', 'chorus:1382-1574']


1

In [5]:
def boundaries_from_sections(
    sections: Sequence[TempoSection],
    aligned_actual_indices: Sequence[int],
) -> List[int]:
    if not sections:
        return []
    if not aligned_actual_indices:
        raise ValueError("Cannot map tempo sections without aligned note indices.")

    boundaries: List[int] = []
    for section in sections[:-1]:
        end_actual = getattr(section, "end_index", None)
        if end_actual is None:
            _, end_actual, _ = section  # fallback for tuple-style data
        aligned_pos = bisect_right(aligned_actual_indices, int(end_actual)) - 1
        if aligned_pos < 0:
            raise ValueError(
                f"Tempo section end {end_actual} precedes all aligned reference notes."
            )
        boundaries.append(aligned_pos)
    return boundaries


def score_boundaries(
    predicted: Sequence[int], truth: Sequence[int], tolerance: int = 3
) -> Dict[str, float]:
    if not predicted and not truth:
        return {"precision": 1.0, "recall": 1.0, "f1": 1.0}
    if not predicted or not truth:
        return {"precision": 0.0, "recall": 0.0, "f1": 0.0}
    matched_truth = set()
    matched_pred = set()
    for p_idx, boundary in enumerate(predicted):
        for t_idx, truth_boundary in enumerate(truth):
            if t_idx in matched_truth:
                continue
            if abs(boundary - truth_boundary) <= tolerance:
                matched_pred.add(p_idx)
                matched_truth.add(t_idx)
                break
    precision = len(matched_pred) / len(predicted) if predicted else 0.0
    recall = len(matched_truth) / len(truth) if truth else 0.0
    if precision + recall == 0.0:
        f1 = 0.0
    else:
        f1 = 2 * precision * recall / (precision + recall)
    return {"precision": precision, "recall": recall, "f1": f1}

In [6]:
PARAM_GRID = {
    "min_segment_length": [6, 8, 10, 12],
    "penalty": [2.0, 3.5, 5.0, 6.5],
    "smoothing_window": [3, 5, 7, 9],
    "max_segments": [None, 6, 8, 10],
}


def evaluate_param_grid(
    datasets: Sequence[TempoDataset],
    grid: Dict[str, Sequence[float | int | None]],
    *,
    tolerance: int = 3,
) -> List[Dict[str, float]]:
    trials: List[Dict[str, float]] = []
    products = itertools.product(
        grid["min_segment_length"],
        grid["penalty"],
        grid["smoothing_window"],
        grid["max_segments"],
    )
    for min_len, penalty, window, max_segments in products:
        params = scoring_native.TempoSegmentationParams(
            min_segment_length=int(min_len),
            penalty=float(penalty),
            smoothing_window=int(window),
            max_segments=None if max_segments is None else int(max_segments),
        )
        precision_scores: List[float] = []
        recall_scores: List[float] = []
        f1_scores: List[float] = []
        instabilities: List[float] = []
        for ds in datasets:
            sections, instability = analyze_tempo(
                ds.actual_times,
                ds.played_times,
                ds.aligned_pairs,
                params,
            )
            predicted_boundaries = boundaries_from_sections(
                sections,
                ds.aligned_actual_indices,
            )
            scores = score_boundaries(
                predicted_boundaries, ds.boundaries, tolerance=tolerance
            )
            precision_scores.append(scores["precision"])
            recall_scores.append(scores["recall"])
            f1_scores.append(scores["f1"])
            instabilities.append(float(instability))
        trials.append(
            {
                "min_segment_length": float(min_len),
                "penalty": float(penalty),
                "smoothing_window": float(window),
                "max_segments": -1.0 if max_segments is None else float(max_segments),
                "precision": float(np.mean(precision_scores)),
                "recall": float(np.mean(recall_scores)),
                "f1": float(np.mean(f1_scores)),
                "instability": float(np.mean(instabilities)),
            }
        )
    trials.sort(
        key=lambda row: (row["f1"], row["precision"], -row["instability"]), reverse=True
    )
    return trials


grid_results = evaluate_param_grid(datasets, PARAM_GRID)
grid_results[:5]

[32m2025-09-28 15:51:07.252[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_param_grid[0m:[36m33[0m - [1m	[analyze_tempo] took 66.312 ms[0m
[32m2025-09-28 15:51:07.319[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_param_grid[0m:[36m33[0m - [1m	[analyze_tempo] took 65.935 ms[0m
[32m2025-09-28 15:51:07.386[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_param_grid[0m:[36m33[0m - [1m	[analyze_tempo] took 66.868 ms[0m
[32m2025-09-28 15:51:07.452[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_param_grid[0m:[36m33[0m - [1m	[analyze_tempo] took 65.609 ms[0m
[32m2025-09-28 15:51:07.518[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_param_grid[0m:[36m33[0m - [1m	[analyze_tempo] took 65.780 ms[0m
[32m2025-09-28 15:51:07.585[0m | [1mINFO    [0m | [36m__main__[0m:[36mevaluate_param_grid[0m:[36m33[0m - [1m	[analyze_tempo] took 66.028 ms[0m
[32m2025-09-28 15:51:07.653[0m | [1mINFO    [0m | [36m__main__[0

[{'min_segment_length': 6.0,
  'penalty': 2.0,
  'smoothing_window': 9.0,
  'max_segments': -1.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'instability': 170.4219207763672},
 {'min_segment_length': 6.0,
  'penalty': 2.0,
  'smoothing_window': 9.0,
  'max_segments': 6.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'instability': 170.4219207763672},
 {'min_segment_length': 6.0,
  'penalty': 2.0,
  'smoothing_window': 9.0,
  'max_segments': 8.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'instability': 170.4219207763672},
 {'min_segment_length': 6.0,
  'penalty': 2.0,
  'smoothing_window': 9.0,
  'max_segments': 10.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'instability': 170.4219207763672},
 {'min_segment_length': 6.0,
  'penalty': 3.5,
  'smoothing_window': 9.0,
  'max_segments': -1.0,
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'instability': 170.4219207763672}]

In [7]:
best = grid_results[0]
best

{'min_segment_length': 6.0,
 'penalty': 2.0,
 'smoothing_window': 9.0,
 'max_segments': -1.0,
 'precision': 0.0,
 'recall': 0.0,
 'f1': 0.0,
 'instability': 170.4219207763672}

In [8]:
def params_from_row(row: Dict[str, float]) -> scoring_native.TempoSegmentationParams:
    max_segments = None if row["max_segments"] < 0 else int(row["max_segments"])
    return scoring_native.TempoSegmentationParams(
        min_segment_length=int(row["min_segment_length"]),
        penalty=float(row["penalty"]),
        smoothing_window=int(row["smoothing_window"]),
        max_segments=max_segments,
    )


best_params = params_from_row(best)
per_dataset_scores: List[Dict[str, float]] = []
for ds in datasets:
    sections, instability = analyze_tempo(
        ds.actual_times,
        ds.played_times,
        ds.aligned_pairs,
        best_params,
    )
    predicted_boundaries = boundaries_from_sections(
        sections,
        ds.aligned_actual_indices,
    )
    scores = score_boundaries(predicted_boundaries, ds.boundaries)
    per_dataset_scores.append(
        {
            "title": ds.title,
            "precision": scores["precision"],
            "recall": scores["recall"],
            "f1": scores["f1"],
            "instability": float(instability),
        }
    )
per_dataset_scores

[32m2025-09-28 15:51:24.390[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [1m	[analyze_tempo] took 66.811 ms[0m


[{'title': 'Spider Dance Take 1',
  'precision': 0.0,
  'recall': 0.0,
  'f1': 0.0,
  'instability': 170.4219207763672}]

In [9]:
OUTPUT_PATH = DATA_PATH.with_suffix(".fitted_params.json")
payload = {
    "params": {
        "min_segment_length": best_params.min_segment_length,
        "penalty": best_params.penalty,
        "smoothing_window": best_params.smoothing_window,
        "max_segments": best_params.max_segments,
    },
    "aggregate_scores": {
        "precision": best["precision"],
        "recall": best["recall"],
        "f1": best["f1"],
        "instability": best["instability"],
    },
    "per_dataset": per_dataset_scores,
    "grid_top": grid_results[:10],
}
OUTPUT_PATH.write_text(json.dumps(payload, indent=2))
OUTPUT_PATH

PosixPath('fit_tempo_dataset.fitted_params.json')