<a href="https://colab.research.google.com/github/maple-buice/chart-hero/blob/main/colab/transformer_training_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Environment Setup

In [None]:
# Check available accelerator (GPU or TPU)
import os, torch
if "COLAB_TPU_ADDR" in os.environ:
    print("TPU detected")
else:
    print(f"CUDA available: {torch.cuda.is_available()}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        props = torch.cuda.get_device_properties(0)
        print(f"Memory: {props.total_memory / 1e9:.1f} GB")


In [None]:
# Mount Google Drive
import os
from google.colab import drive
drive.mount("/content/drive")

# Set up project directory
PROJECT_DIR = "/content/drive/MyDrive/chart-hero"
os.makedirs(PROJECT_DIR, exist_ok=True)
%cd {PROJECT_DIR}

In [None]:
# Clone or update repository
import os
if not os.path.exists(".git"):
    !git clone https://github.com/maple-buice/chart-hero.git .
else:
    !git pull


In [None]:
# Install dependencies
%pip install -q -r requirements.txt torch_xla torchvision torchaudio


## 2. Data Setup

In [None]:
import os

# Paths
DATASET_DIR = os.path.join(PROJECT_DIR, "datasets")
PROCESSED_DIR = os.path.join(DATASET_DIR, "processed_highres")
CLONEHERO_SONGS = "/content/drive/MyDrive/CloneHeroSongs"  # Update to your songs directory
os.makedirs(PROCESSED_DIR, exist_ok=True)
print("Processed dataset dir:", PROCESSED_DIR)

# Build dataset if directory is empty
if not os.listdir(PROCESSED_DIR):
    !python -m chart_hero.train.build_dataset --roots "$CLONEHERO_SONGS" --out-dir "$PROCESSED_DIR" --config cloud


## 3. Model Training

In [None]:
# Optional: log in to Weights & Biases
import wandb
wandb.login()

# Directories for models and logs
MODEL_DIR = os.path.join(PROJECT_DIR, "models")
LOG_DIR = os.path.join(PROJECT_DIR, "logs")
RUN_TAG = "colab_highres_run"

!python -m chart_hero.model_training.train_transformer --config cloud --data-dir "$PROCESSED_DIR" --model-dir "$MODEL_DIR" --log-dir "$LOG_DIR" --experiment-tag "$RUN_TAG" --use-wandb


## 4. Resume Training (Optional)

In [None]:
CHECKPOINT_PATH = os.path.join(MODEL_DIR, RUN_TAG, "last.ckpt")
if os.path.exists(CHECKPOINT_PATH):
    !python -m chart_hero.model_training.train_transformer --config cloud --data-dir "$PROCESSED_DIR" --model-dir "$MODEL_DIR" --log-dir "$LOG_DIR" --experiment-tag "$RUN_TAG" --resume --use-wandb
else:
    print(f"Checkpoint not found: {CHECKPOINT_PATH}")


## 5. Model Evaluation and Export

In [None]:
import os, torch
from chart_hero.model_training.train_transformer import DrumTranscriptionModule
from chart_hero.model_training.transformer_config import get_config

config = get_config("cloud")
best_model_path = os.path.join(MODEL_DIR, RUN_TAG, "best_model.ckpt")

if os.path.exists(best_model_path):
    model = DrumTranscriptionModule.load_from_checkpoint(best_model_path)
    model.eval()
    print("Model loaded successfully!")
    dummy_input = torch.randn(1, 1, 256, 128)
    onnx_path = os.path.join(MODEL_DIR, RUN_TAG, "drum_transformer.onnx")
    torch.onnx.export(model.model, dummy_input, onnx_path, export_params=True, opset_version=11, do_constant_folding=True, input_names=["spectrogram"], output_names=["logits"], dynamic_axes={"spectrogram": {0: "batch_size", 2: "time"}, "logits": {0: "batch_size"}})
    print(f"Model exported to ONNX: {onnx_path}")
else:
    print(f"Best model not found: {best_model_path}")


## 6. Cleanup

In [None]:
wandb.finish()
print("Training completed!")
print(f"Models saved to: {MODEL_DIR}")
print(f"Logs saved to: {LOG_DIR}")
print(f"Datasets saved to: {DATASET_DIR}")
