# Rainbow Pipeline Training - RunPod

Phase 2: Multiclass Rebracketing Classification

## 1. Install Dependencies

In [None]:
!pip install -q torch transformers pandas pyarrow scikit-learn pyyaml tqdm wandb matplotlib seaborn

## 2. Verify Data

In [None]:
from pathlib import Path
import pandas as pd

data_dir = Path("/workspace/data")

print("Files in /workspace/data:")
for f in sorted(data_dir.glob("*.parquet")):
    size = f.stat().st_size / (1024**2)
    unit = "MB"
    if size > 1000:
        size = size / 1024
        unit = "GB"
    print(f"  {f.name}: {size:.1f} {unit}")

In [None]:
# Quick look at the data
df = pd.read_parquet("/workspace/data/base_manifest_db.parquet")
print(f"Total tracks: {len(df)}")
print(f"Columns: {list(df.columns)}")

In [None]:
# Check rebracketing type distribution
types = df['training_data'].apply(lambda x: x.get('rebracketing_type') if isinstance(x, dict) else None)
print("Rebracketing type distribution:")
print(types.value_counts())

## 3. Clone/Update Code

In [None]:
import os
from pathlib import Path

code_dir = Path("/workspace/white")

if not code_dir.exists():
    !git clone https://github.com/brotherclone/white.git /workspace/white
    !cd /workspace/white && git checkout feature/dataPrep
else:
    !cd /workspace/white && git pull

print("\nCode ready at /workspace/white")

## 4. Check GPU

In [None]:
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 5. Setup Training Config

In [None]:
import yaml

os.chdir("/workspace/white/training")

# Load multiclass config
with open("config_multiclass.yml") as f:
    config = yaml.safe_load(f)

# Update paths for RunPod
config["data"]["manifest_path"] = "/workspace/data/base_manifest_db.parquet"
config["logging"]["output_dir"] = "/workspace/output"
config["training"]["device"] = "cuda"
config["training"]["mixed_precision"] = True

# Optional: Adjust for your GPU memory
# config["training"]["batch_size"] = 16  # Reduce if OOM

# Save updated config
with open("config_multiclass.yml", "w") as f:
    yaml.dump(config, f, default_flow_style=False)

print("Config updated for RunPod")
print(f"  Data: {config['data']['manifest_path']}")
print(f"  Output: {config['logging']['output_dir']}")
print(f"  Device: {config['training']['device']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Epochs: {config['training']['epochs']}")

## 6. Optional: Setup WandB

In [None]:
# Uncomment and run if you want WandB logging
# import wandb
# wandb.login()

## 7. Run Training

In [None]:
# Add training directory to path
import sys
sys.path.insert(0, "/workspace/white/training")

from train import Trainer

# Set seeds
seed = config["reproducibility"]["seed"]
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Create trainer and run
trainer = Trainer(config)
trainer.train()

## 8. View Results

In [None]:
import json
from pathlib import Path

output_dir = Path("/workspace/output")

# Load training history
with open(output_dir / "history.json") as f:
    history = json.load(f)

print("Training History:")
print(f"  Final train loss: {history['train_loss'][-1]:.4f}")
print(f"  Final val loss: {history['val_loss'][-1]:.4f}")
print(f"  Final val accuracy: {history['val_acc'][-1]:.4f}")
if 'val_macro_f1' in history:
    print(f"  Final val macro F1: {history['val_macro_f1'][-1]:.4f}")

In [None]:
# Plot training curves
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss
axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].set_title('Loss')

# Accuracy
axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].set_title('Accuracy')

plt.tight_layout()
plt.show()

In [None]:
# View confusion matrix (if saved)
from IPython.display import Image

cm_path = output_dir / "confusion_matrix.png"
if cm_path.exists():
    display(Image(filename=str(cm_path)))
else:
    print("Confusion matrix not found")

## 9. Download Checkpoint

The best model is saved at `/workspace/output/checkpoint_best.pt`

To download, you can:
1. Use the RunPod file browser
2. SCP from your local machine:
```bash
scp -P <PORT> root@<IP>:/workspace/output/checkpoint_best.pt ./
```

In [None]:
# List output files
!ls -lh /workspace/output/