# Drum Transcription Model Training (OaF-style)

This notebook trains a drum transcription model based on the Onsets and Frames architecture.

**Features:**
- Multi-class drum detection (kick, snare, hi-hat, toms, cymbals)
- Velocity prediction for dynamics
- Ghost note detection

**Dataset:** E-GMD (Expanded Groove MIDI Dataset) - 444 hours of human drum performances

**Requirements:** Colab Pro recommended (~225 GB disk needed for E-GMD)

**Estimated time:** 6-12 hours on A100, 12-24 hours on T4

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install dependencies
!pip install -q torch torchaudio librosa pretty_midi mir_eval scikit-learn tqdm wandb

In [None]:
import os

EGMD_URL = "https://storage.googleapis.com/magentadata/datasets/e-gmd/v1.0.0/e-gmd-v1.0.0.zip"
DATA_DIR = "/content/e-gmd"

# Clean up any failed previous attempts
!rm -rf /content/e-gmd.zip /content/e-gmd-v1.0.0 /content/e-gmd
!pip cache purge 2>/dev/null
!rm -rf /root/.cache/pip /tmp/* 2>/dev/null

print("=== Disk space (cleaned) ===")
!df -h /content

if not os.path.exists(DATA_DIR):
    print("\nStreaming E-GMD download + extraction (~90 GB, ~15 min)...")
    print("(zip is piped directly to extractor â€” never stored on disk)")

    # bsdtar reads the zip from stdin and extracts on the fly
    # This avoids needing 2x disk space (zip + extracted)
    !wget -q --show-progress -O - "{EGMD_URL}" | bsdtar -xf - -C /content/

    # Rename extracted directory to expected path
    !mv /content/e-gmd-v1.0.0 /content/e-gmd 2>/dev/null || true

    print("\n=== Disk after extraction ===")
    !df -h /content
else:
    print("Dataset already exists, skipping download.")

# Verify
print(f"\nWAV files:")
!find {DATA_DIR} -name "*.wav" | wc -l
print(f"MIDI files:")
!find {DATA_DIR} -name "*.midi" | wc -l

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import librosa
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import json

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Configuration
CONFIG = {
    'sample_rate': 16000,
    'hop_length': 256,
    'n_mels': 128,
    'n_fft': 2048,
    
    # Drum classes (General MIDI)
    'drum_classes': {
        'kick': [35, 36],
        'snare': [38, 40],
        'hihat_closed': [42, 44],
        'hihat_open': [46],
        'tom_low': [41, 43, 45],
        'tom_high': [47, 48, 50],
        'crash': [49, 57],
        'ride': [51, 59],
    },
    'num_classes': 8,
    
    # Training
    'batch_size': 16,
    'learning_rate': 1e-4,
    'num_epochs': 50,
    'chunk_length_sec': 5.0,
}

# Create class mapping
MIDI_TO_CLASS = {}
for class_idx, (class_name, midi_notes) in enumerate(CONFIG['drum_classes'].items()):
    for midi_note in midi_notes:
        MIDI_TO_CLASS[midi_note] = class_idx

CLASS_NAMES = list(CONFIG['drum_classes'].keys())
print(f"Drum classes: {CLASS_NAMES}")

In [None]:
class DrumDataset(Dataset):
    """Dataset for drum transcription training using E-GMD."""
    
    def __init__(self, data_dir, split='train', chunk_length_sec=5.0, sample_rate=16000):
        self.data_dir = Path(data_dir)
        self.chunk_length = chunk_length_sec
        self.sr = sample_rate
        self.hop_length = CONFIG['hop_length']
        
        # Load metadata
        self.samples = self._load_samples(split)
        print(f"Loaded {len(self.samples)} samples for {split}")
        
    def _load_samples(self, split):
        """Load sample paths from E-GMD structure."""
        samples = []
        
        # E-GMD structure: drummer/session/file.wav + file.midi
        # NOTE: E-GMD uses .midi extension, NOT .mid
        for audio_path in self.data_dir.rglob('*.wav'):
            midi_path = audio_path.with_suffix('.midi')
            if midi_path.exists():
                samples.append({
                    'audio': str(audio_path),
                    'midi': str(midi_path)
                })
        
        if len(samples) == 0:
            # Debug: check if files exist with different extension
            wav_count = len(list(self.data_dir.rglob('*.wav')))
            midi_count = len(list(self.data_dir.rglob('*.midi')))
            mid_count = len(list(self.data_dir.rglob('*.mid')))
            print(f"WARNING: 0 matched samples!")
            print(f"  WAV files found: {wav_count}")
            print(f"  .midi files found: {midi_count}")
            print(f"  .mid files found: {mid_count}")
            # Show sample filenames
            for f in list(self.data_dir.rglob('*.wav'))[:3]:
                print(f"  Sample WAV: {f.name}")
                midi_candidate = f.with_suffix('.midi')
                print(f"    Expected MIDI: {midi_candidate.name} (exists: {midi_candidate.exists()})")
        
        # Split 90/10 train/val
        np.random.seed(42)
        np.random.shuffle(samples)
        split_idx = int(len(samples) * 0.9)
        
        if split == 'train':
            return samples[:split_idx]
        else:
            return samples[split_idx:]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load audio
        audio, sr = torchaudio.load(sample['audio'])
        if sr != self.sr:
            audio = torchaudio.functional.resample(audio, sr, self.sr)
        
        # Convert to mono
        if audio.shape[0] > 1:
            audio = audio.mean(dim=0, keepdim=True)
        audio = audio.squeeze(0)
        
        # Random chunk
        chunk_samples = int(self.chunk_length * self.sr)
        if len(audio) > chunk_samples:
            start = np.random.randint(0, len(audio) - chunk_samples)
            audio = audio[start:start + chunk_samples]
            start_time = start / self.sr
        else:
            # Pad if too short
            audio = F.pad(audio, (0, chunk_samples - len(audio)))
            start_time = 0
        
        # Compute mel spectrogram
        mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sr,
            n_fft=CONFIG['n_fft'],
            hop_length=self.hop_length,
            n_mels=CONFIG['n_mels']
        )(audio)
        mel_spec = torch.log(mel_spec + 1e-8)
        
        # Load MIDI and create target
        import pretty_midi
        midi = pretty_midi.PrettyMIDI(sample['midi'])
        
        # Create onset and frame targets
        num_frames = mel_spec.shape[-1]
        onsets = torch.zeros(CONFIG['num_classes'], num_frames)
        frames = torch.zeros(CONFIG['num_classes'], num_frames)
        velocities = torch.zeros(CONFIG['num_classes'], num_frames)
        
        end_time = start_time + self.chunk_length
        
        for instrument in midi.instruments:
            if instrument.is_drum:
                for note in instrument.notes:
                    if start_time <= note.start < end_time:
                        if note.pitch in MIDI_TO_CLASS:
                            class_idx = MIDI_TO_CLASS[note.pitch]
                            onset_frame = int((note.start - start_time) * self.sr / self.hop_length)
                            end_frame = int((note.end - start_time) * self.sr / self.hop_length)
                            
                            if 0 <= onset_frame < num_frames:
                                onsets[class_idx, onset_frame] = 1
                                velocities[class_idx, onset_frame] = note.velocity / 127.0
                                
                            for f in range(max(0, onset_frame), min(num_frames, end_frame)):
                                frames[class_idx, f] = 1
        
        return {
            'mel_spec': mel_spec,
            'onsets': onsets,
            'frames': frames,
            'velocities': velocities
        }

In [None]:
class DrumTranscriptionModel(nn.Module):
    """Onsets and Frames style model for drum transcription."""
    
    def __init__(self, n_mels=128, num_classes=8):
        super().__init__()
        
        # Acoustic model (CNN)
        self.conv_stack = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(0.25),
            
            nn.Conv2d(32, 64, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(0.25),
            
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(0.25),
        )
        
        # Calculate flattened size
        self.flat_size = 128 * (n_mels // 4)
        
        # Onset detection branch
        self.onset_lstm = nn.LSTM(
            self.flat_size, 128, num_layers=2,
            batch_first=True, bidirectional=True, dropout=0.3
        )
        self.onset_fc = nn.Linear(256, num_classes)
        
        # Frame detection branch (combined with onset)
        self.frame_lstm = nn.LSTM(
            self.flat_size + num_classes, 128, num_layers=2,
            batch_first=True, bidirectional=True, dropout=0.3
        )
        self.frame_fc = nn.Linear(256, num_classes)
        
        # Velocity prediction
        self.velocity_fc = nn.Linear(256, num_classes)
        
    def forward(self, x):
        # x: (batch, n_mels, time)
        x = x.unsqueeze(1)  # Add channel dim
        
        # CNN
        x = self.conv_stack(x)  # (batch, 128, n_mels//4, time)
        
        # Reshape for LSTM: (batch, time, features)
        batch, channels, freq, time = x.shape
        x = x.permute(0, 3, 1, 2).reshape(batch, time, -1)
        
        # Onset detection
        onset_features, _ = self.onset_lstm(x)
        onset_pred = torch.sigmoid(self.onset_fc(onset_features))
        
        # Frame detection (with onset info)
        frame_input = torch.cat([x, onset_pred], dim=-1)
        frame_features, _ = self.frame_lstm(frame_input)
        frame_pred = torch.sigmoid(self.frame_fc(frame_features))
        
        # Velocity
        velocity_pred = torch.sigmoid(self.velocity_fc(onset_features))
        
        # Permute back: (batch, classes, time)
        onset_pred = onset_pred.permute(0, 2, 1)
        frame_pred = frame_pred.permute(0, 2, 1)
        velocity_pred = velocity_pred.permute(0, 2, 1)
        
        return onset_pred, frame_pred, velocity_pred

In [None]:
# Create datasets and dataloaders
train_dataset = DrumDataset(DATA_DIR, split='train')
val_dataset = DrumDataset(DATA_DIR, split='val')

train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True, 
    num_workers=4,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=False,
    num_workers=4
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
# Initialize model and training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Training on: {device}")

model = DrumTranscriptionModel(
    n_mels=CONFIG['n_mels'],
    num_classes=CONFIG['num_classes']
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

# Loss functions
onset_criterion = nn.BCELoss()
frame_criterion = nn.BCELoss()
velocity_criterion = nn.MSELoss()

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

In [None]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    
    pbar = tqdm(loader, desc='Training')
    for batch in pbar:
        mel_spec = batch['mel_spec'].to(device)
        onsets_target = batch['onsets'].to(device)
        frames_target = batch['frames'].to(device)
        velocities_target = batch['velocities'].to(device)
        
        optimizer.zero_grad()
        
        onset_pred, frame_pred, velocity_pred = model(mel_spec)
        
        # Match dimensions
        min_len = min(onset_pred.shape[-1], onsets_target.shape[-1])
        onset_pred = onset_pred[..., :min_len]
        frame_pred = frame_pred[..., :min_len]
        velocity_pred = velocity_pred[..., :min_len]
        onsets_target = onsets_target[..., :min_len]
        frames_target = frames_target[..., :min_len]
        velocities_target = velocities_target[..., :min_len]
        
        # Compute losses
        onset_loss = onset_criterion(onset_pred, onsets_target)
        frame_loss = frame_criterion(frame_pred, frames_target)
        
        # Velocity loss only where there are onsets
        onset_mask = onsets_target > 0.5
        if onset_mask.any():
            vel_loss = velocity_criterion(
                velocity_pred[onset_mask], 
                velocities_target[onset_mask]
            )
        else:
            vel_loss = torch.tensor(0.0, device=device)
        
        loss = onset_loss + frame_loss + 0.5 * vel_loss
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(loader)


def validate(model, loader, device):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc='Validation'):
            mel_spec = batch['mel_spec'].to(device)
            onsets_target = batch['onsets'].to(device)
            frames_target = batch['frames'].to(device)
            
            onset_pred, frame_pred, _ = model(mel_spec)
            
            min_len = min(onset_pred.shape[-1], onsets_target.shape[-1])
            onset_pred = onset_pred[..., :min_len]
            frame_pred = frame_pred[..., :min_len]
            onsets_target = onsets_target[..., :min_len]
            frames_target = frames_target[..., :min_len]
            
            onset_loss = onset_criterion(onset_pred, onsets_target)
            frame_loss = frame_criterion(frame_pred, frames_target)
            
            total_loss += (onset_loss + frame_loss).item()
    
    return total_loss / len(loader)

In [None]:
# Training loop
SAVE_DIR = Path('/content/drive/MyDrive/drum_model_results')
SAVE_DIR.mkdir(parents=True, exist_ok=True)

best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': []}

print("Starting training...")
print(f"Saving checkpoints to: {SAVE_DIR}")

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    
    train_loss = train_epoch(model, train_loader, optimizer, device)
    val_loss = validate(model, val_loader, device)
    
    scheduler.step(val_loss)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': CONFIG,
            'class_names': CLASS_NAMES
        }, SAVE_DIR / 'best_drum_model.pt')
        print(f"  Saved best model (val_loss: {val_loss:.4f})")
    
    # Save checkpoint every 5 epochs (Colab can disconnect)
    if (epoch + 1) % 5 == 0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, SAVE_DIR / f'drum_model_epoch_{epoch+1}.pt')
        print(f"  Saved checkpoint at epoch {epoch+1}")

print("\nTraining complete!")

In [None]:
# Plot training history
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(history['train_loss'], label='Train Loss')
plt.plot(history['val_loss'], label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Drum Transcription Training')
plt.savefig(SAVE_DIR / 'training_history.png')
plt.show()

In [None]:
# List saved models
print("Saved models:")
!ls -la {SAVE_DIR}/*.pt

## Training Complete!

Your trained drum transcription model is saved to `Google Drive/drum_model_results/`

**Files:**
- `best_drum_model.pt` - Best model checkpoint
- `drum_model_epoch_XX.pt` - Periodic checkpoints
- `training_history.png` - Loss curves

**Next:** Copy the model to your StemScribe backend and update `drum_transcriber_v2.py` to use it.