# Automatic Piano Fingering Detection from Video

**Computer Vision Final Project**

---

## Project Goal

Given a video of a piano performance with synchronized MIDI data, automatically determine the finger assignment (1-5, thumb to pinky) for each played note.

**Input**: Video + MIDI -> **Output**: Per-note finger labels (L1-L5 for left hand, R1-R5 for right hand)

## Pipeline Architecture

```
Video -> Keyboard Detection -> Hand Pose Estimation -> Temporal Filtering -> Finger-Key Assignment -> Neural Refinement -> Fingering Labels
          (OpenCV)              (MediaPipe)             (Hampel+SavGol)      (Gaussian Prob.)          (BiLSTM)
```

| Stage | Method | Input | Output |
|-------|--------|-------|--------|
| 1. Keyboard Detection | Canny + Hough + Clustering + Black-Key Analysis | Video frames | 88 key bounding boxes |
| 2. Hand Pose Estimation | MediaPipe Hands | Video frame | 21-keypoint hand skeletons |
| 3. Temporal Filtering | Hampel + SavGol filters | Raw landmarks | Filtered landmarks (T x 21 x 3) |
| 4. Finger Assignment | Gaussian probability | MIDI events + fingertips + keys | FingerAssignment per note |
| 5. Neural Refinement | BiLSTM + Attention | Initial assignments | Refined predictions |

## Dataset

**PianoVAM** (KAIST)
- 107 piano performances with synchronized video, audio, MIDI
- Pre-extracted 21-keypoint hand skeletons (MediaPipe)
- Top-view camera angle (1920 x 1080, 60fps)

## Table of Contents

0. [Environment Setup](#0)
1. [Data Exploration](#1)
2. [Stage 1: Keyboard Detection (OpenCV)](#2)
3. [Stage 2: Hand Pose Estimation (MediaPipe)](#3)
4. [Stage 3: Temporal Filtering](#4)
5. [Stage 4: Finger-Key Assignment](#5)
6. [Baseline Pipeline on Multiple Samples](#6)
7. [Stage 5: Neural Refinement (BiLSTM)](#7)
8. [Evaluation & Results](#8)

---
<a id='0'></a>
## 0. Environment Setup

In [None]:
import os, sys, subprocess

IN_COLAB = 'google.colab' in str(get_ipython()) if 'get_ipython' in dir() else False

if IN_COLAB:
    REPO_URL = 'https://github.com/esnylmz/computer-vision.git'
    BRANCH = 'v4'
    if not os.path.exists('computer-vision'):
        subprocess.run(['git', 'clone', '--branch', BRANCH, '--single-branch', REPO_URL], check=True)
    os.chdir('computer-vision')
    subprocess.run(['git', 'fetch', 'origin', BRANCH], check=True)
    subprocess.run(['git', 'checkout', BRANCH], check=True)
    subprocess.run(['git', 'pull', '--ff-only', 'origin', BRANCH], check=True)
    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', '-e', '.'], check=True)
    # mediapipe-numpy2 keeps mp.solutions API and works with numpy 2.x on Colab
    subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'mediapipe-numpy2'], check=True)
    print('\nColab environment ready')
else:
    PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
    if PROJECT_ROOT not in sys.path:
        sys.path.insert(0, PROJECT_ROOT)

    # make sure we have a compatible mediapipe (solutions API removed in 0.10.31+)
    try:
        import mediapipe as _mp
        if not hasattr(_mp, 'solutions'):
            print('WARNING: mediapipe version too new, reinstalling compatible version...')
            subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'mediapipe-numpy2'], check=True)
    except ImportError:
        subprocess.run([sys.executable, '-m', 'pip', 'install', '-q', 'mediapipe-numpy2'], check=True)

    print('Local environment ready')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import json, time, warnings
from pathlib import Path
from tqdm.notebook import tqdm

warnings.filterwarnings('ignore', category=UserWarning)
sns.set_style('whitegrid')

print(f'NumPy  : {np.__version__}')
print(f'Pandas : {pd.__version__}')
print(f'OpenCV : {cv2.__version__}')

import torch
print(f'PyTorch: {torch.__version__}')
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device : {DEVICE}')

In [None]:
from src.data.dataset import PianoVAMDataset, PianoVAMSample
from src.data.midi_utils import MidiProcessor, MidiEvent
from src.data.video_utils import VideoProcessor
from src.utils.config import load_config, Config

from src.keyboard.detector import KeyboardDetector, KeyboardRegion
from src.keyboard.homography import HomographyComputer
from src.keyboard.key_localization import KeyLocalizer
from src.keyboard.auto_detector import AutoKeyboardDetector, AutoDetectionResult

from src.hand.skeleton_loader import SkeletonLoader, HandLandmarks
from src.hand.temporal_filter import TemporalFilter
from src.hand.fingertip_extractor import FingertipExtractor, FingertipData

from src.assignment.gaussian_assignment import GaussianFingerAssigner, FingerAssignment
from src.assignment.midi_sync import MidiVideoSync
from src.assignment.hand_separation import HandSeparator

from src.refinement.model import FingeringRefiner, FeatureExtractor, SequenceDataset
from src.refinement.constraints import BiomechanicalConstraints
from src.refinement.decoding import constrained_viterbi_decode
from src.refinement.train import train_refiner, collate_fn

from src.evaluation.metrics import FingeringMetrics, EvaluationResult, aggregate_results
from src.evaluation.visualization import ResultVisualizer

from src.pipeline import FingeringPipeline

config_path = 'configs/colab.yaml' if IN_COLAB else 'configs/default.yaml'
config = load_config(config_path)
print(f'All modules imported | Config: {config_path}')
print(f'Project: {config.project_name} v{config.version}')

---
<a id='1'></a>
## 1. Data Exploration

In [None]:
MAX_EXPLORE = 20

print('Loading PianoVAM dataset splits ...\n')
train_dataset = PianoVAMDataset(split='train', streaming=True, max_samples=MAX_EXPLORE)
val_dataset = PianoVAMDataset(split='validation', streaming=True, max_samples=MAX_EXPLORE)
test_dataset = PianoVAMDataset(split='test', streaming=True, max_samples=MAX_EXPLORE)

sample = next(iter(train_dataset))
print(f'\nSample ID      : {sample.id}')
print(f'Composer       : {sample.metadata["composer"]}')
print(f'Piece          : {sample.metadata["piece"]}')
print(f'Skill Level    : {sample.metadata["skill_level"]}')
print(f'Keyboard Corners: {sample.metadata["keyboard_corners"]}')

In [None]:
print('Collecting dataset statistics ...')
stats_ds = PianoVAMDataset(split='train', max_samples=None)

composers, skill_levels = [], []
for s in stats_ds:
    composers.append(s.metadata['composer'])
    skill_levels.append(s.metadata['skill_level'])

fig, axes = plt.subplots(1, 2, figsize=(14, 4))
pd.Series(skill_levels).value_counts().plot.bar(ax=axes[0], color='steelblue')
axes[0].set_title(f'Skill Level Distribution (n={len(skill_levels)})')
pd.Series(composers).value_counts().head(10).plot.barh(ax=axes[1], color='darkorange')
axes[1].set_title('Top 10 Composers')
plt.tight_layout()
plt.show()

print(f'Train samples    : {len(skill_levels)}')
print(f'Unique composers : {len(set(composers))}')
print(f'Skill levels     : {dict(pd.Series(skill_levels).value_counts())}')

---
<a id='2'></a>
## 2. Stage 1 — Keyboard Detection (Classical Computer Vision)

We detect the piano keyboard from raw video frames using classical computer vision:

**Automatic Detection Pipeline (Canny + Hough)**
1. Preprocessing — CLAHE contrast enhancement + Gaussian blur
2. Canny edge detection (Otsu-adaptive + fixed thresholds, merged)
3. Morphological closing to connect nearby edge fragments
4. Hough line transform — detect horizontal and vertical lines
5. Horizontal line clustering — group nearby lines, select keyboard top/bottom pair
6. Black-key contour analysis — refine x-boundaries using dark-region segmentation
7. Multi-frame consensus — sample multiple frames, take median bbox
8. Homography computation for perspective normalization
9. 88-key localization

**Corner-based Detection (Ground Truth)**
- Uses 4-point corner annotations from PianoVAM metadata
- Serves as ground truth for evaluating automatic detection (IoU)

In [None]:
# Download a video frame from PianoVAM
print(f'Downloading video for sample {sample.id} ...')
video_path = train_dataset.download_file(sample.video_path)
print(f'Video saved to: {video_path}')

vp = VideoProcessor()
vp.open(video_path)
print(f'Resolution: {vp.info.width}x{vp.info.height}')
print(f'FPS: {vp.info.fps}')
print(f'Total frames: {vp.info.frame_count}')
print(f'Duration: {vp.info.duration:.1f}s')

# grab a frame from the middle of the video
mid_frame_idx = vp.info.frame_count // 2
frame_bgr = vp.get_frame(mid_frame_idx)
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(14, 6))
plt.imshow(frame_rgb)
plt.title(f'Raw Video Frame (frame {mid_frame_idx})')
plt.axis('off')
plt.show()

vp.close()

In [None]:
# ── Image Preprocessing Pipeline ──
gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (5, 5), 0)

# CLAHE (Contrast-Limited Adaptive Histogram Equalisation) for lighting normalisation
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
enhanced_blur = cv2.GaussianBlur(enhanced, (5, 5), 0)

# Canny edge detection with different thresholds
edges_low = cv2.Canny(blurred, 30, 100)
edges_mid = cv2.Canny(blurred, 50, 150)
edges_high = cv2.Canny(blurred, 100, 200)

# Otsu-based automatic threshold on CLAHE-enhanced image
otsu_thresh, _ = cv2.threshold(enhanced_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
otsu_low, otsu_high = max(10, int(otsu_thresh * 0.5)), min(255, int(otsu_thresh * 1.0))
edges_otsu = cv2.Canny(enhanced_blur, otsu_low, otsu_high)
edges_merged = cv2.bitwise_or(edges_mid, edges_otsu)

# Morphological closing to connect fragmented edges
kernel_h = cv2.getStructuringElement(cv2.MORPH_RECT, (15, 1))
edges_closed = cv2.morphologyEx(edges_merged, cv2.MORPH_CLOSE, kernel_h)

fig, axes = plt.subplots(3, 3, figsize=(18, 14))

axes[0, 0].imshow(frame_rgb);               axes[0, 0].set_title('Original (RGB)')
axes[0, 1].imshow(gray, cmap='gray');        axes[0, 1].set_title('Grayscale')
axes[0, 2].imshow(enhanced, cmap='gray');    axes[0, 2].set_title('CLAHE Enhanced')

axes[1, 0].imshow(edges_low, cmap='gray');   axes[1, 0].set_title('Canny (30, 100)')
axes[1, 1].imshow(edges_mid, cmap='gray');   axes[1, 1].set_title('Canny (50, 150)')
axes[1, 2].imshow(edges_high, cmap='gray');  axes[1, 2].set_title('Canny (100, 200)')

axes[2, 0].imshow(edges_otsu, cmap='gray');  axes[2, 0].set_title(f'Otsu-adaptive Canny ({otsu_low}, {otsu_high})')
axes[2, 1].imshow(edges_merged, cmap='gray');axes[2, 1].set_title('Merged Edges (fixed + Otsu)')
axes[2, 2].imshow(edges_closed, cmap='gray');axes[2, 2].set_title('After Morphological Close')

for ax in axes.flat:
    ax.axis('off')

plt.suptitle('Image Preprocessing & Edge Detection Pipeline', fontsize=14)
plt.tight_layout()
plt.show()

print(f'Otsu threshold: {otsu_thresh:.0f}  →  Canny range: ({otsu_low}, {otsu_high})')

In [None]:
# Hough Line Transform on the edge map
edges = edges_mid

lines = cv2.HoughLinesP(
    edges, rho=1, theta=np.pi/180, threshold=100,
    minLineLength=100, maxLineGap=10
)

print(f'Total lines detected: {len(lines) if lines is not None else 0}')

# separate horizontal and vertical lines
line_vis = frame_rgb.copy()
horizontal_lines = []
vertical_lines = []

if lines is not None:
    for line in lines:
        x1, y1, x2, y2 = line[0]
        angle = np.abs(np.arctan2(y2 - y1, x2 - x1))
        if angle < np.pi / 18:  # within 10 degrees of horizontal
            horizontal_lines.append((x1, y1, x2, y2))
            cv2.line(line_vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
        elif angle > np.pi / 2 - np.pi / 18:  # within 10 degrees of vertical
            vertical_lines.append((x1, y1, x2, y2))
            cv2.line(line_vis, (x1, y1), (x2, y2), (255, 0, 0), 1)

print(f'Horizontal lines: {len(horizontal_lines)}')
print(f'Vertical lines  : {len(vertical_lines)}')

fig, axes = plt.subplots(1, 2, figsize=(18, 6))
axes[0].imshow(edges, cmap='gray')
axes[0].set_title('Canny Edge Map')
axes[1].imshow(line_vis)
axes[1].set_title('Hough Lines (green=horizontal, red=vertical)')
for ax in axes:
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# ── Automatic keyboard detection (Canny + Hough + clustering) ──
corners = sample.metadata['keyboard_corners']

auto_detector = AutoKeyboardDetector({
    'canny_low': config.keyboard.canny_low,
    'canny_high': config.keyboard.canny_high,
    'hough_threshold': config.keyboard.hough_threshold,
})

# Run full single-frame auto-detection with intermediate results
auto_result = auto_detector.detect_single_frame(frame_bgr, return_intermediates=True)
print(f'Auto-detection success: {auto_result.success}')
if auto_result.success:
    print(f'  Auto bbox           : {auto_result.consensus_bbox}')
    print(f'  Top line y          : {auto_result.top_line_y:.0f}')
    print(f'  Bottom line y       : {auto_result.bottom_line_y:.0f}')
    n_horiz = len(auto_result.horizontal_lines) if auto_result.horizontal_lines else 0
    n_vert  = len(auto_result.vertical_lines) if auto_result.vertical_lines else 0
    n_bk    = len(auto_result.black_key_contours) if auto_result.black_key_contours else 0
    print(f'  Horizontal lines    : {n_horiz}')
    print(f'  Vertical lines      : {n_vert}')
    print(f'  Black key contours  : {n_bk}')
    print(f'  Line clusters       : {len(auto_result.line_clusters)}')
else:
    print('  (auto-detection failed on this frame)')

# Corner-based detection (ground truth)
detector = KeyboardDetector({
    'canny_low': config.keyboard.canny_low,
    'canny_high': config.keyboard.canny_high,
    'hough_threshold': config.keyboard.hough_threshold
})
keyboard_region = detector.detect_from_corners(corners)
print(f'\nCorner-based detection: {len(keyboard_region.key_boundaries)} keys')
print(f'  Bounding box   : {keyboard_region.bbox}')
print(f'  White key width: {keyboard_region.white_key_width:.1f} px')

# Compute IoU between auto-detected and corner-based
if auto_result.success:
    iou = auto_detector.evaluate_against_corners(auto_result, corners)
    print(f'\n>>> IoU (auto vs corners): {iou:.3f} <<<')

In [None]:
# ── Visualize auto-detection intermediates ──
fig, axes = plt.subplots(2, 2, figsize=(18, 12))

# 1. Hough lines on frame
line_vis = auto_detector.visualize_lines(frame_bgr, auto_result)
axes[0, 0].imshow(cv2.cvtColor(line_vis, cv2.COLOR_BGR2RGB))
axes[0, 0].set_title('Hough Lines (green=horizontal, red=vertical)')

# 2. Line clusters & selected top/bottom edges
cluster_vis = auto_detector.visualize_clusters(frame_bgr, auto_result)
axes[0, 1].imshow(cv2.cvtColor(cluster_vis, cv2.COLOR_BGR2RGB))
axes[0, 1].set_title('Line Clusters & Selected Edges (cyan)')

# 3. Black key contours
bk_vis = auto_detector.visualize_black_keys(frame_bgr, auto_result)
axes[1, 0].imshow(cv2.cvtColor(bk_vis, cv2.COLOR_BGR2RGB))
axes[1, 0].set_title('Black Key Contour Detection')

# 4. Final comparison: auto (green) vs corner GT (red)
compare_vis = auto_detector.visualize_detection(
    frame_bgr, auto_result, corner_bbox=keyboard_region.bbox
)
axes[1, 1].imshow(cv2.cvtColor(compare_vis, cv2.COLOR_BGR2RGB))
axes[1, 1].set_title('Auto-detected (green) vs Corner GT (red)')

for ax in axes.flat:
    ax.axis('off')

plt.suptitle('Automatic Keyboard Detection — Intermediate Stages', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ── Multi-frame consensus auto-detection from video ──
print(f'Running multi-frame auto-detection on {video_path} ...')
video_auto_result = auto_detector.detect_from_video(str(video_path), return_intermediates=False)

n_valid = sum(1 for b in (video_auto_result.per_frame_bboxes or []) if b is not None)
n_total = len(video_auto_result.per_frame_bboxes or [])
print(f'  Sampled frames   : {n_total}')
print(f'  Successful frames: {n_valid}/{n_total}')
print(f'  Consensus bbox   : {video_auto_result.consensus_bbox}')

if video_auto_result.success:
    video_iou = auto_detector.evaluate_against_corners(video_auto_result, corners)
    print(f'  IoU (multi-frame vs corners): {video_iou:.3f}')

    # Show per-frame bboxes
    if video_auto_result.per_frame_bboxes:
        valid_bboxes = [b for b in video_auto_result.per_frame_bboxes if b is not None]
        arr = np.array(valid_bboxes)
        fig, axes = plt.subplots(1, 2, figsize=(16, 4))

        axes[0].plot(arr[:, 0], 'o-', label='x1')
        axes[0].plot(arr[:, 2], 's-', label='x2')
        axes[0].set_title('Per-frame x-coordinates')
        axes[0].legend()
        axes[0].set_xlabel('Sampled frame #')

        axes[1].plot(arr[:, 1], 'o-', label='y_top')
        axes[1].plot(arr[:, 3], 's-', label='y_bottom')
        axes[1].set_title('Per-frame y-coordinates')
        axes[1].legend()
        axes[1].set_xlabel('Sampled frame #')

        plt.suptitle('Multi-Frame Consensus: per-frame bbox coordinates', fontsize=13)
        plt.tight_layout()
        plt.show()
else:
    print('  Multi-frame auto-detection failed')

In [None]:
# Homography warping: perspective correction
hc = HomographyComputer()
H = keyboard_region.homography
H_inv = np.linalg.inv(H)

# warp the frame to get a top-down view of the keyboard
warped = cv2.warpPerspective(frame_bgr, H, (1718, 213))
warped_rgb = cv2.cvtColor(warped, cv2.COLOR_BGR2RGB)

# draw keyboard corners on the original frame
corner_vis = frame_rgb.copy()
if keyboard_region.corners:
    pts = [keyboard_region.corners[k] for k in ['LT', 'RT', 'RB', 'LB']]
    for i in range(4):
        p1 = pts[i]
        p2 = pts[(i + 1) % 4]
        cv2.line(corner_vis, p1, p2, (0, 255, 0), 3)
        cv2.circle(corner_vis, p1, 8, (255, 0, 0), -1)

fig, axes = plt.subplots(2, 1, figsize=(16, 8))
axes[0].imshow(corner_vis)
axes[0].set_title('Detected Keyboard Corners')
axes[1].imshow(warped_rgb)
axes[1].set_title('Perspective-Corrected Keyboard (Homography Warp)')
for ax in axes:
    ax.axis('off')
plt.tight_layout()
plt.show()

In [None]:
# Visualize 88-key layout in warped space
localizer = KeyLocalizer(keyboard_region.key_boundaries)
white_keys = localizer.get_white_keys()
black_keys = localizer.get_black_keys()

print(f'White keys: {len(white_keys)}  |  Black keys: {len(black_keys)}')

fig, ax = plt.subplots(figsize=(18, 3))
for ki in white_keys:
    x1, y1, x2, y2 = ki.bbox
    ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=0.8,
                               edgecolor='black', facecolor='white'))
for ki in black_keys:
    x1, y1, x2, y2 = ki.bbox
    ax.add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=0.5,
                               edgecolor='black', facecolor='#333'))

for note_name in ['A0', 'C4', 'C8']:
    ki = localizer.get_key_by_name(note_name)
    if ki:
        ax.annotate(ki.note_name, xy=ki.center, fontsize=7, color='red', ha='center', va='bottom')

ax.set_xlim(keyboard_region.bbox[0] - 10, keyboard_region.bbox[2] + 10)
ax.set_ylim(keyboard_region.bbox[3] + 10, keyboard_region.bbox[1] - 10)
ax.set_aspect('equal')
ax.set_title('88-Key Layout (Warped Space)')
plt.tight_layout()
plt.show()

---
<a id='3'></a>
## 3. Stage 2 - Hand Pose Estimation (MediaPipe)

We run MediaPipe Hands on actual video frames to detect 21 hand landmarks per hand.
Then compare with pre-extracted skeleton data from the PianoVAM dataset.

In [None]:
import mediapipe as mp

mp_hands = mp.solutions.hands
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles

print(f'MediaPipe version: {mp.__version__}')

In [None]:
# Run MediaPipe hand detection on video frames
vp = VideoProcessor()
vp.open(video_path)

# pick several frames spread across the video
sample_frame_indices = [300, 600, 1200, 2400, 4800]
sample_frames = []
for idx in sample_frame_indices:
    f = vp.get_frame(idx)
    if f is not None:
        sample_frames.append((idx, f))

vp.close()

# run mediapipe on each frame
hands_detector = mp_hands.Hands(
    static_image_mode=True,
    max_num_hands=2,
    min_detection_confidence=0.5
)

fig, axes = plt.subplots(1, len(sample_frames), figsize=(5 * len(sample_frames), 6))
if len(sample_frames) == 1:
    axes = [axes]

for i, (fidx, frame) in enumerate(sample_frames):
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = hands_detector.process(frame_rgb)

    annotated = frame_rgb.copy()
    n_hands = 0
    if results.multi_hand_landmarks:
        n_hands = len(results.multi_hand_landmarks)
        for hand_lm in results.multi_hand_landmarks:
            mp_drawing.draw_landmarks(
                annotated, hand_lm, mp_hands.HAND_CONNECTIONS,
                mp_drawing_styles.get_default_hand_landmarks_style(),
                mp_drawing_styles.get_default_hand_connections_style()
            )

    axes[i].imshow(annotated)
    axes[i].set_title(f'Frame {fidx} ({n_hands} hands)')
    axes[i].axis('off')

plt.suptitle('MediaPipe Hand Detection on Video Frames', fontsize=14)
plt.tight_layout()
plt.show()

hands_detector.close()

In [None]:
# Extract landmarks from MediaPipe and visualize
# Focus on one frame to show the 21-keypoint structure

demo_frame_bgr = sample_frames[2][1] if len(sample_frames) > 2 else sample_frames[0][1]
demo_frame_rgb = cv2.cvtColor(demo_frame_bgr, cv2.COLOR_BGR2RGB)

hands_detector = mp_hands.Hands(static_image_mode=True, max_num_hands=2, min_detection_confidence=0.5)
results = hands_detector.process(demo_frame_rgb)
hands_detector.close()

h, w = demo_frame_rgb.shape[:2]
annotated = demo_frame_rgb.copy()

fingertip_indices = [4, 8, 12, 16, 20]
finger_names = {4: 'Thumb', 8: 'Index', 12: 'Middle', 16: 'Ring', 20: 'Pinky'}

if results.multi_hand_landmarks:
    for hand_lm, hand_info in zip(results.multi_hand_landmarks, results.multi_handedness):
        label = hand_info.classification[0].label
        mp_drawing.draw_landmarks(annotated, hand_lm, mp_hands.HAND_CONNECTIONS)

        # mark fingertips
        for tip_idx in fingertip_indices:
            lm = hand_lm.landmark[tip_idx]
            px, py = int(lm.x * w), int(lm.y * h)
            cv2.circle(annotated, (px, py), 8, (255, 0, 0), -1)
            cv2.putText(annotated, finger_names[tip_idx], (px + 10, py - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1)

        print(f'{label} hand detected:')
        for tip_idx in fingertip_indices:
            lm = hand_lm.landmark[tip_idx]
            print(f'  {finger_names[tip_idx]:6s} tip: ({lm.x:.4f}, {lm.y:.4f}, {lm.z:.4f})')

plt.figure(figsize=(14, 8))
plt.imshow(annotated)
plt.title('MediaPipe 21-Keypoint Hand Skeleton with Fingertip Labels')
plt.axis('off')
plt.show()

In [None]:
# Compare our MediaPipe detection with pre-extracted skeleton JSON
print(f'Downloading skeleton JSON for sample {sample.id} ...')
skeleton_data = train_dataset.load_skeleton(sample)

loader = SkeletonLoader()
hands_parsed = loader._parse_json(skeleton_data)

left_raw = loader.to_array(hands_parsed['left'])
right_raw = loader.to_array(hands_parsed['right'])

print(f'\nPre-extracted skeleton:')
print(f'  Left  hand frames: {len(hands_parsed["left"])}, array shape: {left_raw.shape}')
print(f'  Right hand frames: {len(hands_parsed["right"])}, array shape: {right_raw.shape}')

left_valid = int(np.sum(~np.any(np.isnan(left_raw.reshape(len(left_raw), -1)), axis=1))) if left_raw.size > 0 else 0
right_valid = int(np.sum(~np.any(np.isnan(right_raw.reshape(len(right_raw), -1)), axis=1))) if right_raw.size > 0 else 0
print(f'  Valid frames - Left: {left_valid}  Right: {right_valid}')

In [None]:
# Overlay pre-extracted skeleton on the same frame for comparison
demo_fidx = sample_frames[2][0] if len(sample_frames) > 2 else sample_frames[0][0]
comparison = demo_frame_rgb.copy()

# draw pre-extracted landmarks in green
for hand_key, color in [('right', (0, 255, 0)), ('left', (0, 200, 255))]:
    arr = right_raw if hand_key == 'right' else left_raw
    if demo_fidx < len(arr) and not np.any(np.isnan(arr[demo_fidx])):
        lm = arr[demo_fidx]
        for j in range(21):
            px = int(lm[j, 0] * w)
            py = int(lm[j, 1] * h)
            cv2.circle(comparison, (px, py), 4, color, -1)
        # connect landmarks
        connections = [(0,1),(1,2),(2,3),(3,4),(0,5),(5,6),(6,7),(7,8),
                       (5,9),(9,10),(10,11),(11,12),(9,13),(13,14),(14,15),(15,16),
                       (13,17),(17,18),(18,19),(19,20),(0,17)]
        for c1, c2 in connections:
            p1 = (int(lm[c1, 0] * w), int(lm[c1, 1] * h))
            p2 = (int(lm[c2, 0] * w), int(lm[c2, 1] * h))
            cv2.line(comparison, p1, p2, color, 2)

plt.figure(figsize=(14, 8))
plt.imshow(comparison)
plt.title(f'Pre-extracted Skeleton Overlay (green=right, cyan=left) - Frame {demo_fidx}')
plt.axis('off')
plt.show()

---
<a id='4'></a>
## 4. Stage 3 - Temporal Filtering

MediaPipe landmarks are noisy. We apply a 3-stage filtering pipeline:
1. Hampel filter (outlier detection via Median Absolute Deviation)
2. Linear interpolation (fill gaps < 30 frames)
3. Savitzky-Golay filter (smoothing)

In [None]:
tf = TemporalFilter(
    hampel_window=config.hand.hampel_window,
    hampel_threshold=config.hand.hampel_threshold,
    max_interpolation_gap=config.hand.interpolation_max_gap,
    savgol_window=config.hand.savgol_window,
    savgol_order=config.hand.savgol_order
)

left_filtered = tf.process(left_raw) if left_raw.size > 0 else left_raw
right_filtered = tf.process(right_raw) if right_raw.size > 0 else right_raw

print('Filtering complete')
print(f'Left  filtered shape: {left_filtered.shape}')
print(f'Right filtered shape: {right_filtered.shape}')

In [None]:
# Visualize filtering effect: index fingertip x-coordinate
hand_arr_raw = right_raw if right_raw.size > 0 else left_raw
hand_arr_filt = right_filtered if right_filtered.size > 0 else left_filtered
hand_label = 'Right' if right_raw.size > 0 else 'Left'

lm_idx = 8  # index fingertip
T = min(3000, len(hand_arr_raw))

raw_signal = hand_arr_raw[:T, lm_idx, 0]
filt_signal = hand_arr_filt[:T, lm_idx, 0]

fig, ax = plt.subplots(figsize=(16, 4))
ax.plot(raw_signal, alpha=0.5, label='Raw', linewidth=0.5)
ax.plot(filt_signal, label='Filtered', linewidth=1)
ax.set_title(f'{hand_label} Hand - Index Fingertip X-Coordinate')
ax.set_xlabel('Frame')
ax.set_ylabel('X (normalized)')
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Optical flow: track fingertip motion between frames
vp = VideoProcessor()
vp.open(video_path)

flow_start = 1000
frame1_bgr = vp.get_frame(flow_start)
frame2_bgr = vp.get_frame(flow_start + 5)
vp.close()

if frame1_bgr is not None and frame2_bgr is not None:
    gray1 = cv2.cvtColor(frame1_bgr, cv2.COLOR_BGR2GRAY)
    gray2 = cv2.cvtColor(frame2_bgr, cv2.COLOR_BGR2GRAY)

    # crop to keyboard region for clearer visualization
    y1, y2 = keyboard_region.bbox[1], keyboard_region.bbox[3]
    x1, x2 = keyboard_region.bbox[0], keyboard_region.bbox[2]
    gray1_crop = gray1[y1:y2, x1:x2]
    gray2_crop = gray2[y1:y2, x1:x2]

    flow = cv2.calcOpticalFlowFarneback(
        gray1_crop, gray2_crop, None,
        pyr_scale=0.5, levels=3, winsize=15, iterations=3, poly_n=5, poly_sigma=1.2, flags=0
    )

    magnitude, angle = cv2.cartToPolar(flow[..., 0], flow[..., 1])

    # HSV visualization
    hsv = np.zeros((*gray1_crop.shape, 3), dtype=np.uint8)
    hsv[..., 0] = angle * 180 / np.pi / 2
    hsv[..., 1] = 255
    hsv[..., 2] = cv2.normalize(magnitude, None, 0, 255, cv2.NORM_MINMAX)
    flow_vis = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    axes[0].imshow(cv2.cvtColor(frame1_bgr[y1:y2, x1:x2], cv2.COLOR_BGR2RGB))
    axes[0].set_title(f'Frame {flow_start}')
    axes[1].imshow(cv2.cvtColor(frame2_bgr[y1:y2, x1:x2], cv2.COLOR_BGR2RGB))
    axes[1].set_title(f'Frame {flow_start + 5}')
    axes[2].imshow(flow_vis)
    axes[2].set_title('Dense Optical Flow (Farneback)')
    for ax in axes:
        ax.axis('off')
    plt.suptitle('Optical Flow: Hand Motion Between Frames', fontsize=14)
    plt.tight_layout()
    plt.show()

    print(f'Mean flow magnitude: {np.mean(magnitude):.2f} px/frame')
    print(f'Max flow magnitude : {np.max(magnitude):.2f} px/frame')
else:
    print('Could not read frames for optical flow')

In [None]:
# Fingertip extraction
extractor = FingertipExtractor()

sample_fidx = 500
if sample_fidx < len(right_filtered) and not np.any(np.isnan(right_filtered[sample_fidx])):
    ftips = extractor.extract(right_filtered[sample_fidx], frame_idx=sample_fidx, hand_type='right')
    print(f'Frame {sample_fidx} - Right hand fingertips:')
    for f_num in range(1, 6):
        pos = ftips.get_position_2d(f_num)
        if pos:
            print(f'  {extractor.FINGER_NAMES[f_num]:6s}: ({pos[0]:.4f}, {pos[1]:.4f})')

    span = extractor.compute_hand_span(ftips)
    print(f'  Hand span: {span:.4f}')

---
<a id='5'></a>
## 5. Stage 4 - Finger-Key Assignment

Gaussian probability model in image-pixel space (not homography-warped space).
Uses x-distance only to avoid y-bias from different finger lengths.
Tries both hands for each key, picks the higher-confidence assignment.
Max-distance gate rejects assignments when the hand is too far from the key.

In [None]:
# Load MIDI/TSV annotations
print(f'Downloading TSV annotations for sample {sample.id} ...')
tsv_df = train_dataset.load_tsv_annotations(sample)

midi_events = []
for _, row in tsv_df.iterrows():
    midi_events.append({
        'onset': float(row['onset']),
        'offset': float(row['onset']) + 0.3,
        'pitch': int(row['note']),
        'velocity': int(row['velocity']) if 'velocity' in row and pd.notna(row['velocity']) else 64
    })

print(f'Total MIDI events: {len(midi_events)}')
print(f'Pitch range: {min(e["pitch"] for e in midi_events)} - {max(e["pitch"] for e in midi_events)}')

In [None]:
# Synchronize MIDI events with video frames
midi_sync = MidiVideoSync(fps=config.video_fps)
synced_events = midi_sync.sync_events(midi_events)
print(f'Synced events: {len(synced_events)}')

In [None]:
FRAME_W, FRAME_H = 1920, 1080

def project_keys_to_pixel_space(key_boundaries_warped, homography):
    H_inv = np.linalg.inv(homography)
    result = {}
    for k, (x1, y1, x2, y2) in key_boundaries_warped.items():
        cy = (y1 + y2) / 2.0
        pts_w = np.array([[x1, cy, 1.0], [x2, cy, 1.0], [(x1+x2)/2.0, cy, 1.0]], dtype=np.float64)
        pts_p = (H_inv @ pts_w.T).T
        pts_p = pts_p[:, :2] / pts_p[:, 2:3]
        lx, rx = pts_p[0, 0], pts_p[1, 0]
        cy_px = pts_p[2, 1]
        result[k] = (lx, cy_px - 5.0, rx, cy_px + 5.0)
    return result

key_boundaries_px = project_keys_to_pixel_space(keyboard_region.key_boundaries, keyboard_region.homography)

# scale landmarks from [0,1] to pixels
left_px = left_filtered.copy()
left_px[:, :, 0] *= FRAME_W
left_px[:, :, 1] *= FRAME_H

right_px = right_filtered.copy()
right_px[:, :, 0] *= FRAME_W
right_px[:, :, 1] *= FRAME_H

assigner = GaussianFingerAssigner(
    key_boundaries=key_boundaries_px,
    sigma=config.assignment.sigma,
    candidate_range=config.assignment.candidate_keys
)

print(f'Sigma (auto): {assigner.sigma:.1f} px')
print(f'Max distance: {assigner.max_distance_px:.0f} px ({assigner.max_distance_sigma} sigma)')

In [None]:
# Run assignment: try BOTH hands for every key, pick higher confidence
assignments = []
skipped = 0

for event in synced_events:
    frame_idx = event.frame_idx
    key_idx = event.key_idx

    if key_idx not in assigner.key_centers:
        skipped += 1
        continue

    asgn_right = None
    if frame_idx < len(right_px):
        lm = right_px[frame_idx]
        if not np.any(np.isnan(lm)):
            asgn_right = assigner.assign_from_landmarks(lm, key_idx, 'right', frame_idx, event.onset_time)

    asgn_left = None
    if frame_idx < len(left_px):
        lm = left_px[frame_idx]
        if not np.any(np.isnan(lm)):
            asgn_left = assigner.assign_from_landmarks(lm, key_idx, 'left', frame_idx, event.onset_time)

    candidates = [a for a in (asgn_right, asgn_left) if a is not None]
    if candidates:
        assignments.append(max(candidates, key=lambda a: a.confidence))
    else:
        skipped += 1

print(f'Total events  : {len(synced_events)}')
print(f'Assigned      : {len(assignments)}')
print(f'Skipped       : {skipped}')
print(f'Coverage      : {len(assignments)/max(1,len(synced_events))*100:.1f}%')

In [None]:
# Assignment statistics
if assignments:
    fingers = [a.assigned_finger for a in assignments]
    hands_list = [a.hand for a in assignments]
    confs = [a.confidence for a in assignments]

    fig, axes = plt.subplots(1, 3, figsize=(16, 4))

    finger_names_map = {1: 'Thumb', 2: 'Index', 3: 'Middle', 4: 'Ring', 5: 'Pinky'}
    colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00']
    fc = pd.Series(fingers).value_counts().sort_index()
    fc.plot.bar(ax=axes[0], color=[colors[i-1] for i in fc.index])
    axes[0].set_xticklabels([finger_names_map[i] for i in fc.index], rotation=45)
    axes[0].set_title('Finger Distribution')

    pd.Series(hands_list).value_counts().plot.bar(ax=axes[1], color=['coral', 'skyblue'])
    axes[1].set_title('Hand Distribution')

    axes[2].hist(confs, bins=30, color='mediumseagreen', edgecolor='white')
    axes[2].set_title('Assignment Confidence')

    plt.tight_layout()
    plt.show()

    print('\nSample assignments:')
    for a in assignments[:10]:
        print(f'  Frame {a.frame_idx:>5d} | {a.label} | MIDI {a.midi_pitch} ({a.finger_name:6s}) | conf={a.confidence:.3f}')

---
<a id='6'></a>
## 6. Baseline Pipeline on Multiple Samples

In [None]:
_SAMPLE_CACHE = {'skeleton': {}, 'tsv': {}, 'filtered_landmarks': {}, 'keys_px': {}}

def process_sample_baseline(sample, dataset, config, max_duration_sec=120, cache=_SAMPLE_CACHE):
    result = {'sample_id': sample.id, 'assignments': [], 'error': None}
    try:
        corners = sample.metadata['keyboard_corners']
        det = KeyboardDetector()
        kb = det.detect_from_corners(corners)

        if sample.id not in cache['keys_px']:
            cache['keys_px'][sample.id] = project_keys_to_pixel_space(kb.key_boundaries, kb.homography)
        kb_px = cache['keys_px'][sample.id]

        if sample.id not in cache['filtered_landmarks']:
            if sample.id not in cache['skeleton']:
                cache['skeleton'][sample.id] = dataset.load_skeleton(sample)
            skel = cache['skeleton'][sample.id]

            ldr = SkeletonLoader()
            h = ldr._parse_json(skel)
            la = ldr.to_array(h['left'])
            ra = ldr.to_array(h['right'])

            t = TemporalFilter(
                hampel_window=config.hand.hampel_window,
                hampel_threshold=config.hand.hampel_threshold,
                max_interpolation_gap=config.hand.interpolation_max_gap,
                savgol_window=config.hand.savgol_window,
                savgol_order=config.hand.savgol_order
            )
            if la.size > 0: la = t.process(la)
            if ra.size > 0: ra = t.process(ra)
            if la.size > 0: la = la.copy(); la[:,:,0] *= FRAME_W; la[:,:,1] *= FRAME_H
            if ra.size > 0: ra = ra.copy(); ra[:,:,0] *= FRAME_W; ra[:,:,1] *= FRAME_H
            cache['filtered_landmarks'][sample.id] = (la, ra)

        la, ra = cache['filtered_landmarks'][sample.id]
        if max_duration_sec:
            mf = int(max_duration_sec * config.video_fps)
            if la.size > 0: la = la[:mf]
            if ra.size > 0: ra = ra[:mf]

        if sample.id not in cache['tsv']:
            cache['tsv'][sample.id] = dataset.load_tsv_annotations(sample)
        tsv = cache['tsv'][sample.id]
        if max_duration_sec:
            tsv = tsv[tsv['onset'] <= float(max_duration_sec)].copy()

        midi_evts = [{'onset': float(r['onset']), 'offset': float(r['onset'])+0.3,
                      'pitch': int(r['note']),
                      'velocity': int(r['velocity']) if 'velocity' in r and pd.notna(r['velocity']) else 64}
                     for _, r in tsv.iterrows()]

        sync = MidiVideoSync(fps=config.video_fps)
        synced = sync.sync_events(midi_evts)

        asgn = GaussianFingerAssigner(key_boundaries=kb_px, sigma=config.assignment.sigma,
                                      candidate_range=config.assignment.candidate_keys)

        for ev in synced:
            fidx, kidx = ev.frame_idx, ev.key_idx
            if kidx not in asgn.key_centers: continue
            ar = None
            if fidx < len(ra):
                lm = ra[fidx]
                if not np.any(np.isnan(lm)):
                    ar = asgn.assign_from_landmarks(lm, kidx, 'right', fidx, ev.onset_time)
            al = None
            if fidx < len(la):
                lm = la[fidx]
                if not np.any(np.isnan(lm)):
                    al = asgn.assign_from_landmarks(lm, kidx, 'left', fidx, ev.onset_time)
            cands = [a for a in (ar, al) if a is not None]
            if cands:
                result['assignments'].append(max(cands, key=lambda a: a.confidence))
    except Exception as e:
        result['error'] = str(e)
    return result

In [None]:
NUM_SAMPLES = 10
MAX_DURATION_SEC = 120

all_results = []
for i, samp in enumerate(train_dataset):
    if i >= NUM_SAMPLES: break
    print(f'Processing {i+1}/{NUM_SAMPLES}: {samp.id} - {samp.metadata["piece"][:40]}')
    res = process_sample_baseline(samp, train_dataset, config, max_duration_sec=MAX_DURATION_SEC)
    if res['error']:
        print(f'  Error: {res["error"][:100]}')
    else:
        print(f'  Assigned {len(res["assignments"])} notes')
    all_results.append(res)

total_assigned = sum(len(r['assignments']) for r in all_results)
print(f'\nTotal assignments: {total_assigned}')

In [None]:
all_fingers = [a.assigned_finger for r in all_results for a in r['assignments']]
all_hands = [a.hand for r in all_results for a in r['assignments']]
all_confs = [a.confidence for r in all_results for a in r['assignments']]

if all_fingers:
    fig, axes = plt.subplots(1, 3, figsize=(16, 4))
    colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00']
    fc = pd.Series(all_fingers).value_counts().sort_index()
    fc.plot.bar(ax=axes[0], color=[colors[i-1] for i in fc.index])
    axes[0].set_title(f'Finger Distribution (n={len(all_fingers)})')
    pd.Series(all_hands).value_counts().plot.bar(ax=axes[1], color=['coral', 'skyblue'])
    axes[1].set_title('Hand Distribution')
    axes[2].hist(all_confs, bins=30, color='mediumseagreen', edgecolor='white')
    axes[2].axvline(np.mean(all_confs), color='red', ls='--', label=f'mean={np.mean(all_confs):.3f}')
    axes[2].set_title('Confidence Distribution')
    axes[2].legend()
    plt.tight_layout()
    plt.show()

In [None]:
# ── Keyboard Auto-Detection IoU Across Multiple Samples ──
print('Evaluating auto-detection (Canny/Hough) IoU across samples ...\n')

auto_det = AutoKeyboardDetector({
    'canny_low': config.keyboard.canny_low,
    'canny_high': config.keyboard.canny_high,
    'hough_threshold': config.keyboard.hough_threshold,
    'num_sample_frames': 5,
})

iou_scores = []
eval_ds = PianoVAMDataset(split='train', streaming=True, max_samples=NUM_SAMPLES)

for i, samp in enumerate(eval_ds):
    if i >= NUM_SAMPLES:
        break
    try:
        vpath = eval_ds.download_file(samp.video_path)
        res = auto_det.detect_from_video(str(vpath))
        if res.success:
            iou = auto_det.evaluate_against_corners(res, samp.metadata['keyboard_corners'])
            iou_scores.append(iou)
            print(f'  {samp.id}: IoU = {iou:.3f}  (bbox={res.consensus_bbox})')
        else:
            iou_scores.append(0.0)
            print(f'  {samp.id}: FAILED')
    except Exception as e:
        iou_scores.append(0.0)
        print(f'  {samp.id}: Error - {str(e)[:60]}')

print(f'\nAuto-detection IoU summary ({len(iou_scores)} samples):')
print(f'  Mean IoU : {np.mean(iou_scores):.3f}')
print(f'  Median   : {np.median(iou_scores):.3f}')
print(f'  Min / Max: {np.min(iou_scores):.3f} / {np.max(iou_scores):.3f}')
print(f'  Success  : {sum(1 for s in iou_scores if s > 0)}/{len(iou_scores)}')

if iou_scores:
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.bar(range(len(iou_scores)), iou_scores, color='steelblue', edgecolor='white')
    ax.axhline(np.mean(iou_scores), color='red', ls='--', label=f'Mean = {np.mean(iou_scores):.3f}')
    ax.set_xlabel('Sample')
    ax.set_ylabel('IoU')
    ax.set_title('Auto-Detection IoU vs Corner Annotations')
    ax.set_ylim(0, 1)
    ax.legend()
    plt.tight_layout()
    plt.show()

---
<a id='7'></a>
## 7. Stage 5 - Neural Refinement (BiLSTM)

Architecture: Input(20) -> Linear(128) -> BiLSTM(128 x 2 layers) -> Self-Attention -> Linear(128) -> Linear(5)

In [None]:
print('Preparing training sequences from baseline assignments ...')

MAX_TRAIN_SAMPLES = 20
train_sequences = []
train_ds_full = PianoVAMDataset(split='train', streaming=True, max_samples=MAX_TRAIN_SAMPLES)

for i, samp in enumerate(train_ds_full):
    if i >= MAX_TRAIN_SAMPLES: break
    res = process_sample_baseline(samp, train_ds_full, config, max_duration_sec=120)
    asgns = res['assignments']
    if len(asgns) < 10: continue
    seq = {
        'pitches': [a.midi_pitch for a in asgns],
        'fingers': [a.assigned_finger for a in asgns],
        'onsets': [a.note_onset for a in asgns],
        'hands': [a.hand for a in asgns],
        'labels': [a.assigned_finger for a in asgns],
    }
    train_sequences.append(seq)

print(f'Training sequences: {len(train_sequences)}')
print(f'Total notes: {sum(len(s["pitches"]) for s in train_sequences)}')

In [None]:
feature_extractor = FeatureExtractor(normalize_pitch=True)
input_size = feature_extractor.get_input_size()

trained_model = None
if len(train_sequences) > 2:
    split_idx = max(1, int(0.8 * len(train_sequences)))
    train_seqs = train_sequences[:split_idx]
    val_seqs = train_sequences[split_idx:]

    train_torch_ds = SequenceDataset(train_seqs, feature_extractor, max_len=256)
    val_torch_ds = SequenceDataset(val_seqs, feature_extractor, max_len=256)

    model = FingeringRefiner(
        input_size=input_size,
        hidden_size=config.refinement.hidden_size,
        num_layers=config.refinement.num_layers,
        dropout=config.refinement.dropout,
        bidirectional=config.refinement.bidirectional
    ).to(DEVICE)

    print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')
    print(model)

    training_config = {
        'hidden_size': config.refinement.hidden_size,
        'num_layers': config.refinement.num_layers,
        'dropout': config.refinement.dropout,
        'batch_size': min(config.refinement.batch_size, len(train_torch_ds)),
        'learning_rate': config.refinement.learning_rate,
        'epochs': config.refinement.epochs,
        'early_stopping_patience': config.refinement.early_stopping_patience,
        'device': DEVICE,
        'checkpoint_dir': '/content/checkpoints' if IN_COLAB else './outputs/checkpoints'
    }

    print('\nTraining BiLSTM refinement model ...')
    trained_model = train_refiner(
        train_dataset=train_torch_ds,
        val_dataset=val_torch_ds if len(val_torch_ds) > 0 else None,
        config=training_config
    )
    print('Training complete')
else:
    print('Not enough data for training')

In [None]:
def refine_assignments(model, assignments, feature_extractor, device='cpu', use_constraints=True):
    if not assignments or model is None:
        return assignments

    pitches = [a.midi_pitch for a in assignments]
    fingers = [a.assigned_finger for a in assignments]
    onsets = [a.note_onset for a in assignments]
    hands = [a.hand for a in assignments]

    x = feature_extractor.extract(pitches, fingers, onsets, hands)
    x = x.unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=-1).squeeze(0).cpu().numpy()

    if use_constraints:
        decoded = constrained_viterbi_decode(
            probs=probs, pitches=pitches, hands=hands,
            constraints=BiomechanicalConstraints(strict=False)
        )
        pred_fingers = decoded.fingers
    else:
        pred_fingers = (np.argmax(probs, axis=-1) + 1).tolist()

    confs = [float(probs[i, f - 1]) for i, f in enumerate(pred_fingers)]

    return [FingerAssignment(
        note_onset=a.note_onset, frame_idx=a.frame_idx, midi_pitch=a.midi_pitch,
        key_idx=a.key_idx, assigned_finger=int(pred_fingers[i]), hand=a.hand,
        confidence=float(confs[i]), fingertip_position=a.fingertip_position
    ) for i, a in enumerate(assignments)]


if trained_model is not None and all_results:
    print('Refining baseline predictions ...')
    for res in all_results:
        if res['assignments']:
            original = res['assignments']
            refined = refine_assignments(trained_model, original, feature_extractor, DEVICE)
            res['refined_assignments'] = refined
            changed = sum(1 for o, r in zip(original, refined) if o.assigned_finger != r.assigned_finger)
            print(f'  {res["sample_id"]}: {changed}/{len(original)} changed')
    print('Refinement done')

---
<a id='8'></a>
## 8. Evaluation & Results

In [None]:
metrics = FingeringMetrics()
constraints = BiomechanicalConstraints()

print('=' * 70)
print('EVALUATION RESULTS')
print('=' * 70)

baseline_ifrs = []
refined_ifrs = []

for res in all_results:
    if not res['assignments']: continue
    asgns = res['assignments']
    pitches = [a.midi_pitch for a in asgns]
    fingers = [a.assigned_finger for a in asgns]
    hl = [a.hand for a in asgns]

    violations = constraints.validate_sequence(fingers, pitches, hl)
    ifr = len(violations) / max(1, len(asgns) - 1)
    baseline_ifrs.append(ifr)
    mc = np.mean([a.confidence for a in asgns])

    msg = f'  {res["sample_id"]} - {len(asgns)} notes | Baseline IFR={ifr:.3f} | conf={mc:.3f}'

    if 'refined_assignments' in res:
        ref = res['refined_assignments']
        rf = [a.assigned_finger for a in ref]
        rv = constraints.validate_sequence(rf, pitches, hl)
        ri = len(rv) / max(1, len(ref) - 1)
        refined_ifrs.append(ri)
        msg += f' | Refined IFR={ri:.3f}'

    print(msg)

print('\n' + '=' * 70)
if baseline_ifrs:
    print(f'BASELINE Mean IFR: {np.mean(baseline_ifrs):.3f} +/- {np.std(baseline_ifrs):.3f}')
if refined_ifrs:
    print(f'REFINED  Mean IFR: {np.mean(refined_ifrs):.3f} +/- {np.std(refined_ifrs):.3f}')
    imp = np.mean(baseline_ifrs) - np.mean(refined_ifrs)
    print(f'Improvement: {imp:+.3f}')
print('=' * 70)

In [None]:
# Test set evaluation
print('Processing test split ...\n')
test_ds_eval = PianoVAMDataset(split='test', streaming=True, max_samples=5)
test_results = []

for i, samp in enumerate(test_ds_eval):
    print(f'  Test {i+1}: {samp.id}')
    res = process_sample_baseline(samp, test_ds_eval, config)
    if res['error']:
        print(f'    Error: {res["error"][:80]}')
    else:
        n = len(res['assignments'])
        if trained_model is not None and n > 0:
            res['refined_assignments'] = refine_assignments(
                trained_model, res['assignments'], feature_extractor, DEVICE)
        print(f'    {n} notes assigned')
    test_results.append(res)

print('\n' + '=' * 70)
print('TEST SET RESULTS')
print('=' * 70)

test_baseline_ifrs = []
test_refined_ifrs = []

for res in test_results:
    if not res['assignments']: continue
    asgns = res['assignments']
    pitches = [a.midi_pitch for a in asgns]
    fingers = [a.assigned_finger for a in asgns]
    hl = [a.hand for a in asgns]

    viols = constraints.validate_sequence(fingers, pitches, hl)
    ifr = len(viols) / max(1, len(asgns) - 1)
    test_baseline_ifrs.append(ifr)

    msg = f'  {res["sample_id"]} - {len(asgns)} notes | Baseline IFR={ifr:.3f}'

    if 'refined_assignments' in res:
        ref = res['refined_assignments']
        rf = [a.assigned_finger for a in ref]
        rv = constraints.validate_sequence(rf, pitches, hl)
        ri = len(rv) / max(1, len(ref) - 1)
        test_refined_ifrs.append(ri)
        msg += f' | Refined IFR={ri:.3f}'
    print(msg)

if test_baseline_ifrs:
    print(f'\nTEST Baseline Mean IFR: {np.mean(test_baseline_ifrs):.3f}')
if test_refined_ifrs:
    print(f'TEST Refined  Mean IFR: {np.mean(test_refined_ifrs):.3f}')

In [None]:
# Final summary figure
if baseline_ifrs:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    x = np.arange(len(baseline_ifrs))
    w = 0.35
    axes[0].bar(x - w/2, baseline_ifrs, w, label='Baseline', color='steelblue')
    if refined_ifrs:
        axes[0].bar(x + w/2, refined_ifrs, w, label='Refined', color='coral')
    axes[0].set_xlabel('Sample')
    axes[0].set_ylabel('IFR (lower = better)')
    axes[0].set_title('IFR Comparison (Train Samples)')
    axes[0].legend()

    if all_confs:
        axes[1].hist(all_confs, bins=30, color='mediumseagreen', edgecolor='white')
        axes[1].axvline(np.mean(all_confs), color='red', ls='--', label=f'mean={np.mean(all_confs):.3f}')
        axes[1].set_title('Confidence Distribution')
        axes[1].legend()

    plt.suptitle('Piano Fingering Detection - Results Summary', fontsize=14)
    plt.tight_layout()
    plt.show()

In [None]:
# Save results
output_dir = Path('/content/outputs' if IN_COLAB else './outputs')
output_dir.mkdir(parents=True, exist_ok=True)

results_summary = {
    'pipeline': 'piano-fingering-detection',
    'baseline_method': 'Gaussian Assignment (x-only, both hands, max-distance gate)',
    'refinement_method': 'BiLSTM + Attention + Constrained Viterbi',
    'test_results': []
}

for i, res in enumerate(test_results):
    entry = {'sample_id': res['sample_id'], 'num_assignments': len(res.get('assignments', []))}
    if i < len(test_baseline_ifrs):
        entry['baseline_ifr'] = float(test_baseline_ifrs[i])
    if i < len(test_refined_ifrs):
        entry['refined_ifr'] = float(test_refined_ifrs[i])
    results_summary['test_results'].append(entry)

if test_baseline_ifrs:
    results_summary['mean_baseline_ifr'] = float(np.mean(test_baseline_ifrs))
if test_refined_ifrs:
    results_summary['mean_refined_ifr'] = float(np.mean(test_refined_ifrs))

with open(output_dir / 'evaluation_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

if trained_model is not None:
    torch.save(trained_model.state_dict(), output_dir / 'refinement_model.pt')

print(f'Results saved to {output_dir}')