AlphaBuilder v3.2 - Kaggle Training Script (NEW REWARD)
========================================================

This script is designed to run on Kaggle with GPU T4 x2.
Upload this as a Kaggle notebook and enable GPU accelerator.

**UPDATED**: Uses new additive reward formula for value head targets:
- `compliance_score = 0.80 - 0.16 * (log10(C) - 1)`
- `volume_bonus = (0.10 - V) * 2.0`  (always inversely proportional)
- `reward = compliance_score + volume_bonus`

Expected runtime: ~5 hours for 30 epochs (SimpleBackbone)


## Environment Setup


In [None]:
# ============================================================================

import subprocess
import sys
import os

print("="*60)
print("üöÄ AlphaBuilder v3.2 - Kaggle Training (NEW REWARD)")
print("="*60)

# Check GPU
import torch
print(f"\nüìä Hardware Detection:")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   GPU count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"   GPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        print(f"       Memory: {props.total_memory / 1024**3:.1f} GB")


## Clone Repository and Install Dependencies


In [None]:
# ============================================================================

REPO_URL = "https://github.com/gustavomello9600/alphabuild.git"
REPO_DIR = "/kaggle/working/alphabuild"

if not os.path.exists(REPO_DIR):
    print(f"\nüì• Cloning repository...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
else:
    print(f"\nüì• Updating repository...")
    subprocess.run(["git", "-C", REPO_DIR, "pull"], check=True)

# Add to path
sys.path.insert(0, REPO_DIR)
os.chdir(REPO_DIR)

print(f"   Working directory: {os.getcwd()}")


## Download Training Data


In [None]:
# ============================================================================

# Option A: From Kaggle Dataset (recommended)
# Add the dataset "gustavomello9600/alphabuilder-warmup-data" to your notebook

DATA_PATH = "/kaggle/input/alphabuilder-warmup-data/warmup_data_kaggle.db"

# Option B: From Google Drive (backup)
if not os.path.exists(DATA_PATH):
    print("\n‚ö†Ô∏è Dataset not found in Kaggle input.")
    print("   Please add the dataset: gustavomello9600/alphabuilder-warmup-data")
    print("   Or upload warmup_data.db manually.")
    
    # Try local path as fallback
    DATA_PATH = "/kaggle/working/alphabuild/data/warm_up_data/warmup_data.db"
    if not os.path.exists(DATA_PATH):
        raise FileNotFoundError(f"Training data not found at {DATA_PATH}")

print(f"\nüìÇ Training data: {DATA_PATH}")
print(f"   Size: {os.path.getsize(DATA_PATH) / 1024**2:.1f} MB")


## NEW: Define Reward Functions

These replace the old tanh-normalized formula with the new additive formula.


In [None]:
# ============================================================================
# NEW REWARD FORMULA (Dec 2024)
# ============================================================================

import numpy as np

# New formula constants
COMPLIANCE_BASE = 0.80       # Base score at C=10 (log10(10)=1)
COMPLIANCE_SLOPE = 0.16      # Score decrease per log10 unit
COMPLIANCE_MIN = -0.50       # Minimum compliance score
COMPLIANCE_MAX = 0.85        # Maximum compliance score

VOLUME_REFERENCE = 0.10      # Reference volume (V=0.10 gives no bonus/penalty)
VOLUME_SENSITIVITY = 2.0     # Bonus/penalty per 0.1 volume change
VOLUME_BONUS_MAX = 0.30      # Max bonus for very low V
VOLUME_PENALTY_MAX = 0.60    # Max penalty for very high V


def calculate_compliance_score(compliance: float) -> float:
    """
    Calculate the compliance component of the reward using log10 mapping.
    
    Mapping:
    - C=10 (log10=1) -> 0.80
    - C=1,000,000 (log10=6) -> 0.00
    """
    if compliance <= 0:
        return COMPLIANCE_MAX
    
    log_c = np.log10(max(compliance, 1.0))
    score = COMPLIANCE_BASE - COMPLIANCE_SLOPE * (log_c - 1.0)
    
    return float(np.clip(score, COMPLIANCE_MIN, COMPLIANCE_MAX))


def calculate_volume_bonus(vol_frac: float) -> float:
    """
    Calculate the volume bonus/penalty.
    
    Volume is ALWAYS inversely proportional to reward:
    - V < VOLUME_REFERENCE -> positive bonus (lean structure)
    - V > VOLUME_REFERENCE -> negative penalty (heavy structure)
    """
    adjustment = (VOLUME_REFERENCE - vol_frac) * VOLUME_SENSITIVITY
    return float(np.clip(adjustment, -VOLUME_PENALTY_MAX, VOLUME_BONUS_MAX))


def calculate_new_reward(compliance: float, vol_frac: float, is_valid: bool = True) -> float:
    """
    Calculate the reward using the NEW additive formula.
    
    R = compliance_score(C) + volume_bonus(V)
    
    Returns:
        Reward in range [-1, 1]
    """
    if not is_valid:
        return -1.0
    
    compliance_score = calculate_compliance_score(compliance)
    volume_bonus = calculate_volume_bonus(vol_frac)
    
    reward = compliance_score + volume_bonus
    
    return float(np.clip(reward, -1.0, 1.0))


# Test the new formula
print("üéØ NEW REWARD FORMULA TEST (at V=0.10):")
for c in [10, 100, 1000, 10000, 100000, 1000000]:
    r = calculate_new_reward(c, 0.10)
    print(f"   C={c:>8}: R={r:.3f}")


## Configure Training


In [None]:
# ============================================================================

# Training configuration
CONFIG = {
    'use_swin': False,          # False = SimpleBackbone, True = Swin-UNETR
    'feature_size': 24,
    'batch_size': 32,           # Per GPU (64 total with 2 GPUs)
    'epochs': 30,
    'learning_rate': 1e-4,
    'weight_decay': 1e-5,
    'num_workers': 0,
    'use_amp': True,            # Mixed precision
    'val_split': 0.1,
    'patience': 10,             # Early stopping
    'save_every': 5,            # Save checkpoint every N epochs
    'recalculate_values': True, # NEW: Recalculate value targets with new formula
}

print(f"\n‚öôÔ∏è Training Configuration:")
for k, v in CONFIG.items():
    print(f"   {k}: {v}")


## Import and Setup


In [None]:
# ============================================================================

import time
import json
import sqlite3
from pathlib import Path
from datetime import datetime

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split, Dataset
from torch.amp import GradScaler, autocast
from tqdm.auto import tqdm

from alphabuilder.src.neural.model import AlphaBuilderV31
from alphabuilder.src.neural.dataset import TopologyDatasetV31
from alphabuilder.src.neural.trainer import policy_loss, weighted_value_loss, LAMBDA_POLICY


## NEW: Custom Dataset with Value Recalculation

Wraps the base dataset but recalculates value targets using the new formula.


In [None]:
# ============================================================================
# WRAPPER DATASET THAT RECALCULATES VALUES
# ============================================================================

class RewardRecalculationDataset(Dataset):
    """
    Wraps TopologyDatasetV31 and recalculates value targets using new formula.
    
    This requires the database to have compliance and volume_fraction in metadata.
    If not available, falls back to the original value.
    """
    
    def __init__(self, base_dataset, db_path: Path, recalculate: bool = True):
        self.base_dataset = base_dataset
        self.db_path = db_path
        self.recalculate = recalculate
        
        # Build a mapping from index to compliance/volume
        self.value_overrides = {}
        
        if recalculate:
            self._load_compliance_volume_data()
    
    def _load_compliance_volume_data(self):
        """Load compliance and volume data from database metadata."""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        
        # Try v2 schema first (episodes + records)
        try:
            cursor.execute("""
                SELECT r.episode_id, r.step, e.metadata_blob
                FROM records r
                JOIN episodes e ON r.episode_id = e.episode_id
            """)
            rows = cursor.fetchall()
            
            for ep_id, step, meta_blob in rows:
                if meta_blob:
                    try:
                        import zlib
                        import pickle
                        meta = pickle.loads(zlib.decompress(meta_blob))
                        compliance = meta.get('compliance')
                        vol_frac = meta.get('volume_fraction')
                        
                        if compliance is not None and vol_frac is not None:
                            key = (ep_id, step)
                            new_value = calculate_new_reward(compliance, vol_frac)
                            self.value_overrides[key] = new_value
                    except:
                        pass
        except sqlite3.OperationalError:
            pass
        
        # Try v1 schema (training_data)
        try:
            cursor.execute("SELECT id, metadata FROM training_data")
            rows = cursor.fetchall()
            
            for rec_id, meta_json in rows:
                if meta_json:
                    try:
                        meta = json.loads(meta_json)
                        compliance = meta.get('compliance')
                        vol_frac = meta.get('volume_fraction')
                        
                        if compliance is not None and vol_frac is not None:
                            self.value_overrides[rec_id] = calculate_new_reward(compliance, vol_frac)
                    except:
                        pass
        except sqlite3.OperationalError:
            pass
        
        conn.close()
        
        print(f"   üìä Loaded {len(self.value_overrides)} value overrides")
    
    def __len__(self):
        return len(self.base_dataset)
    
    def __getitem__(self, idx):
        sample = self.base_dataset[idx]
        
        # Try to override value if we have recalculated data
        if self.recalculate and len(self.value_overrides) > 0:
            info = self.base_dataset.index[idx]
            
            # Try different key formats depending on schema
            key = None
            if 'episode_id' in info and 'step' in info:
                key = (info['episode_id'], info['step'])
            elif 'record_id' in info:
                key = info['record_id']
            
            if key and key in self.value_overrides:
                sample['value'] = torch.tensor([self.value_overrides[key]], dtype=torch.float32)
        
        return sample

print("‚úÖ RewardRecalculationDataset defined")


## Create DataLoaders


In [None]:
# ============================================================================

print(f"\nüìÇ Loading dataset...")

# Load base dataset
base_dataset = TopologyDatasetV31(
    db_path=Path(DATA_PATH),
    augment=True,
    preload_to_ram=True  # Kaggle has enough RAM
)

print(f"   Base samples: {len(base_dataset):,}")

# Wrap with value recalculation if enabled
if CONFIG['recalculate_values']:
    print("   üîÑ Wrapping dataset with value recalculation...")
    full_dataset = RewardRecalculationDataset(
        base_dataset, 
        Path(DATA_PATH),
        recalculate=True
    )
else:
    full_dataset = base_dataset

print(f"   Total samples: {len(full_dataset):,}")

# Split train/val
val_size = int(len(full_dataset) * CONFIG['val_split'])
train_size = len(full_dataset) - val_size

train_dataset, val_dataset = random_split(
    full_dataset,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"   Train samples: {len(train_dataset):,}")
print(f"   Val samples: {len(val_dataset):,}")

print(f"   DataLoaders will be created per-epoch for memory management")

## Create Model


In [None]:
# ============================================================================

print(f"\nüß† Creating model...")

device = torch.device('cuda')

model = AlphaBuilderV31(
    in_channels=7,
    out_channels=2,
    feature_size=CONFIG['feature_size'],
    use_swin=CONFIG['use_swin']
)

total_params = sum(p.numel() for p in model.parameters())
print(f"   Architecture: {'Swin-UNETR' if CONFIG['use_swin'] else 'SimpleBackbone'}")
print(f"   Total parameters: {total_params:,}")

# Multi-GPU
if torch.cuda.device_count() > 1:
    print(f"   Using DataParallel on {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

model = model.to(device)

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay']
)

# Scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=CONFIG['epochs'],
    eta_min=CONFIG['learning_rate'] * 0.01
)

# Mixed precision
scaler = GradScaler('cuda') if CONFIG['use_amp'] else None


## Training Functions


In [None]:
# ============================================================================

def train_epoch(model, loader, optimizer, scaler, device):
    model.train()
    total_loss = 0
    total_p_loss = 0
    total_v_loss = 0
    n_batches = 0
    
    pbar = tqdm(loader, desc="Training", leave=False)
    
    for batch in pbar:
        state = batch['state'].to(device, non_blocking=True)
        target_policy = batch['policy'].to(device, non_blocking=True)
        target_value = batch['value'].to(device, non_blocking=True)
        
        optimizer.zero_grad(set_to_none=True)
        
        if scaler is not None:
            with autocast('cuda'):
                pred_policy, pred_value = model(state)
                v_loss = weighted_value_loss(pred_value, target_value)
                p_loss = policy_loss(pred_policy, target_policy)
                loss = v_loss + LAMBDA_POLICY * p_loss
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            pred_policy, pred_value = model(state)
            v_loss = weighted_value_loss(pred_value, target_value)
            p_loss = policy_loss(pred_policy, target_policy)
            loss = v_loss + LAMBDA_POLICY * p_loss
            
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
        total_p_loss += p_loss.item()
        total_v_loss += v_loss.item()
        n_batches += 1
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return {
        'loss': total_loss / n_batches,
        'policy_loss': total_p_loss / n_batches,
        'value_loss': total_v_loss / n_batches,
    }


def validate_epoch(model, loader, device):
    model.eval()
    total_loss = 0
    total_p_loss = 0
    total_v_loss = 0
    n_batches = 0
    
    with torch.no_grad():
        for batch in tqdm(loader, desc="Validating", leave=False):
            state = batch['state'].to(device, non_blocking=True)
            target_policy = batch['policy'].to(device, non_blocking=True)
            target_value = batch['value'].to(device, non_blocking=True)
            
            with autocast('cuda'):
                pred_policy, pred_value = model(state)
                v_loss = weighted_value_loss(pred_value, target_value)
                p_loss = policy_loss(pred_policy, target_policy)
                loss = v_loss + LAMBDA_POLICY * p_loss
            
            total_loss += loss.item()
            total_p_loss += p_loss.item()
            total_v_loss += v_loss.item()
            n_batches += 1
    
    return {
        'loss': total_loss / n_batches,
        'policy_loss': total_p_loss / n_batches,
        'value_loss': total_v_loss / n_batches,
    }


## Training Loop


In [None]:
def get_memory_info():
    """Get current GPU and RAM memory usage."""
    import psutil
    ram_percent = psutil.virtual_memory().percent
    ram_used = psutil.virtual_memory().used / 1024**3
    
    if torch.cuda.is_available():
        gpu_allocated = torch.cuda.memory_allocated() / 1024**3
        gpu_reserved = torch.cuda.memory_reserved() / 1024**3
        return f"RAM: {ram_used:.1f}GB ({ram_percent:.0f}%) | GPU: {gpu_allocated:.1f}GB alloc / {gpu_reserved:.1f}GB reserved"
    return f"RAM: {ram_used:.1f}GB ({ram_percent:.0f}%)"

print(f"\nüöÄ Starting training: {CONFIG['epochs']} epochs")
print("-" * 60)
print(f"üìä Initial memory: {get_memory_info()}")

checkpoint_dir = Path("/kaggle/working/checkpoints")
checkpoint_dir.mkdir(exist_ok=True)

best_val_loss = float('inf')
patience_counter = 0
history = {'train_loss': [], 'val_loss': [], 'lr': []}
batch_size = CONFIG['batch_size'] * max(1, torch.cuda.device_count())
num_workers = CONFIG['num_workers']

training_start = time.time()

for epoch in range(CONFIG['epochs']):
    epoch_start = time.time()
    
    # Create fresh train loader
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True,
    )
    
    train_metrics = train_epoch(model, train_loader, optimizer, scaler, device)
    
    # Cleanup
    del train_loader
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    print(f"  üìä After train cleanup: {get_memory_info()}")
    
    # Create fresh val loader
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    
    val_metrics = validate_epoch(model, val_loader, device)
    
    del val_loader
    gc.collect()
    torch.cuda.empty_cache()
    print(f"  üìä After val cleanup: {get_memory_info()}")
    
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    history['lr'].append(current_lr)
    
    epoch_time = time.time() - epoch_start
    samples_per_sec = len(train_dataset) / epoch_time
    
    print(f"\nEpoch {epoch + 1}/{CONFIG['epochs']} ({epoch_time:.1f}s, {samples_per_sec:.0f} samples/s)")
    print(f"  Train | Loss: {train_metrics['loss']:.4f} | P: {train_metrics['policy_loss']:.4f} | V: {train_metrics['value_loss']:.4f}")
    print(f"  Val   | Loss: {val_metrics['loss']:.4f} | P: {val_metrics['policy_loss']:.4f} | V: {val_metrics['value_loss']:.4f}")
    print(f"  LR: {current_lr:.2e}")
    
    if val_metrics['loss'] < best_val_loss:
        best_val_loss = val_metrics['loss']
        patience_counter = 0
        torch.cuda.empty_cache()
        model_state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_state,
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': best_val_loss,
            'config': CONFIG,
            'reward_formula': 'v3.2_additive',  # NEW: Track formula version
        }, checkpoint_dir / "best_model.pt")
        print(f"  ‚úì New best model saved (val_loss: {best_val_loss:.4f})")
        del model_state
    else:
        patience_counter += 1
    
    if (epoch + 1) % CONFIG['save_every'] == 0:
        model_state = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_state,
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['loss'],
            'config': CONFIG,
            'reward_formula': 'v3.2_additive',
        }, checkpoint_dir / f"checkpoint_epoch_{epoch + 1}.pt")
        del model_state
    
    if patience_counter >= CONFIG['patience']:
        print(f"\n‚ö†Ô∏è Early stopping triggered (patience={CONFIG['patience']})")
        break

## Final Summary


In [None]:
# ============================================================================

total_time = time.time() - training_start

print("\n" + "=" * 60)
print("üéâ Training Complete!")
print("=" * 60)
print(f"   Total time: {total_time / 3600:.1f} hours")
print(f"   Best validation loss: {best_val_loss:.4f}")
print(f"   Checkpoints saved to: {checkpoint_dir}")
print(f"   Reward formula: v3.2_additive")

# List saved files
print(f"\nüìÅ Saved files:")
for f in sorted(checkpoint_dir.glob("*.pt")):
    print(f"   {f.name} ({f.stat().st_size / 1024**2:.1f} MB)")

# Plot training history
try:
    import matplotlib.pyplot as plt
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    ax1.plot(history['train_loss'], label='Train')
    ax1.plot(history['val_loss'], label='Val')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training Progress (v3.2 Additive Reward)')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    ax2.plot(history['lr'])
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Learning Rate')
    ax2.set_title('LR Schedule')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(checkpoint_dir / "training_history.png", dpi=150)
    plt.show()
    print(f"\nüìä Training history saved to: {checkpoint_dir / 'training_history.png'}")
except Exception as e:
    print(f"\n‚ö†Ô∏è Could not plot history: {e}")

print("\n‚úÖ Done! Download checkpoints from /kaggle/working/checkpoints/")


In [None]:
# Listar arquivos
!ls -la /kaggle/working/checkpoints/
# Copiar para output root (mais vis√≠vel na UI)
!cp /kaggle/working/checkpoints/*.pt /kaggle/working/
!ls -la /kaggle/working/*.pt