In [None]:
# Local setup: point to the Tap-to-Music repo and project root (no Colab drive needed)
# %pip install -q torch pytorch-lightning pretty_midi  # uncomment if deps are missing
from pathlib import Path
import sys

PROJECT_ROOT = Path.cwd()
TAP_TO_MUSIC_PATH = PROJECT_ROOT / "../Tap-to-Music"  # set to your local clone of https://github.com/lynnzYe/Tap-to-Music

if not TAP_TO_MUSIC_PATH.exists():
    raise FileNotFoundError(f"Tap-to-Music repo not found at {TAP_TO_MUSIC_PATH}. Clone it locally first.")

sys.path.append(str(TAP_TO_MUSIC_PATH))
sys.path.append(str(PROJECT_ROOT))

print(f"Working directory: {PROJECT_ROOT}")
print("Tap-to-Music path:", TAP_TO_MUSIC_PATH)
print("Workspace contents:", [p.name for p in PROJECT_ROOT.iterdir()])


Working directory: /Users/ffr/Desktop/10701/project
Tap-to-Music path: /Users/ffr/Desktop/10701/project/Tap-to-Music
Workspace contents: ['.DS_Store', 'hannds', 'features', '__pycache__', 'maestro-v3.0.0', '10701project.ipynb', 'features...', 'left_right_data.py', '.vscode', 'outputs', 'lightning_logs', 'Tap-to-Music']


In [11]:
# Build a DataModule backed by Tap-to-Music RangeDataset reading unconditional PKLs
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from ttm.data_preparation.dataset import RangeDataset

FEATURE_FOLDER = PROJECT_ROOT / "features"

class UncondRangeDataModule(pl.LightningDataModule):
    def __init__(self, feature_folder: str, batch_size: int = 32, num_workers: int = 0):
        super().__init__()
        self.feature_folder = feature_folder
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        for split in ["train", "validation"]:
            path = os.path.join(self.feature_folder, f"unconditional-{split}.pkl")
            if not os.path.exists(path):
                raise FileNotFoundError(f"Missing {path}; run FeaturePreparation(feature='unconditional') to generate.")
        # Use RangeDataset but point to unconditional PKLs
        self.train_ds = RangeDataset(self.feature_folder, "train", feature_type="unconditional")
        self.val_ds = RangeDataset(self.feature_folder, "validation", feature_type="unconditional")

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            drop_last=True,
        )

range_data_module = UncondRangeDataModule(str(FEATURE_FOLDER), batch_size=32, num_workers=0)
# Use this for training below
data_module = range_data_module


In [3]:
# Range-aware LSTM over full 4-dim features; predicts next pitch, velocity, and duration
import torch
import torch.nn as nn
import torch.nn.functional as F
from ttm.config import MAX_PIANO_PITCH, MIN_PIANO_PITCH

class RangeTapLSTM(nn.Module):
    def __init__(
        self,
        pitch_vocab: int = MAX_PIANO_PITCH + 1,
        pitch_emb_dim: int = 32,
        range_vocab: int = 3,
        range_emb_dim: int = 3,
        hidden: int = 128,
        layers: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.pitch_emb = nn.Embedding(pitch_vocab + 1, pitch_emb_dim, padding_idx=pitch_vocab)
        self.range_emb = nn.Embedding(range_vocab if range_vocab is not None else len(range_bounds), range_emb_dim)
        self.input_linear = nn.Linear(pitch_emb_dim + range_emb_dim + 2, hidden)
        self.lstm = nn.LSTM(hidden, hidden, num_layers=layers, batch_first=True, dropout=dropout)
        self.pitch_head = nn.Linear(hidden, pitch_vocab + 1)
        self.vel_head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
        self.dur_head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))

    def forward(self, feats: torch.Tensor, range_ids: torch.Tensor):
        """
        feats: (B, T, 4) = [pitch, log_dt, log_dur, velocity]
        range_ids: (B,) int
        returns pitch_logits (B,T,V), vel_pred (B,T), dur_pred (B,T)
        """
        B, T, _ = feats.shape
        pitch_idx = feats[..., 0].long().clamp(min=0, max=MAX_PIANO_PITCH)
        log_dt = feats[..., 1]
        log_dur = feats[..., 2]
        vel = feats[..., 3]

        pitch_embed = self.pitch_emb(pitch_idx)
        range_embed = self.range_emb(range_ids).unsqueeze(1).expand(B, T, -1)
        x = torch.cat([pitch_embed, range_embed, log_dt.unsqueeze(-1), log_dur.unsqueeze(-1)], dim=-1)
        x = self.input_linear(x)
        out, _ = self.lstm(x)

        pitch_logits = self.pitch_head(out)
        vel_pred = self.vel_head(out).squeeze(-1)
        dur_pred = F.softplus(self.dur_head(out).squeeze(-1))  # keep durations positive
        return pitch_logits, vel_pred, dur_pred

    @torch.no_grad()
    def generate_next(self, feats: torch.Tensor, range_id: int, temperature: float = 1.0):
        self.eval()
        logits, vel_pred, dur_pred = self(feats.unsqueeze(0), torch.tensor([range_id], device=feats.device))
        next_pitch_logits = logits[:, -1, :] / temperature
        probs = F.softmax(next_pitch_logits, dim=-1)
        next_pitch = torch.multinomial(probs, num_samples=1).squeeze()
        next_vel = vel_pred[:, -1].squeeze().clamp(0, 127)
        next_dur = dur_pred[:, -1].squeeze()
        return next_pitch.item(), next_vel.item(), next_dur.item()


In [12]:
import torch
torch.set_default_dtype(torch.float32)


In [13]:
import pytorch_lightning as pl
import torch.nn.functional as F

class RangeTapModule(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.model = RangeTapLSTM(**kwargs)
        self.lr = 1e-3

    def training_step(self, batch, batch_idx):
        feats, labels, range_ids = batch  # adjust to your dataloader format
        pitch_logits, vel_pred, dur_pred = self.model(feats, range_ids)
        ce = F.cross_entropy(pitch_logits.view(-1, pitch_logits.size(-1)),
                             labels.view(-1), ignore_index=88)
        vel_loss = F.mse_loss(vel_pred, feats[..., 3])  # or your target
        dur_loss = F.mse_loss(dur_pred, feats[..., 2])
        loss = ce + vel_loss + dur_loss
        self.log_dict({"train_loss": loss})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)


In [6]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import Trainer

module = RangeTapModule()
trainer = Trainer(
    max_epochs=5,
    accelerator="gpu" if torch.cuda.is_available() else "mps",
    devices=1,
    gradient_clip_val=1.0,
    log_every_n_steps=10,
)
trainer.fit(module, datamodule=range_data_module)


ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/pytorch_lightning/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.

  | Name  | Type         | Params | Mode 
-----------------------------------------------
0 | model | RangeTapLSTM | 320 K  | train
-----------------------------------------------
320 K     Trainable params
0         Non-trainable params
320 K     Total params
1.280     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/site-packages/pytorch_lightning/trainer/connectors/da

Loading /Users/ffr/Desktop/10701/project/features/unconditional-train.pkl
Loading /Users/ffr/Desktop/10701/project/features/unconditional-validation.pkl


Training: |          | 0/? [00:00<?, ?it/s]

ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 5 and the array at index 1 has size 4

In [14]:
# Cluster-aware LSTM conditioned on per-sequence median (5th feature column)
import torch
import torch.nn as nn
import torch.nn.functional as F
from ttm.config import MAX_PIANO_PITCH, MIN_PIANO_PITCH

class ClusterTapLSTM(nn.Module):
    def __init__(
        self,
        pitch_vocab: int = MAX_PIANO_PITCH + 1,
        pitch_emb_dim: int = 32,
        hidden: int = 128,
        layers: int = 2,
        dropout: float = 0.1,
    ):
        super().__init__()
        # Expect feats shape (B, T, 5): [pitch, log_dt, log_dur, velocity, median]
        self.pitch_emb = nn.Embedding(pitch_vocab + 1, pitch_emb_dim, padding_idx=pitch_vocab)
        self.input_linear = nn.Linear(pitch_emb_dim + 4, hidden)
        self.lstm = nn.LSTM(hidden, hidden, num_layers=layers, batch_first=True, dropout=dropout)
        self.pitch_head = nn.Linear(hidden, pitch_vocab + 1)
        self.vel_head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))
        self.dur_head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1))

    def forward(self, feats: torch.Tensor):
        """
        feats: (B, T, 5) = [pitch, log_dt, log_dur, velocity, median]
        Returns pitch_logits (B,T,V), vel_pred (B,T), dur_pred (B,T)
        """
        B, T, _ = feats.shape
        pitch_idx = feats[..., 0].long().clamp(min=0, max=MAX_PIANO_PITCH)
        log_dt = feats[..., 1]
        log_dur = feats[..., 2]
        vel = feats[..., 3]
        median = feats[..., 4]

        pitch_embed = self.pitch_emb(pitch_idx)
        x = torch.cat([
            pitch_embed,
            log_dt.unsqueeze(-1),
            log_dur.unsqueeze(-1),
            vel.unsqueeze(-1),
            median.unsqueeze(-1),
        ], dim=-1)
        x = self.input_linear(x)
        out, _ = self.lstm(x)

        pitch_logits = self.pitch_head(out)
        vel_pred = self.vel_head(out).squeeze(-1)
        dur_pred = F.softplus(self.dur_head(out).squeeze(-1))
        return pitch_logits, vel_pred, dur_pred

    @torch.no_grad()
    def generate_next(self, feats: torch.Tensor, temperature: float = 1.0):
        """Generate next note conditioned on the provided sequence (expects median in column 4)."""
        self.eval()
        logits, vel_pred, dur_pred = self(feats.unsqueeze(0))
        next_pitch_logits = logits[:, -1, :] / temperature
        probs = F.softmax(next_pitch_logits, dim=-1)
        next_pitch = torch.multinomial(probs, num_samples=1).squeeze()
        next_vel = vel_pred[:, -1].squeeze().clamp(0, 127)
        next_dur = dur_pred[:, -1].squeeze()
        return next_pitch.item(), next_vel.item(), next_dur.item()


In [15]:
# ClusterDataModule: wraps median-conditioned ClusterDataset (5-col features with median in col 4)
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from ttm.config import config
from ttm.data_preparation.dataset import ClusterDataset

class ClusterDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, batch_size: int = None, num_workers: int = None):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size or config.get('unconditional', {}).get('batch_size', 16)
        self.num_workers = num_workers or config.get('unconditional', {}).get('num_workers', 0)

    def _ds(self, split: str):
        # ClusterDataset holds (features, labels, median) tuples; features include median in column 4
        return ClusterDataset(self.data_dir, split, feature_type='unconditional')

    def train_dataloader(self):
        return DataLoader(
            self._ds('train'),
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=self.num_workers > 0,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self._ds('validation'),
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=self.num_workers > 0,
            drop_last=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self._ds('test'),
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=self.num_workers > 0,
            drop_last=False,
        )


In [7]:
# Generate PKLs from split HANNDs left/right MIDIs
import pickle
from pathlib import Path
from tqdm import tqdm
from ttm.data_preparation.utils import get_note_sequence_from_midi, midi_to_tap

left_dir = Path('outputs/hannds_split/left')
right_dir = Path('outputs/hannds_split/right')
output_dir = Path('features')  # same folder as unconditional-*.pkl


def midi_dir_to_pkl(midi_root: Path, out_path: Path):
    midi_paths = sorted([p for p in midi_root.rglob('*') if p.suffix.lower() in {'.mid', '.midi'}])
    data = []
    for midi_path in tqdm(midi_paths, desc=f'Processing {midi_root.name}'):
        try:
            notes = get_note_sequence_from_midi(midi_path)
            feats, labels = midi_to_tap(notes)
            data.append((feats, labels))
        except Exception as exc:
            print(f'Skip {midi_path}: {exc}')
    out_path.parent.mkdir(parents=True, exist_ok=True)
    pickle.dump(data, open(out_path, 'wb'))
    print(f'Saved {len(data)} samples to {out_path}')

midi_dir_to_pkl(left_dir, output_dir / 'hannds_left.pkl')
midi_dir_to_pkl(right_dir, output_dir / 'hannds_right.pkl')


Processing left: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1276/1276 [01:07<00:00, 18.87it/s]


Saved 1276 samples to features/hannds_left.pkl


Processing right: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1276/1276 [00:53<00:00, 23.82it/s]


Saved 1276 samples to features/hannds_right.pkl


In [8]:
# Build a single ClusterAugmentation PKL from HANNDs-split MIDIs (left/right combined)
import pickle
from pathlib import Path
from tqdm import tqdm
from ttm.data_preparation.utils import get_note_sequence_from_midi, midi_to_tap
from ttm.data_preparation.data_augmentation import ClusterAugmentation

split_root = Path('outputs/hannds_split')
output_dir = Path('features')
out_path = output_dir / 'hannds_cluster.pkl'

cluster_aug = ClusterAugmentation()

midi_paths = [p for p in split_root.rglob('*') if p.suffix.lower() in {'.mid', '.midi'}]
data = []
for midi_path in tqdm(sorted(midi_paths), desc='ClusterAug PKL'):
    try:
        notes = get_note_sequence_from_midi(midi_path)
        feats, labels = midi_to_tap(notes)
        feats_aug, labels_aug, median_info = cluster_aug(feats, labels)
        data.append((feats_aug, labels_aug, median_info))
    except Exception as exc:
        print(f'Skip {midi_path}: {exc}')

output_dir.mkdir(parents=True, exist_ok=True)
pickle.dump(data, open(out_path, 'wb'))
print(f'Saved {len(data)} samples to {out_path}')


ClusterAug PKL: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2578/2578 [02:02<00:00, 21.03it/s]


Saved 2578 samples to features/hannds_cluster.pkl


In [19]:
# Train a hand-aware Cluster LSTM on hannds_cluster.pkl (pitch + hand flag + medians)
import pickle
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from ttm.config import MIN_PIANO_PITCH, MAX_PIANO_PITCH

pkl_path = Path('features/hannds_cluster.pkl')
raw = pickle.load(open(pkl_path, 'rb'))
print('Loaded samples:', len(raw))

max_len = 128

class ClusterHandDataset(Dataset):
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        feats, labels, median_info = self.data[idx]
        # truncate
        feats = feats[:max_len]
        labels = labels[:max_len]
        # ensure 8 columns (pitch, dt, dur, vel, median, hand, left_med, right_med)
        if feats.shape[1] < 8:
            pad_cols = 8 - feats.shape[1]
            feats = np.concatenate([feats, np.zeros((len(feats), pad_cols))], axis=1)
        # normalize pitch to 0..88
        feats = feats.copy()
        feats[:, 0] = np.clip(feats[:, 0] - MIN_PIANO_PITCH, 0, 88)
        labels = labels - MIN_PIANO_PITCH
        # pad sequence
        if len(feats) < max_len:
            pad_len = max_len - len(feats)
            pad_row = np.zeros((1, feats.shape[1]))
            pad_row[0, 0] = 88
            feats = np.concatenate([feats, np.repeat(pad_row, pad_len, axis=0)], axis=0)
            label_pad = np.full(pad_len, 88)
            labels = np.concatenate([labels, label_pad])
        hand_labels = feats[:, 5].copy()
        return torch.tensor(feats, dtype=torch.float32), torch.tensor(labels, dtype=torch.long), torch.tensor(hand_labels, dtype=torch.float32)

# simple split
split_idx = int(0.9 * len(raw))
train_ds = ClusterHandDataset(raw[:split_idx])
val_ds = ClusterHandDataset(raw[split_idx:])
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=8)

class ClusterHandLSTM(nn.Module):
    def __init__(self, pitch_vocab=MAX_PIANO_PITCH+1, pitch_emb_dim=32, hidden=128, layers=2, dropout=0.1):
        super().__init__()
        self.pitch_emb = nn.Embedding(pitch_vocab+1, pitch_emb_dim, padding_idx=pitch_vocab)
        # scalars: dt, dur, vel, median, hand, left_med, right_med = 7
        self.input_linear = nn.Linear(pitch_emb_dim + 7, hidden)
        self.lstm = nn.LSTM(hidden, hidden, num_layers=layers, batch_first=True, dropout=dropout)
        self.pitch_head = nn.Linear(hidden, pitch_vocab+1)
        self.hand_head = nn.Linear(hidden, 1)
    def forward(self, x):
        pitch_idx = x[...,0].long().clamp(0, MAX_PIANO_PITCH)
        scalars = x[...,1:8]
        pitch_embed = self.pitch_emb(pitch_idx)
        h = torch.cat([pitch_embed, scalars], dim=-1)
        h = self.input_linear(h)
        out,_ = self.lstm(h)
        pitch_logits = self.pitch_head(out)
        hand_logits = self.hand_head(out).squeeze(-1)
        return pitch_logits, hand_logits

model = ClusterHandLSTM()
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)

for epoch in range(50):
    model.train()
    for feats, labels, hand in train_dl:
        optim.zero_grad()
        pitch_logits, hand_logits = model(feats)
        ce = F.cross_entropy(pitch_logits.view(-1, pitch_logits.size(-1)), labels.view(-1), ignore_index=88)
        bce = F.binary_cross_entropy_with_logits(hand_logits.view(-1), hand.view(-1))
        loss = ce + 0.1 * bce
        loss.backward()
        optim.step()
    model.eval()
    with torch.no_grad():
        feats, labels, hand = next(iter(val_dl))
        pitch_logits, hand_logits = model(feats)
        ce = F.cross_entropy(pitch_logits.view(-1, pitch_logits.size(-1)), labels.view(-1), ignore_index=88)
        bce = F.binary_cross_entropy_with_logits(hand_logits.view(-1), hand.view(-1))
        print('val ce', ce.item(), 'val bce', bce.item())


Loaded samples: 2578
val ce 3.7939863204956055 val bce 0.004972664639353752
val ce 3.4566445350646973 val bce 0.0033888458274304867
val ce 3.321607828140259 val bce 0.002891015028581023
val ce 3.2729179859161377 val bce 0.0010837360750883818
val ce 3.2381386756896973 val bce 0.0007364661432802677
val ce 3.2002110481262207 val bce 0.0005740937776863575
val ce 3.148324489593506 val bce 0.0011368109844624996
val ce 3.04152250289917 val bce 0.001097196713089943
val ce 2.9973466396331787 val bce 0.0007472229190170765
val ce 2.9728150367736816 val bce 0.00044573377817869186
val ce 2.978398561477661 val bce 0.00043291784822940826
val ce 2.9560670852661133 val bce 0.0003445935435593128
val ce 2.934034585952759 val bce 0.00021225400269031525
val ce 2.9147841930389404 val bce 0.00024826545268297195
val ce 2.903575897216797 val bce 0.00021154293790459633
val ce 2.8765547275543213 val bce 0.00012388359755277634
val ce 2.8477938175201416 val bce 0.00011513056233525276
val ce 2.849213123321533 val b