<a href="https://colab.research.google.com/github/maple-buice/chart-hero/blob/main/colab/transformer_training_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a href="https://colab.research.google.com/github/maple-buice/chart-hero/blob/main/colab/transformer_training_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Environment Setup

In [1]:
# Check available accelerator (GPU or TPU)
import os, torch

if "COLAB_TPU_ADDR" in os.environ:
    print("TPU detected")
else:
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        props = torch.cuda.get_device_properties(0)
        print(f"Memory: {props.total_memory / 1e9:.1f} GB")

CUDA available: False


In [15]:
# Mount Google Drive
import os
from google.colab import drive

drive.mount("/content/drive")

# Set up project directory
ROOT_DIR = "/content/drive/MyDrive/chart-hero"
PROJECT_DIR = os.path.join(ROOT_DIR, "chart-hero")
os.makedirs(PROJECT_DIR, exist_ok=True)
%cd {PROJECT_DIR}

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/chart-hero/chart-hero


In [16]:
# Clone or update repository
import os

if not os.path.exists(".git"):
    !git clone https://github.com/maple-buice/chart-hero.git .
else:
    !git reset --hard origin/main
    !git pull

HEAD is now at 9e84295 Merge pull request #52 from maple-buice/codex/optimize-cloudconfig-for-t4-high-ram-colab
Already up to date.


In [13]:
# Install dependencies
%pip install -q -r requirements.txt torch_xla torchvision torchaudio

## 2. Data Setup

In [20]:
import os

# Paths
DATASET_DIR = os.path.join(ROOT_DIR, "datasets")
DATASET_TAR = os.path.join(DATASET_DIR, "dataset.tar.gz")
PROCESSED_DIR = os.path.join(DATASET_DIR, "processed_highres")

os.makedirs(PROCESSED_DIR, exist_ok=True)
print("Processed dataset dir:", PROCESSED_DIR)

print("DATASET_TAR:", DATASET_TAR)
print("tar exists:", os.path.exists(DATASET_TAR))

# Extract prebuilt dataset archive if available
if os.path.exists(DATASET_TAR) and not os.listdir(PROCESSED_DIR):
    !tar -xzf $DATASET_TAR -C $DATASET_DIR

Processed dataset dir: /content/drive/MyDrive/chart-hero/datasets/processed_highres
DATASET_TAR: /content/drive/MyDrive/chart-hero/datasets/dataset.tar.gz
tar exists: True
tar: Ignoring unknown extended header keyword 'SCHILY.fflags'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.FinderInfo'


## 3. Model Training

In [21]:
# Optional: log in to Weights & Biases
import wandb

wandb.login()

# Directories for models and logs
MODEL_DIR = os.path.join(PROJECT_DIR, "models")
LOG_DIR = os.path.join(PROJECT_DIR, "logs")
RUN_TAG = "colab_highres_run"

!python -m chart_hero.model_training.train_transformer --config cloud --data-dir "$PROCESSED_DIR" --model-dir "$MODEL_DIR" --log-dir "$LOG_DIR" --experiment-tag "$RUN_TAG" --use-wandb

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmaple-buice[0m ([33mmbuice-org[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


/usr/bin/python3: Error while finding module specification for 'chart_hero.model_training.train_transformer' (ModuleNotFoundError: No module named 'chart_hero')


## 4. Resume Training (Optional)

In [None]:
CHECKPOINT_PATH = os.path.join(MODEL_DIR, RUN_TAG, "last.ckpt")
if os.path.exists(CHECKPOINT_PATH):
    !python -m chart_hero.model_training.train_transformer --config cloud --data-dir "$PROCESSED_DIR" --model-dir "$MODEL_DIR" --log-dir "$LOG_DIR" --experiment-tag "$RUN_TAG" --resume --use-wandb
else:
    print(f"Checkpoint not found: {CHECKPOINT_PATH}")

## 5. Model Evaluation and Export

In [None]:
import os, torch
from chart_hero.model_training.train_transformer import DrumTranscriptionModule
from chart_hero.model_training.transformer_config import get_config

config = get_config("cloud")
best_model_path = os.path.join(MODEL_DIR, RUN_TAG, "best_model.ckpt")

if os.path.exists(best_model_path):
    model = DrumTranscriptionModule.load_from_checkpoint(best_model_path)
    model.eval()
    print("Model loaded successfully!")
    dummy_input = torch.randn(1, 1, 256, 128)
    onnx_path = os.path.join(MODEL_DIR, RUN_TAG, "drum_transformer.onnx")
    torch.onnx.export(
        model.model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=11,
        do_constant_folding=True,
        input_names=["spectrogram"],
        output_names=["logits"],
        dynamic_axes={
            "spectrogram": {0: "batch_size", 2: "time"},
            "logits": {0: "batch_size"},
        },
    )
    print(f"Model exported to ONNX: {onnx_path}")
else:
    print(f"Best model not found: {best_model_path}")

## 6. Cleanup

In [None]:
wandb.finish()
print("Training completed!")
print(f"Models saved to: {MODEL_DIR}")
print(f"Logs saved to: {LOG_DIR}")
print(f"Datasets saved to: {DATASET_DIR}")