<a href="https://colab.research.google.com/github/maple-buice/chart-hero/blob/main/colab/transformer_training_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Environment Setup

In [None]:
# Check GPU availability
!nvidia-smi

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

Wed Jun 11 21:19:51 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   40C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Set up project directory
PROJECT_DIR = '/content/drive/MyDrive/chart-hero'
!mkdir -p {PROJECT_DIR}
%cd {PROJECT_DIR}

Mounted at /content/drive
/content/drive/MyDrive/chart-hero


In [25]:
# Clone or update repository
import os
if not os.path.exists('chart-hero'):
    !git clone https://github.com/maple-buice/chart-hero.git
else:
    %cd chart-hero
    !git pull
    %cd ..

%cd chart-hero

Cloning into 'chart-hero'...
remote: Enumerating objects: 390, done.[K
remote: Counting objects: 100% (390/390), done.[K
remote: Compressing objects: 100% (266/266), done.[K
remote: Total 390 (delta 204), reused 276 (delta 108), pack-reused 0 (from 0)[K
Receiving objects: 100% (390/390), 4.33 MiB | 10.56 MiB/s, done.
Resolving deltas: 100% (204/204), done.
/content/drive/MyDrive/chart-hero/chart-hero/chart-hero/chart-hero


In [17]:
# Install dependencies
!pip uninstall -y torch torchvision torchaudio fastai
!pip install -q torch torchvision torchaudio pytorch-lightning --index-url https://download.pytorch.org/whl/cu118
!pip install -q transformers timm wandb
!pip install -q librosa soundfile scikit-learn pandas numpy matplotlib seaborn tqdm
!pip install -q ipywidgets mido

Found existing installation: torch 2.7.1+cu118
Uninstalling torch-2.7.1+cu118:
  Successfully uninstalled torch-2.7.1+cu118
Found existing installation: torchvision 0.22.1+cu118
Uninstalling torchvision-0.22.1+cu118:
  Successfully uninstalled torchvision-0.22.1+cu118
Found existing installation: torchaudio 2.7.1+cu118
Uninstalling torchaudio-2.7.1+cu118:
  Successfully uninstalled torchaudio-2.7.1+cu118
Found existing installation: fastai 2.7.19
Uninstalling fastai-2.7.19:
  Successfully uninstalled fastai-2.7.19


## 2. Data Setup

In [18]:
import os
import subprocess
import shutil
from google.colab import drive

# --- Configuration ---
DRIVE_MOUNT_POINT = '/content/drive'
MY_DRIVE_PATH = os.path.join(DRIVE_MOUNT_POINT, 'MyDrive')
PROJECT_DIR_NAME = 'chart-hero' # As per your workspace structure
PROJECT_DRIVE_PATH = os.path.join(MY_DRIVE_PATH, PROJECT_DIR_NAME)
DATASET_DIR_NAME = 'datasets'
DATASET_DIR = os.path.join(PROJECT_DRIVE_PATH, DATASET_DIR_NAME)

# Dataset URL and paths
# IMPORTANT: Please verify this URL and the expected unzipped content name for your specific dataset.
DATASET_URL = "https://storage.googleapis.com/magentadata/datasets/e-gmd/v1.0.0/e-gmd-v1.0.0.zip"
ZIP_FILE_NAME = os.path.basename(DATASET_URL)
DRIVE_ZIP_PATH = os.path.join(DATASET_DIR, ZIP_FILE_NAME) # ZIP stored in Google Drive

# Define the path for an expected file/folder after unzipping.
# For e-gmd-v1.0.0.zip, it would be 'e-gmd-v1.0.0'.
EXPECTED_UNZIPPED_CONTENT_NAME = "e-gmd-v1.0.0"
EXPECTED_UNZIPPED_CONTENT_PATH = os.path.join(DATASET_DIR, EXPECTED_UNZIPPED_CONTENT_NAME)

SENTINEL_FILE_NAME = ".unzip_successful_sentinel"
SENTINEL_FILE_PATH = os.path.join(DATASET_DIR, SENTINEL_FILE_NAME)

# --- Setup ---
print("Starting dataset setup...")

# 1. Mount Google Drive
if not os.path.exists(MY_DRIVE_PATH) or not os.path.ismount(DRIVE_MOUNT_POINT):
    print(f"Mounting Google Drive at {DRIVE_MOUNT_POINT}...")
    drive.mount(DRIVE_MOUNT_POINT, force_remount=True) # force_remount can be useful
else:
    print(f"Google Drive already mounted at {DRIVE_MOUNT_POINT}.")

# 2. Create project and dataset directories in Drive if they don't exist
os.makedirs(PROJECT_DRIVE_PATH, exist_ok=True)
os.makedirs(DATASET_DIR, exist_ok=True)
print(f"Project directory in Drive: {PROJECT_DRIVE_PATH}")
print(f"Dataset directory in Drive: {DATASET_DIR}")


# --- Main Logic ---

# 3. Check if unzipping is already complete (sentinel file exists)
if os.path.exists(SENTINEL_FILE_PATH):
    print(f"Dataset already successfully unzipped. Sentinel file found: {SENTINEL_FILE_PATH}")
    if os.path.exists(EXPECTED_UNZIPPED_CONTENT_PATH):
        print(f"Verified: Expected content '{EXPECTED_UNZIPPED_CONTENT_NAME}' exists at '{EXPECTED_UNZIPPED_CONTENT_PATH}'.")
    else:
        print(f"WARNING: Sentinel file exists, but expected content '{EXPECTED_UNZIPPED_CONTENT_NAME}' not found at '{EXPECTED_UNZIPPED_CONTENT_PATH}'.")
        print("This might indicate that the unzipped content was moved or deleted after the sentinel was created.")
        print("If you encounter issues, consider deleting the sentinel file and this cell's output, then re-running.")
else:
    print(f"Sentinel file not found at {SENTINEL_FILE_PATH}. Proceeding with dataset download and/or unzip.")

    # 4. Download the dataset ZIP if it doesn't exist in Google Drive
    if not os.path.exists(DRIVE_ZIP_PATH):
        print(f"ZIP file not found at {DRIVE_ZIP_PATH}. Downloading...")
        # Download directly to Google Drive
        download_command = f"wget -O '{DRIVE_ZIP_PATH}' '{DATASET_URL}'"
        print(f"Executing: {download_command}")
        process = subprocess.run(download_command, shell=True, capture_output=True, text=True)
        if process.returncode == 0:
            print(f"Download successful. ZIP file saved to {DRIVE_ZIP_PATH}")
        else:
            print("Download failed.")
            print(f"Stdout: {process.stdout}")
            print(f"Stderr: {process.stderr}")
            # Clean up potentially incomplete ZIP file
            if os.path.exists(DRIVE_ZIP_PATH):
                os.remove(DRIVE_ZIP_PATH)
                print(f"Removed potentially incomplete ZIP file: {DRIVE_ZIP_PATH}")
            raise Exception(f"Failed to download dataset from {DATASET_URL} to {DRIVE_ZIP_PATH}")
    else:
        print(f"ZIP file already exists at {DRIVE_ZIP_PATH}. Skipping download.")
        file_size = os.path.getsize(DRIVE_ZIP_PATH)
        print(f"Existing ZIP file size: {file_size / (1024*1024):.2f} MB")


    # 5. Unzip the dataset from Google Drive to Google Drive
    print(f"Checking for expected unzipped content at: {EXPECTED_UNZIPPED_CONTENT_PATH}")
    if os.path.exists(EXPECTED_UNZIPPED_CONTENT_PATH):
        print(f"Main expected content '{EXPECTED_UNZIPPED_CONTENT_NAME}' already exists at '{EXPECTED_UNZIPPED_CONTENT_PATH}'.")
        print("Unzip will attempt to complete by skipping existing files (due to -nq option).")
    else:
        print(f"Main expected content '{EXPECTED_UNZIPPED_CONTENT_NAME}' not found. Proceeding with unzip.")

    print(f"Unzipping '{DRIVE_ZIP_PATH}' to '{DATASET_DIR}'...")
    # -nq: never overwrite existing files (good for resuming, quiet)
    unzip_command = f"unzip -nq '{DRIVE_ZIP_PATH}' -d '{DATASET_DIR}'"
    print(f"Executing: {unzip_command}")
    process = subprocess.run(unzip_command, shell=True, capture_output=True, text=True)

    # Check unzip outcome
    # unzip return codes:
    # 0: success
    # 1: warning error (e.g., some files not processed, or -nq found existing files)
    # Other codes: more serious errors
    if process.returncode == 0 or process.returncode == 1:
        print("Unzip command executed.")
        if process.stdout: print(f"Unzip stdout: {process.stdout[:500]}...") # Print some output
        if process.stderr: print(f"Unzip stderr: {process.stderr[:500]}...")

        # Verify expected content after unzip attempt
        if os.path.exists(EXPECTED_UNZIPPED_CONTENT_PATH):
            print(f"Unzip appears successful. Expected content '{EXPECTED_UNZIPPED_CONTENT_NAME}' found at '{EXPECTED_UNZIPPED_CONTENT_PATH}'.")
            # Create sentinel file
            print(f"Creating sentinel file: {SENTINEL_FILE_PATH}")
            with open(SENTINEL_FILE_PATH, 'w') as f:
                f.write(f"Unzip completed successfully for {ZIP_FILE_NAME} into {DATASET_DIR} on {subprocess.check_output(['date']).decode('utf-8').strip()}")
            print("Sentinel file created.")
        else:
            print(f"ERROR: Unzip command finished (exit code {process.returncode}), but expected content '{EXPECTED_UNZIPPED_CONTENT_NAME}' still NOT found at '{EXPECTED_UNZIPPED_CONTENT_PATH}'.")
            print("This could indicate an issue with the ZIP file structure, an empty ZIP, or the unzip process itself did not extract the main folder as expected.")
            print(f"Please check the ZIP file contents and the '{EXPECTED_UNZIPPED_CONTENT_NAME}' variable if this is incorrect.")
            # Do NOT create sentinel file if primary content is missing.
    else:
        print(f"Unzip failed with exit code {process.returncode}.")
        print(f"Stdout: {process.stdout}")
        print(f"Stderr: {process.stderr}")
        raise Exception(f"Failed to unzip dataset from {DRIVE_ZIP_PATH}. Check logs for details.")

# --- Verification (Final Check) ---
print("--- Final Verification ---")
print(f"Checking existence of sentinel file: {SENTINEL_FILE_PATH} -> {'Exists' if os.path.exists(SENTINEL_FILE_PATH) else 'Not found'}")
print(f"Checking existence of ZIP file in Drive: {DRIVE_ZIP_PATH} -> {'Exists' if os.path.exists(DRIVE_ZIP_PATH) else 'Not found'}")
print(f"Checking existence of expected unzipped content: {EXPECTED_UNZIPPED_CONTENT_PATH} -> {'Exists' if os.path.exists(EXPECTED_UNZIPPED_CONTENT_PATH) else 'Not found'}")

if os.path.exists(EXPECTED_UNZIPPED_CONTENT_PATH):
    print(f"Listing contents of '{EXPECTED_UNZIPPED_CONTENT_PATH}' (first few items):")
    try:
        contents = os.listdir(EXPECTED_UNZIPPED_CONTENT_PATH)
        for item in contents[:5]:
            print(f"- {item}")
        if len(contents) > 5:
            print("  ...")
    except Exception as e:
        print(f"Could not list contents: {e}")
else:
    print(f"Cannot list contents as '{EXPECTED_UNZIPPED_CONTENT_PATH}' does not exist.")

print("Dataset setup cell execution complete.")

Starting dataset setup...
Google Drive already mounted at /content/drive.
Project directory in Drive: /content/drive/MyDrive/chart-hero
Dataset directory in Drive: /content/drive/MyDrive/chart-hero/datasets
Dataset already successfully unzipped. Sentinel file found: /content/drive/MyDrive/chart-hero/datasets/.unzip_successful_sentinel
Verified: Expected content 'e-gmd-v1.0.0' exists at '/content/drive/MyDrive/chart-hero/datasets/e-gmd-v1.0.0'.
--- Final Verification ---
Checking existence of sentinel file: /content/drive/MyDrive/chart-hero/datasets/.unzip_successful_sentinel -> Exists
Checking existence of ZIP file in Drive: /content/drive/MyDrive/chart-hero/datasets/e-gmd-v1.0.0.zip -> Exists
Checking existence of expected unzipped content: /content/drive/MyDrive/chart-hero/datasets/e-gmd-v1.0.0 -> Exists
Listing contents of '/content/drive/MyDrive/chart-hero/datasets/e-gmd-v1.0.0' (first few items):
- drummer5
- drummer7
- drummer6
- e-gmd-v1.0.0.csv
- LICENSE
  ...
Dataset setup cel

In [None]:
# Prepare training data (if not already processed)
# This cell converts the raw EGMD data to transformer-compatible format

import sys
sys.path.append('/content/chart-hero')

from model_training.data_preparation import data_preparation
from model_training.transformer_data import convert_legacy_data

# Process raw EGMD data
egmd_dir = "/content/drive/MyDrive/chart-hero/datasets/e-gmd-v1.0.0"
processed_dir = "/content/drive/MyDrive/chart-hero/datasets/processed"

if not os.path.exists(processed_dir):
    print("Processing raw EGMD data...")

    # Create data preparation instance
    data_prep = data_preparation(
        directory_path=egmd_dir,
        dataset='egmd',
        sample_ratio=1.0,
        diff_threshold=1.0
    )

    # Create audio set with batching
    data_prep.create_audio_set(
        pad_before=0.1,
        pad_after=0.1,
        fix_length=10.0,  # 10 second segments
        batching=True,
        dir_path=processed_dir,
        num_batches=20
    )

    print("Data processing completed!")
else:
    print("Processed data already exists.")

!ls -la {processed_dir}

ModuleNotFoundError: No module named 'mido'

## 3. Model Training

In [20]:
# Set up W&B logging
import wandb

# Login to W&B (you'll need to provide your API key)
wandb.login()

# Or set the API key directly
# wandb.login(key="your-wandb-api-key")

True

In [26]:
# Test transformer setup
!python model_training/test_transformer_setup.py --config cloud

INFO:__main__:TRANSFORMER SETUP TESTS
INFO:__main__:
------------------------------
INFO:__main__:Running Configuration test (using 'cloud' config context)...
INFO:__main__:------------------------------
INFO:__main__:Testing configuration classes...
INFO:__main__:--- Validating invoked config: 'cloud' ---
INFO:__main__:✓ Invoked config 'cloud' (resolved to cloud): VALIDATED - device cuda, batch_size=64
INFO:__main__:--- Checking standard config: 'local' ---
INFO:__main__:--- Checking standard config: 'auto' ---
INFO:__main__:Auto-detected config for verification: CloudConfig
INFO:__main__:✓ Config 'auto': Loaded & Validated - device cuda, batch_size=64
INFO:__main__:✓ Configuration test PASSED
INFO:__main__:
------------------------------
INFO:__main__:Running Model Architecture test (using 'cloud' config context)...
INFO:__main__:------------------------------
INFO:__main__:Testing model architecture...
Model created with 7,214,217 total parameters
Trainable parameters: 7,214,217
INF

In [27]:
# Start training with cloud configuration
DATA_DIR = "/content/drive/MyDrive/chart-hero/datasets/processed"
AUDIO_DIR = "/content/drive/MyDrive/chart-hero/datasets/e-gmd-v1.0.0"

!python model_training/train_transformer.py \
    --config cloud \
    --data-dir {DATA_DIR} \
    --audio-dir {AUDIO_DIR} \
    --project-name chart-hero-transformer-colab

usage: train_transformer.py [-h] [--config CONFIG] [--use-wandb]
                            [--quick-test] [--debug]
                            [--experiment-tag EXPERIMENT_TAG]
                            [--data-dir DATA_DIR] [--audio-dir AUDIO_DIR]
                            [--monitor-gpu] [--batch-size BATCH_SIZE]
                            [--hidden-size HIDDEN_SIZE]
                            [--learning-rate LEARNING_RATE]
                            [--num-workers NUM_WORKERS]
                            [--accumulate-grad-batches ACCUMULATE_GRAD_BATCHES]
train_transformer.py: error: unrecognized arguments: --project-name chart-hero-transformer-colab


## 4. Resume Training (if needed)

In [None]:
# Resume from checkpoint
CHECKPOINT_PATH = "/content/drive/MyDrive/chart-hero/models/last.ckpt"

if os.path.exists(CHECKPOINT_PATH):
    !python model_training/train_transformer.py \
        --config cloud \
        --data-dir {DATA_DIR} \
        --audio-dir {AUDIO_DIR} \
        --resume {CHECKPOINT_PATH} \
        --project-name chart-hero-transformer-colab
else:
    print(f"Checkpoint not found: {CHECKPOINT_PATH}")

## 5. Model Evaluation and Export

In [None]:
# Load and evaluate best model
import torch
from model_training.train_transformer import DrumTranscriptionModule
from model_training.transformer_config import get_config

config = get_config("cloud")
best_model_path = "/content/drive/MyDrive/chart-hero/models/best_model.ckpt"

if os.path.exists(best_model_path):
    model = DrumTranscriptionModule.load_from_checkpoint(best_model_path)
    model.eval()
    print("Model loaded successfully!")

    # Export to ONNX for deployment
    dummy_input = torch.randn(1, 1, 256, 128)
    onnx_path = "/content/drive/MyDrive/chart-hero/models/drum_transformer.onnx"

    torch.onnx.export(
        model.model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=['spectrogram'],
        output_names=['logits'],
        dynamic_axes={
            'spectrogram': {0: 'batch_size', 2: 'time'},
            'logits': {0: 'batch_size'}
        }
    )

    print(f"Model exported to ONNX: {onnx_path}")
else:
    print(f"Best model not found: {best_model_path}")

Best model not found: /content/drive/MyDrive/chart-hero/models/best_model.ckpt


## 6. Cleanup

In [None]:
# Clean up temporary files and finish W&B run
wandb.finish()

# Show final model and log locations
print("Training completed!")
print(f"Models saved to: /content/drive/MyDrive/chart-hero/models/")
print(f"Logs saved to: /content/drive/MyDrive/chart-hero/logs/")
print(f"Datasets saved to: /content/drive/MyDrive/chart-hero/datasets/")