<a href="https://colab.research.google.com/github/maple-buice/chart-hero/blob/main/colab/transformer_training_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Environment Setup

In [1]:
# Check GPU availability
!nvidia-smi

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Tue May 27 12:13:08 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   48C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up project directory
PROJECT_DIR = '/content/drive/MyDrive/chart-hero'
!mkdir -p {PROJECT_DIR}
%cd {PROJECT_DIR}

Mounted at /content/drive
/content/drive/MyDrive/chart-hero


In [7]:
# Clone or update repository
import os
if not os.path.exists('chart-hero'):
    !git clone https://github.com/maple-buice/chart-hero.git
else:
    %cd chart-hero
    !git pull
    %cd ..

%cd chart-hero

Cloning into 'chart-hero'...
remote: Enumerating objects: 341, done.[K
remote: Counting objects: 100% (341/341), done.[K
remote: Compressing objects: 100% (224/224), done.[K
remote: Total 341 (delta 177), reused 262 (delta 101), pack-reused 0 (from 0)[K
Receiving objects: 100% (341/341), 4.31 MiB | 18.52 MiB/s, done.
Resolving deltas: 100% (177/177), done.
/content/drive/MyDrive/chart-hero/chart-hero


In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q pytorch-lightning transformers timm wandb
!pip install -q librosa soundfile scikit-learn pandas numpy matplotlib seaborn tqdm
!pip install -q ipywidgets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m100.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m875.6/875.6 kB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m115.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m663.9/663.9 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m417.9/417.9 MB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.4/168.4 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.1/58.1 MB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m128.2/128.2 MB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━

## 2. Data Setup

In [None]:
# Download and extract Expanded Groove MIDI Dataset
# Note: Replace with actual download commands for your dataset

DATASET_URL = "https://storage.googleapis.com/magentadata/datasets/e-gmd/v1.0.0/e-gmd-v1.0.0.zip"
DATASET_DIR = "/content/drive/MyDrive/chart-hero/datasets"

!mkdir -p {DATASET_DIR}

# Download dataset (uncomment and modify as needed)
# !wget -O {DATASET_DIR}/dataset.zip {DATASET_URL}
# !unzip -q {DATASET_DIR}/dataset.zip -d {DATASET_DIR}

print(f"Dataset directory: {DATASET_DIR}")
!ls -la {DATASET_DIR}

In [None]:
# Prepare training data (if not already processed)
# This cell converts the raw EGMD data to transformer-compatible format

import sys
sys.path.append('/content/chart-hero')

from model_training.data_preparation import data_preparation
from model_training.transformer_data import convert_legacy_data

# Process raw EGMD data
egmd_dir = "/content/drive/MyDrive/chart-hero/datasets/expanded-groove-midi"
processed_dir = "/content/drive/MyDrive/chart-hero/datasets/processed"

if not os.path.exists(processed_dir):
    print("Processing raw EGMD data...")

    # Create data preparation instance
    data_prep = data_preparation(
        directory_path=egmd_dir,
        dataset='egmd',
        sample_ratio=1.0,
        diff_threshold=1.0
    )

    # Create audio set with batching
    data_prep.create_audio_set(
        pad_before=0.1,
        pad_after=0.1,
        fix_length=10.0,  # 10 second segments
        batching=True,
        dir_path=processed_dir,
        num_batches=20
    )

    print("Data processing completed!")
else:
    print("Processed data already exists.")

!ls -la {processed_dir}

## 3. Model Training

In [None]:
# Set up W&B logging
import wandb

# Login to W&B (you'll need to provide your API key)
wandb.login()

# Or set the API key directly
# wandb.login(key="your-wandb-api-key")

In [None]:
# Test transformer setup
!python model_training/test_transformer_setup.py

In [None]:
# Start training with cloud configuration
DATA_DIR = "/content/drive/MyDrive/chart-hero/datasets/processed"
AUDIO_DIR = "/content/drive/MyDrive/chart-hero/datasets/expanded-groove-midi"

!python model_training/train_transformer.py \
    --config cloud \
    --data-dir {DATA_DIR} \
    --audio-dir {AUDIO_DIR} \
    --project-name chart-hero-transformer-colab

## 4. Resume Training (if needed)

In [None]:
# Resume from checkpoint
CHECKPOINT_PATH = "/content/drive/MyDrive/chart-hero/models/last.ckpt"

if os.path.exists(CHECKPOINT_PATH):
    !python model_training/train_transformer.py \
        --config cloud \
        --data-dir {DATA_DIR} \
        --audio-dir {AUDIO_DIR} \
        --resume {CHECKPOINT_PATH} \
        --project-name chart-hero-transformer-colab
else:
    print(f"Checkpoint not found: {CHECKPOINT_PATH}")

## 5. Model Evaluation and Export

In [None]:
# Load and evaluate best model
import torch
from model_training.train_transformer import DrumTranscriptionModule
from model_training.transformer_config import get_config

config = get_config("cloud")
best_model_path = "/content/drive/MyDrive/chart-hero/models/best_model.ckpt"

if os.path.exists(best_model_path):
    model = DrumTranscriptionModule.load_from_checkpoint(best_model_path)
    model.eval()
    print("Model loaded successfully!")

    # Export to ONNX for deployment
    dummy_input = torch.randn(1, 1, 256, 128)
    onnx_path = "/content/drive/MyDrive/chart-hero/models/drum_transformer.onnx"

    torch.onnx.export(
        model.model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['spectrogram'],
        output_names=['logits'],
        dynamic_axes={
            'spectrogram': {0: 'batch_size', 2: 'time'},
            'logits': {0: 'batch_size'}
        }
    )

    print(f"Model exported to ONNX: {onnx_path}")
else:
    print(f"Best model not found: {best_model_path}")

## 6. Cleanup

In [None]:
# Clean up temporary files and finish W&B run
wandb.finish()

# Show final model and log locations
print("Training completed!")
print(f"Models saved to: /content/drive/MyDrive/chart-hero/models/")
print(f"Logs saved to: /content/drive/MyDrive/chart-hero/logs/")
print(f"Datasets saved to: /content/drive/MyDrive/chart-hero/datasets/")