AlphaBuilder v3.3 - Kaggle Training (Pre-Corrected Data)
========================================================

This script is designed to run on Kaggle with GPU T4 x2.

**VERSION 3.3 HIGHLIGHTS:**
- Uses **pre-corrected database** (`warmup_data_v3_2.db`)
- Value targets are already updated to the new additive reward formula
- No runtime recalculation needed (faster data loading)

**New Reward Formula (Reference):**
- `reward = compliance_score(C) + volume_bonus(V)`

Expected runtime: ~4.5 hours for 30 epochs


## Environment Setup


In [None]:
# ============================================================================

import subprocess
import sys
import os

print("="*60)
print("üöÄ AlphaBuilder v3.3 - Kaggle Training (Pre-Corrected Data)")
print("="*60)

# Check GPU
import torch
print(f"\nüìä Hardware Detection:")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"   GPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        print(f"       Memory: {props.total_memory / 1024**3:.1f} GB")


## Clone Repository


In [None]:
# ============================================================================

REPO_URL = "https://github.com/gustavomello9600/alphabuild.git"
REPO_DIR = "/kaggle/working/alphabuild"

if not os.path.exists(REPO_DIR):
    print(f"\nüì• Cloning repository...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
else:
    print(f"\nüì• Updating repository...")
    subprocess.run(["git", "-C", REPO_DIR, "pull"], check=True)

# Add to path
sys.path.insert(0, REPO_DIR)
os.chdir(REPO_DIR)

print(f"   Working directory: {os.getcwd()}")


## Data Setup

**IMPORTANT**: Upload the corrected `warmup_data_v3_2.db` to Kaggle datasets and add it to this notebook.


In [None]:
# ============================================================================

# Possible locations for the dataset
# Priority: Kaggle Input -> Local Upload -> Fallback
POSSIBLE_PATHS = [
    "/kaggle/input/alphabuilder-warmup-v3-2/warmup_data_v3_2.db",  
    "/kaggle/input/alphabuilder-warmup-data/warmup_data_v3_2.db",   
    "/kaggle/working/warmup_data_v3_2.db",                          
]

DATA_PATH = None
for path in POSSIBLE_PATHS:
    if os.path.exists(path):
        DATA_PATH = path
        break

if DATA_PATH is None:
    print("\n‚ö†Ô∏è Corrected dataset not found automatically.")
    print("   Looking for 'warmup_data_v3_2.db'...")
    # Fallback to standard name if user renamed it back
    if os.path.exists("/kaggle/input/alphabuilder-warmup-data/warmup_data.db"):
        print("   Found 'warmup_data.db'. assuming this MIGHT be the new one if you renamed it.")
        DATA_PATH = "/kaggle/input/alphabuilder-warmup-data/warmup_data.db"
    else:
        raise FileNotFoundError("Could not find training database!")

print(f"\nüìÇ Training data: {DATA_PATH}")
print(f"   Size: {os.path.getsize(DATA_PATH) / 1024**2:.1f} MB")


## Configure Training


In [None]:
# ============================================================================

CONFIG = {
    'use_swin': False,          # False = SimpleBackbone, True = Swin-UNETR
    'feature_size': 24,
    'batch_size': 32,           # Per GPU
    'epochs': 30,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'num_workers': 2,
    'use_amp': True,
    'val_split': 0.1,
    'patience': 10,
    'save_every': 5,
    'recalculate_values': False, # DISABLED: Data is already correct
}

print(f"\n‚öôÔ∏è Training Configuration:")
for k, v in CONFIG.items():
    print(f"   {k}: {v}")


## Import and Setup


In [None]:
# ============================================================================

import time
import json
import sqlite3
from pathlib import Path
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torch.amp import GradScaler, autocast
from tqdm.auto import tqdm

from alphabuilder.src.neural.model import AlphaBuilderV31
from alphabuilder.src.neural.dataset import TopologyDatasetV31
from alphabuilder.src.neural.trainer import policy_loss, weighted_value_loss, LAMBDA_POLICY


## Load Data and Model


In [None]:
# ============================================================================

print(f"\nüìÇ Loading dataset...")

# Direct loading of V3.3 dataset with corrected values
full_dataset = TopologyDatasetV31(
    db_path=Path(DATA_PATH),
    augment=True,
    preload_to_ram=True
)

print(f"   Total samples: {len(full_dataset):,}")

# Split train/val
val_size = int(len(full_dataset) * CONFIG['val_split'])
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(
    full_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"   Train samples: {len(train_dataset):,}")
print(f"   Val samples: {len(val_dataset):,}")

# Model Setup
print(f"\nüß† Creating model based on {CONFIG['use_swin'] and 'Swin-UNETR' or 'SimpleBackbone'}...")
device = torch.device('cuda')

model = AlphaBuilderV31(
    in_channels=7,
    out_channels=2,
    feature_size=CONFIG['feature_size'],
    use_swin=CONFIG['use_swin']
)

if torch.cuda.device_count() > 1:
    print(f"   Using DataParallel on {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

model = model.to(device)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=CONFIG['epochs'],
    eta_min=CONFIG['learning_rate'] * 0.01
)

scaler = GradScaler('cuda') if CONFIG['use_amp'] else None


## Training Loop


In [None]:
# ============================================================================
# Standard training functions

def train_epoch(model, loader, optimizer, scaler, device):
    model.train()
    total_loss, total_p_loss, total_v_loss = 0, 0, 0
    n_batches = 0
    pbar = tqdm(loader, desc="Training", leave=False)
    
    for batch in pbar:
        state = batch['state'].to(device, non_blocking=True)
        target_policy = batch['policy'].to(device, non_blocking=True)
        target_value = batch['value'].to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        
        if scaler:
            with autocast('cuda'):
                pred_policy, pred_value = model(state)
                v_loss = weighted_value_loss(pred_value, target_value)
                p_loss = policy_loss(pred_policy, target_policy)
                loss = v_loss + LAMBDA_POLICY * p_loss
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            pred_policy, pred_value = model(state)
            v_loss = weighted_value_loss(pred_value, target_value)
            p_loss = policy_loss(pred_policy, target_policy)
            loss = v_loss + LAMBDA_POLICY * p_loss
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
        total_p_loss += p_loss.item()
        total_v_loss += v_loss.item()
        n_batches += 1
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return {'loss': total_loss/n_batches, 'policy_loss': total_p_loss/n_batches, 'value_loss': total_v_loss/n_batches}

def validate_epoch(model, loader, device):
    model.eval()
    total_loss, total_p_loss, total_v_loss = 0, 0, 0
    n_batches = 0
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating", leave=False):
            state = batch['state'].to(device, non_blocking=True)
            target_policy = batch['policy'].to(device, non_blocking=True)
            target_value = batch['value'].to(device, non_blocking=True)
            
            with autocast('cuda'):
                pred_policy, pred_value = model(state)
                v_loss = weighted_value_loss(pred_value, target_value)
                p_loss = policy_loss(pred_policy, target_policy)
                loss = v_loss + LAMBDA_POLICY * p_loss
            
            total_loss += loss.item()
            total_p_loss += p_loss.item()
            total_v_loss += v_loss.item()
            n_batches += 1
            
    return {'loss': total_loss/n_batches, 'policy_loss': total_p_loss/n_batches, 'value_loss': total_v_loss/n_batches}

# === Main Loop ===

print(f"\nüöÄ Starting training...")
checkpoint_dir = Path("/kaggle/working/checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

best_val_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'lr': []}
batch_size = CONFIG['batch_size'] * max(1, torch.cuda.device_count())

training_start = time.time()

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()
    
    # === TRAIN ===
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, 
                              num_workers=CONFIG['num_workers'], pin_memory=True, drop_last=True)
    train_metrics = train_epoch(model, train_loader, optimizer, scaler, device)
    
    # AGGRESSIVE CLEANUP to avoid OOM
    del train_loader
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    
    # === VAL ===
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                            num_workers=CONFIG['num_workers'], pin_memory=True)
    val_metrics = validate_epoch(model, val_loader, device)
    
    # AGGRESSIVE CLEANUP
    del val_loader
    gc.collect()
    torch.cuda.empty_cache()
    
    scheduler.step()
    
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    history['lr'].append(scheduler.get_last_lr()[0])
    
    print(f"\nEpoch {epoch+1}/{CONFIG['epochs']} | Train: {train_metrics['loss']:.4f} | Val: {val_metrics['loss']:.4f}")
    
    # Save best
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.module.state_dict() if hasattr(model, 'module') else model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': best_val_loss,
            'config': CONFIG,
            'reward_formula': 'v3.2_additive(pre_corrected)',
        }, checkpoint_dir / "best_model.pt")
        print(f"  ‚úì Saved best model")
    else:
        patience_counter += 1
        
    if patience_counter >= CONFIG['patience']:
        print("‚ö†Ô∏è Early stopping")
        break

print(f"\n‚úÖ Done! Best Val Loss: {best_val_loss:.4f}")


In [None]:
# Copy outputs
!cp /kaggle/working/checkpoints/*.pt /kaggle/working/
!ls -la /kaggle/working/*.pt