# Task 2

Conditioned drum beat prediction on 2-d pose data

### imports

In [94]:
#!pip install -q numpy pretty_midi torch pytorch_lightning matplotlib seaborn scipy

In [95]:
import argparse
import json
import math
import os
import random
import shutil
import subprocess
import sys
from datetime import datetime
from glob import glob
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pretty_midi as pm
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import seaborn as sns
from scipy.stats import pearsonr
from scipy.signal import correlate
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

### Model training config

In [96]:
SHIFT_SEC = 0.02
DRUM_TOKENS: Dict[str, int] = {
    "pad": 0,
    "kick": 1,
    "snare": 2,
    "hihat_closed": 3,
    "hihat_open": 4,
    "tom_low": 5,
    "tom_mid": 6,
    "tom_high": 7,
    "crash": 8,
    "ride": 9,
}
# time‑shift tokens (20 ms each, up to 2 s)
SHIFT_OFFSET = len(DRUM_TOKENS)
MAX_SHIFT = 100  # 100 × 20 ms  = 2 s
for i in range(1, MAX_SHIFT + 1):
    DRUM_TOKENS[f"shift_{i}"] = SHIFT_OFFSET + i

# sequence control tokens – **added**
DRUM_TOKENS["bos"] = len(DRUM_TOKENS)  # begin‑of‑sequence
DRUM_TOKENS["eos"] = len(DRUM_TOKENS)  # end‑of‑sequence

VOCAB_SIZE = len(DRUM_TOKENS)
IDX2TOKEN = {v: k for k, v in DRUM_TOKENS.items()}
PAD_IDX = DRUM_TOKENS["pad"]
BOS_IDX = DRUM_TOKENS["bos"]
EOS_IDX = DRUM_TOKENS["eos"]

In [97]:
### Model Training Utilities

In [98]:
def _pitch_to_token(p: int) -> str:
    # General MIDI → symbolic token
    return (
        "kick"
        if p in (35, 36)
        else (
            "snare"
            if p in (38, 40)
            else (
                "hihat_closed"
                if p in (42, 44)
                else (
                    "hihat_open"
                    if p == 46
                    else (
                        "tom_low"
                        if p in (41, 45)
                        else (
                            "tom_mid"
                            if p in (47, 48)
                            else (
                                "tom_high"
                                if p == 50
                                else (
                                    "crash"
                                    if p in (49, 57)
                                    else "ride" if p in (51, 59) else "snare"
                                )
                            )
                        )
                    )
                )
            )
        )
    )

In [99]:
def midi_to_tokens(mid: pm.PrettyMIDI, time_unit: float = SHIFT_SEC) -> List[int]:
    """Drum MIDI → event tokens (no BOS/EOS)."""
    events: List[Tuple[float, str]] = []
    for inst in mid.instruments:
        if not inst.is_drum:
            continue
        for note in inst.notes:
            events.append((note.start, _pitch_to_token(note.pitch)))
    events.sort(key=lambda x: x[0])

    tokens, prev_time = [], 0.0
    for t, tok in events:
        delta = t - prev_time
        n_shift = int(round(delta / time_unit))
        while n_shift > MAX_SHIFT:
            tokens.append(DRUM_TOKENS[f"shift_{MAX_SHIFT}"])
            n_shift -= MAX_SHIFT
        if n_shift > 0:
            tokens.append(DRUM_TOKENS[f"shift_{n_shift}"])
        tokens.append(DRUM_TOKENS[tok])
        prev_time = t
    return tokens

In [100]:
def collate_fn(batch):
    pose, tok = zip(*batch)
    return (torch.nn.utils.rnn.pad_sequence(pose, batch_first=True), torch.stack(tok))

### Model dataset class

In [101]:
class ChoreoGrooveDataset(Dataset):
    def __init__(self, root: str, seq_len: int = 512):
        self.items = sorted(glob(os.path.join(root, "*", "pose.npy")))
        self.seq_len = seq_len

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx: int):
        pose_path = self.items[idx]
        drum_path = pose_path.replace("pose.npy", "drums.mid")

        # pose → features
        pose = np.load(pose_path).reshape(-1, 51)  # (T, 17×3)
        vel = np.diff(pose, axis=0, prepend=pose[:1])
        feats = np.concatenate([pose, vel], axis=-1)  # (T, 102)
        feats = (feats - feats.mean()) / (feats.std() + 1e-5)
        feats = feats.astype(np.float32)

        # drums → tokens  [+ BOS/EOS, pad/trim]
        tokens = [BOS_IDX] + midi_to_tokens(pm.PrettyMIDI(drum_path)) + [EOS_IDX]
        if len(tokens) < self.seq_len:
            tokens += [PAD_IDX] * (self.seq_len - len(tokens))
        else:
            tokens = tokens[: self.seq_len]

        return torch.from_numpy(feats), torch.tensor(tokens, dtype=torch.long)

### Models

In [102]:
class PoseEncoder(nn.Module):
    def __init__(self, in_feats=102, embed=256):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_feats, 128, 5, padding=2),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, embed, 3, padding=1),
            nn.ReLU(),
        )
        self.gru = nn.GRU(embed, embed, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(embed * 2, embed)

    def forward(self, x):  # x (B,T,F)
        x = self.conv(x.transpose(1, 2)).transpose(1, 2)  # (B,T,E)
        x, _ = self.gru(x)
        return self.proj(x).transpose(0, 1)  # (T,B,E)

In [103]:
class DrumDecoder(nn.Module):
    def __init__(self, embed=256, layers=4, nhead=8, vocab=VOCAB_SIZE):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab, embed)
        self.pos_emb = nn.Embedding(1024, embed)
        dec_layer = nn.TransformerDecoderLayer(embed, nhead, 1024, batch_first=True)
        self.transformer = nn.TransformerDecoder(dec_layer, layers)
        self.fc_out = nn.Linear(embed, vocab)

    def forward(self, tgt, memory):  # tgt (B,L), memory (T,B,E)
        pos = torch.arange(tgt.size(1), device=tgt.device).unsqueeze(0)
        tgt = self.tok_emb(tgt) + self.pos_emb(pos)
        mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(
            tgt.device
        )
        out = self.transformer(tgt, memory.transpose(0, 1), tgt_mask=mask)
        return self.fc_out(out)  # (B,L,V)


In [104]:
class Choreo2GrooveModel(pl.LightningModule):
    def __init__(self, in_feats: int, lr=1e-4):
        super().__init__()
        self.encoder = PoseEncoder(in_feats)
        self.decoder = DrumDecoder()
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
        self.save_hyperparameters()

    # forward
    def forward(self, poses, tokens):
        memory = self.encoder(poses)  # (T,B,E)
        if self.training:
            tgt_in = tokens[:, :-1]  # strip last (EOS / PAD)
            return self.decoder(tgt_in, memory)  # (B,L‑1,V)
        else:
            return self.decoder(tokens, memory)

    # training
    def training_step(self, batch, _):
        pose, tok = batch
        logits = self(pose, tok)
        loss = self.loss_fn(logits.reshape(-1, VOCAB_SIZE), tok[:, 1:].reshape(-1))
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

### Training Optimizations

In [105]:
epochs=5
lr=1e-4
batch_size=4
seq_len=256
version=0

In [106]:
def check_gpu_availability():
    if torch.cuda.is_available():
        gpu_count = torch.cuda.device_count()
        print(f"Using GPU")
        return True, gpu_count
    else:
        print("Using CPU")
        return False, 0

In [107]:
# Check GPU availability
has_gpu, gpu_count = check_gpu_availability()

Using GPU


In [108]:
# Setup DataLoader - optimized for GPU
num_workers = (
    0 if sys.platform.startswith("win") else min(4, gpu_count * 2) if has_gpu else 2
)
pin_memory = has_gpu  # Use pinned memory for faster GPU transfer

In [109]:
# Adjust batch size for GPU if available
if has_gpu and batch_size < 8:
    original_batch_size = batch_size
    batch_size = min(16, batch_size * 2)  # Increase batch size for GPU
    print(
        f"GPU detected: increasing batch size from {original_batch_size} to {batch_size}"
    )

GPU detected: increasing batch size from 4 to 8


### Dataset

In [110]:
dataset = ChoreoGrooveDataset("dataset_root", seq_len=seq_len)

In [111]:
# Calculate input features from first sample
sample_pose, _ = dataset[0]
in_feats = sample_pose.shape[-1]
print(f"Dataset loaded: {len(dataset)} samples, {in_feats} features per frame")

Dataset loaded: 76 samples, 102 features per frame


### Initialize Model

In [112]:
model = Choreo2GrooveModel(in_feats=in_feats, lr=lr)

In [113]:
dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=num_workers,
        pin_memory=pin_memory,
        persistent_workers=num_workers > 0,
    )

### Logging

In [114]:
checkpoint_callback = ModelCheckpoint(
        monitor="train_loss",
        filename="choreo2groove-{epoch:02d}-{train_loss:.3f}",
        save_top_k=1,
        mode="min",
        save_last=True,
    )

In [115]:
logger = TensorBoardLogger("lightning_logs", version=version)

### Training

In [116]:
trainer_kwargs = {
        "max_epochs": epochs,
        "callbacks": [checkpoint_callback],
        "logger": logger,
        "log_every_n_steps": 10,
        "check_val_every_n_epoch": 1,
        "enable_progress_bar": True,
        "enable_model_summary": True,
    }

In [117]:
if has_gpu:
        trainer_kwargs.update(
            {
                "accelerator": "gpu",
                "devices": min(gpu_count, 1),  # Use 1 GPU for now
                "precision": "16-mixed",  # Mixed precision for faster training
            }
        )
        print("GPU training enabled with mixed precision")
else:
    trainer_kwargs.update(
        {
            "accelerator": "cpu",
            "devices": 1,
        }
    )
    print("CPU training mode")

GPU training enabled with mixed precision


In [118]:
trainer = pl.Trainer(**trainer_kwargs)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [119]:
# Start training
print(f"Starting training for {epochs} epochs...")
start_time = datetime.now()

Starting training for 5 epochs...


In [120]:
trainer.fit(model, dataloader)

end_time = datetime.now()
training_duration = end_time - start_time
print(f"\nTraining completed")
print(f"Training duration: {training_duration}")

# Get final metrics
final_loss = trainer.callback_metrics.get("train_loss", "unknown")
print(f"Final training loss: {final_loss}")

c:\Users\hajin\miniconda3\envs\cse-153-assignment2\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory lightning_logs\lightning_logs\version_0\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | encoder | PoseEncoder      | 1.1 M  | train
1 | decoder | DrumDecoder      | 4.5 M  | train
2 | loss_fn | CrossEntropyLoss | 0      | train
-----------------------------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.474    Total estimated model params size (MB)
72        Modules in train mode
0         Modules in eval mode
c:\Users\hajin\miniconda3\envs\cse-153-assignment2\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `

Epoch 4: 100%|██████████| 10/10 [00:01<00:00,  7.63it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 10/10 [00:01<00:00,  5.08it/s, v_num=0]

Training completed
Training duration: 0:00:10.149039
Final training loss: 0.8090565204620361


# Generate Drum Beats

### Utilities

In [123]:
sys.path.append(".")

In [124]:
def load_trained_model(checkpoint_path):
    """Load the trained model with GPU support"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model = Choreo2GrooveModel(in_feats=102, lr=1e-4)
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()
    model = model.to(device)
    return model

In [125]:
def token_to_pitch(token_name):
    """Convert token name back to MIDI pitch"""
    pitch_map = {
        "kick": 36,
        "snare": 38,
        "hihat_closed": 42,
        "hihat_open": 46,
        "tom_low": 45,
        "tom_mid": 47,
        "tom_high": 50,
        "crash": 49,
        "ride": 51,
    }
    return pitch_map.get(token_name, 38)

In [None]:
def tokens_to_midi(tokens, time_unit=SHIFT_SEC, bpm=120):
    """Convert drum tokens back to MIDI"""
    midi = pm.PrettyMIDI(initial_tempo=bpm)
    drums = pm.Instrument(program=0, is_drum=True, name="Generated_Drums")

    current_time = 0.0

    for token_id in tokens:
        if token_id >= VOCAB_SIZE:
            continue

        token_name = IDX2TOKEN.get(token_id, "unknown")
        if token_name in ("pad", "bos", "eos"):  # <<< skip BOS/EOS
            continue
        elif token_name.startswith("shift_"):
            shift_amount = int(token_name.split("_")[1])
            current_time += shift_amount * time_unit
        elif token_name in [
            "kick",
            "snare",
            "hihat_closed",
            "hihat_open",
            "tom_low",
            "tom_mid",
            "tom_high",
            "crash",
            "ride",
        ]:
            pitch = token_to_pitch(token_name)
            velocity = random.randint(80, 120)
            note = pm.Note(pitch, velocity, current_time, current_time + 0.1)
            drums.notes.append(note)

    midi.instruments.append(drums)
    return midi

In [None]:
def generate_drum_beat(model, pose_data, max_length=256):
    device = next(model.parameters()).device
    pose_tensor = torch.from_numpy(pose_data).unsqueeze(0).to(device)

    memory = model.encoder(pose_tensor)
    pose_dur = pose_data.shape[0] * SHIFT_SEC  # duration

    seq, elapsed = [BOS_IDX], 0.0
    with torch.no_grad():
        for _ in range(max_length):
            cur = torch.tensor(seq, dtype=torch.long, device=device).unsqueeze(0)
            logits = model.decoder(cur, memory)[0, -1]
            nxt = torch.multinomial(torch.softmax(logits / 0.8, -1), 1).item()
            seq.append(nxt)

            if IDX2TOKEN[nxt].startswith("shift_"):
                elapsed += int(IDX2TOKEN[nxt].split("_")[1]) * SHIFT_SEC
            if nxt == EOS_IDX or elapsed >= pose_dur:
                break
    return seq