# üåå DeepLeGATo++ Training Notebook

**Next-Generation Galaxy Profile Fitting with Transformers and Neural Posterior Estimation**

This notebook trains DeepLeGATo++ on Google Colab with automatic:
- GPU detection and config selection
- Checkpoint saving to Google Drive
- Resume from previous checkpoints

## 1Ô∏è‚É£ Setup & Installation

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Setup project from Google Drive
import os
import sys

PROJECT_NAME = "DeepLeGATo++"
DRIVE_PATH = f"/content/drive/MyDrive/{PROJECT_NAME}"

# Check if code exists on Drive
if os.path.exists(f"{DRIVE_PATH}/deeplegato_pp"):
    print("‚úÖ Project found on Google Drive!")
    !cp -r "{DRIVE_PATH}/deeplegato_pp" /content/
    !cp -r "{DRIVE_PATH}/configs" /content/
    sys.path.insert(0, '/content')
else:
    print("‚ùå Project not found! Please upload DeepLeGATo++ folder to Google Drive.")
    print(f"Expected path: {DRIVE_PATH}")

In [None]:
# Install dependencies
!pip install -q torch>=2.1.0 pytorch-lightning>=2.1.0 timm>=0.9.12
!pip install -q nflows zuko astropy photutils
!pip install -q wandb gradio plotly seaborn
!pip install -q einops pyyaml
print("‚úÖ Dependencies installed!")

In [None]:
# Verify GPU
import torch

print("=" * 50)
print("GPU Information")
print("=" * 50)

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    vram = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"‚úÖ GPU: {gpu_name}")
    print(f"‚úÖ VRAM: {vram:.1f} GB")
else:
    print("‚ùå No GPU available!")
    print("Go to Runtime > Change runtime type > GPU")

## 2Ô∏è‚É£ Configuration

In [None]:
from deeplegato_pp.training.colab_utils import (
    setup_drive_paths,
    auto_select_config,
    get_latest_checkpoint,
)

# Setup paths on Drive
paths = setup_drive_paths(PROJECT_NAME)
print("\nProject paths:")
for name, path in paths.items():
    print(f"  {name}: {path}")

In [None]:
import yaml

# Auto-select config based on GPU
config_path = auto_select_config("/content/configs")

with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print(f"\nUsing config: {config_path}")
print(f"\nKey settings:")
print(f"  Backbone: {config['model']['backbone']['type']}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Accumulation: {config['training']['accumulation_steps']}")
print(f"  Effective batch: {config['training']['batch_size'] * config['training']['accumulation_steps']}")
print(f"  Max epochs: {config['training']['max_epochs']}")

## 3Ô∏è‚É£ Initialize Model & Data

In [None]:
from deeplegato_pp.models import DeepLeGAToPP
from deeplegato_pp.data import create_dataloaders

# Create model
model = DeepLeGAToPP.from_config(config)

# Print model info
param_counts = model.count_parameters()
print(f"\nüìä Model Parameters:")
print(f"  Total: {param_counts['total']:,}")
print(f"  Trainable: {param_counts['trainable']:,}")
print(f"  Backbone: {param_counts['backbone']:,}")
print(f"  NPE Head: {param_counts['npe_head']:,}")

In [None]:
# Create dataloaders
print("\nüìÅ Creating dataloaders...")
train_loader, val_loader = create_dataloaders(config)

print(f"  Training samples: {len(train_loader.dataset):,}")
print(f"  Validation samples: {len(val_loader.dataset):,}")
print(f"  Batches per epoch: {len(train_loader):,}")

In [None]:
# Visualize sample data
import matplotlib.pyplot as plt

batch = next(iter(train_loader))
images = batch['image']
params = batch['params']

fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for i, ax in enumerate(axes.flatten()):
    if i < len(images):
        ax.imshow(images[i, 0].cpu(), cmap='viridis')
        ax.set_title(f"n={params[i, 2]:.1f}, Re={params[i, 1]:.1f}")
        ax.axis('off')
plt.suptitle("Sample Training Data")
plt.tight_layout()
plt.show()

## 4Ô∏è‚É£ Training

In [None]:
from deeplegato_pp.training import train

# Check for existing checkpoint
checkpoint = get_latest_checkpoint(paths['checkpoints'])

if checkpoint:
    print(f"\nüîÑ Found checkpoint: {checkpoint}")
    print("Training will resume from this checkpoint.")
else:
    print("\nüÜï Starting fresh training run.")

In [None]:
# üöÄ START TRAINING
print("\n" + "=" * 50)
print("üöÄ STARTING TRAINING")
print("=" * 50)
print("\nCheckpoints will be saved to Google Drive automatically.")
print("If the session disconnects, re-run this notebook to resume.\n")

trained_model = train(
    config=config,
    resume=True,  # Auto-resume from checkpoint
    fast_dev_run=False,  # Set True for quick test
)

## 5Ô∏è‚É£ Evaluate Results

In [None]:
from deeplegato_pp.training.colab_utils import get_best_checkpoint
from deeplegato_pp.inference import Predictor

# Load best model
best_ckpt = get_best_checkpoint(paths['checkpoints'])
if best_ckpt:
    print(f"Loading best checkpoint: {best_ckpt}")
    checkpoint = torch.load(best_ckpt, map_location='cuda')
    model.load_state_dict(checkpoint['state_dict'])

# Create predictor
predictor = Predictor(model, device='cuda')

In [None]:
# Test on validation data
val_batch = next(iter(val_loader))
test_image = val_batch['image'][0:1].cuda()
true_params = val_batch['params'][0]

# Predict
result = predictor.predict(test_image, num_samples=1000, return_samples=True)

# Print comparison
print("\n" + "=" * 60)
print("Prediction vs True Values")
print("=" * 60)

param_names = ["magnitude", "effective_radius", "sersic_index", 
               "axis_ratio", "position_angle", "center_x", "center_y"]

for i, name in enumerate(param_names):
    pred = result['params'][name]['value']
    std = result['params'][name]['std']
    true = true_params[i].item()
    error = abs(pred - true)
    print(f"{name:18s}: pred={pred:7.3f} ¬± {std:.3f}  |  true={true:7.3f}  |  error={error:.3f}")

In [None]:
# Plot posterior distributions
predictor.plot_posterior(result, save_path=str(paths['outputs'] / 'posterior_example.png'))

## 6Ô∏è‚É£ Save Final Model

In [None]:
# Save model to Drive
save_path = paths['models'] / 'final_model'
model.save_pretrained(save_path)

print(f"\n‚úÖ Model saved to: {save_path}")
print("\nYou can load this model later with:")
print(f"  model = DeepLeGAToPP.from_pretrained('{save_path}')")

---

## üéâ Training Complete!

Your trained DeepLeGATo++ model is now saved to Google Drive.

**Next steps:**
1. Test on real galaxy images
2. Export model for deployment
3. Fine-tune on specific survey data (JWST, LSST)