<a href="https://colab.research.google.com/github/maple-buice/chart-hero/blob/main/colab/transformer_training_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

import torch

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Mount Google Drive
import os

from google.colab import drive

drive.mount('/content/drive')

# Set up project directory

PROJECT_DIR = '/content/drive/MyDrive/chart-hero'
os.makedirs(PROJECT_DIR, exist_ok=True)
%cd {PROJECT_DIR}

In [None]:
# Clone or update repository
import os

if not os.path.exists('.git'):
    !git clone https://github.com/maple-buice/chart-hero.git .
else:
    !git pull

In [None]:
# Install dependencies
%pip install -q -r requirements.txt

## 2. Data Setup

In [None]:
import os

# --- Configuration ---
DATASET_DIR = os.path.join(PROJECT_DIR, 'datasets')

# Dataset URL and paths
DATASET_URL = "https://storage.googleapis.com/magentadata/datasets/e-gmd/v1.0.0/e-gmd-v1.0.0.zip"
ZIP_FILE_NAME = os.path.basename(DATASET_URL)
DRIVE_ZIP_PATH = os.path.join(DATASET_DIR, ZIP_FILE_NAME)

EXPECTED_UNZIPPED_CONTENT_NAME = "e-gmd-v1.0.0"
EXPECTED_UNZIPPED_CONTENT_PATH = os.path.join(DATASET_DIR, EXPECTED_UNZIPPED_CONTENT_NAME)

SENTINEL_FILE_NAME = ".unzip_successful_sentinel"
SENTINEL_FILE_PATH = os.path.join(DATASET_DIR, SENTINEL_FILE_NAME)

# --- Setup ---
print("Starting dataset setup...")
os.makedirs(DATASET_DIR, exist_ok=True)

if not os.path.exists(SENTINEL_FILE_PATH):
    if not os.path.exists(DRIVE_ZIP_PATH):
        print(f"ZIP file not found at {DRIVE_ZIP_PATH}. Downloading...")
        !wget -O '{DRIVE_ZIP_PATH}' '{DATASET_URL}'
    else:
        print(f"ZIP file already exists at {DRIVE_ZIP_PATH}. Skipping download.")

    print(f"Unzipping '{DRIVE_ZIP_PATH}' to '{DATASET_DIR}'...")
    !unzip -nq '{DRIVE_ZIP_PATH}' -d '{DATASET_DIR}'

    if os.path.exists(EXPECTED_UNZIPPED_CONTENT_PATH):
        print("Unzip successful.")
        with open(SENTINEL_FILE_PATH, 'w') as f:
            f.write('unzip complete')
    else:
        print("Unzip failed.")
else:
    print("Dataset already unzipped.")

In [None]:
# Prepare training data (if not already processed)
PROCESSED_DIR = os.path.join(DATASET_DIR, 'processed')
if not os.path.exists(PROCESSED_DIR):
    print("Processing raw EGMD data...")
    !python src/chart_hero/prepare_egmd_data.py --output-dir {PROCESSED_DIR}
else:
    print("Processed data already exists.")

## 3. Model Training

In [None]:
# Set up W&B logging
import wandb

wandb.login()

In [None]:
# Test transformer setup
!python tests/model_training/test_transformer_setup.py --config cloud

In [None]:
# Start training with cloud configuration
!python src/chart_hero/model_training/train_transformer.py \
    --config cloud \
    --project-name chart-hero-transformer-colab

## 4. Resume Training (if needed)

In [None]:
# Resume from checkpoint
CHECKPOINT_PATH = "/content/drive/MyDrive/chart-hero/models/local_transformer_models/last.ckpt"

if os.path.exists(CHECKPOINT_PATH):
    !python src/chart_hero/model_training/train_transformer.py \\
        --config cloud \\
        --resume {CHECKPOINT_PATH} \\
        --project-name chart-hero-transformer-colab
else:
    print(f"Checkpoint not found: {CHECKPOINT_PATH}")

## 5. Model Evaluation and Export

In [None]:
# Load and evaluate best model
import torch

from chart_hero.model_training.train_transformer import DrumTranscriptionModule
from chart_hero.model_training.transformer_config import get_config

config = get_config("cloud")
best_model_path = "/content/drive/MyDrive/chart-hero/models/local_transformer_models/best_model.ckpt"

if os.path.exists(best_model_path):
    model = DrumTranscriptionModule.load_from_checkpoint(best_model_path)
    model.eval()
    print("Model loaded successfully!")

    # Export to ONNX for deployment
    dummy_input = torch.randn(1, 1, 256, 128)
    onnx_path = "/content/drive/MyDrive/chart-hero/models/drum_transformer.onnx"

    torch.onnx.export(
        model.model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['spectrogram'],
        output_names=['logits'],
        dynamic_axes={
            'spectrogram': {0: 'batch_size', 2: 'time'},
            'logits': {0: 'batch_size'}
        }
    )

    print(f"Model exported to ONNX: {onnx_path}")
else:
    print(f"Best model not found: {best_model_path}")

## 6. Cleanup

In [None]:
# Clean up temporary files and finish W&B run
wandb.finish()

# Show final model and log locations
print("Training completed!")
print("Models saved to: /content/drive/MyDrive/chart-hero/models/")
print("Logs saved to: /content/drive/MyDrive/chart-hero/logs/")
print("Datasets saved to: /content/drive/MyDrive/chart-hero/datasets/")