# GazeGaussian Enhanced (DiT) - 2-Step Training

## Overview
This notebook trains the enhanced GazeGaussian model with:
1. **DiT Neural Renderer** (replacing U-Net)
2. **VAE Integration**
3. **Orthogonality Regularization**

## Training Process
- **Step 1**: Train MeshHead (~10 epochs, ~2-3 hours)
- **Step 2**: Train GazeGaussian with DiT (~30 epochs, ~8-12 hours)

## Requirements
- GPU: A100 (40GB recommended) or V100 (32GB minimum)
- Dataset: ETH-XGaze training set in Google Drive
- Time: ~12-15 hours total

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 1. Clone Repository

In [None]:
%cd /content
!rm -rf GazeGaussian
!git clone --recursive https://github.com/kram254/GazeGaussian.git
%cd GazeGaussian
!git submodule update --init --recursive

## 2A. ALTERNATIVE: Install Optuna for Automated Training (Optional)

If you want to use automated hyperparameter optimization instead of manual training, install Optuna:

In [None]:
!pip install optuna optuna-dashboard plotly kaleido

## 2. Install Dependencies

In [None]:
!pip install --upgrade pip setuptools wheel ninja

In [None]:
!pip install opencv-python h5py tqdm scipy scikit-image lpips kornia tensorboardX einops trimesh plyfile

In [None]:
!pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

## 3. Build CUDA Extensions

In [None]:
%cd /content/GazeGaussian/submodules/diff-gaussian-rasterization
!python setup.py install
%cd /content/GazeGaussian

In [None]:
%cd /content/GazeGaussian/submodules/simple-knn
!python setup.py install
%cd /content/GazeGaussian

In [None]:
!pip install kaolin-core

## 4. Verify Installation

In [None]:
print("\n" + "="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import kaolin
    try:
        kaolin_version = kaolin.__version__
    except AttributeError:
        kaolin_version = 'OK (version unknown)'
    print(f"✓ {'kaolin':15s} {kaolin_version}")
except ImportError as e:
    print(f"✗ {'kaolin':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL REQUIRED PACKAGES INSTALLED SUCCESSFULLY!")
    print("   Ready for training!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")

import torch
import os
from pathlib import Path
import numpy as np
from PIL import Image
import torchvision.utils as vutils
from tqdm import tqdm
from IPython.display import display, Image as IPImage

from configs.gazegaussian_options import BaseOptions
from models.gaze_gaussian import GazeGaussianNet
from dataloader.eth_xgaze import get_val_loader

def save_image_grid(images, save_path, nrow=4):
    grid = vutils.make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
    grid_np = grid.cpu().numpy().transpose(1, 2, 0)
    grid_np = np.clip((grid_np * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
    Image.fromarray(grid_np).save(save_path)
    return grid_np

print("="*80)
print("GENERATING TEST SAMPLES")
print("="*80)

checkpoint_path = "/content/drive/MyDrive/gazegaussian_dit_final.pth"
data_dir = "/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/test"
output_dir = "/content/test_outputs"
num_samples = 10

os.makedirs(output_dir, exist_ok=True)

print(f"\n[1/4] Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location='cuda')
print(f"✓ Checkpoint loaded")

print(f"\n[2/4] Initializing model with DiT...")
opt = BaseOptions()

if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif isinstance(checkpoint, dict):
    state_dict = checkpoint
else:
    state_dict = None

model = GazeGaussianNet(opt, load_state_dict=state_dict)
model = model.cuda()
model.eval()
print(f"✓ Model initialized")
print(f"  Neural Renderer: {type(model.neural_render).__name__}")

print(f"\n[3/4] Loading test data...")
opt.img_dir = data_dir
val_loader = get_val_loader(opt, data_dir=data_dir, batch_size=1, num_workers=0, evaluate=None, dataset_name='eth_xgaze')
print(f"✓ Test data loaded ({len(val_loader.dataset)} samples)")

print(f"\n[4/4] Generating {num_samples} samples...")

success = 0
with torch.no_grad():
    for idx, data in enumerate(tqdm(val_loader, total=num_samples)):
        if idx >= num_samples:
            break
        
        try:
            for key in data:
                if isinstance(data[key], torch.Tensor):
                    data[key] = data[key].cuda()
                elif isinstance(data[key], dict):
                    for sub_key in data[key]:
                        if isinstance(data[key][sub_key], torch.Tensor):
                            data[key][sub_key] = data[key][sub_key].cuda()
            
            output = model(data)
            
            gt_image = data.get('image', data.get('img', None))
            gaussian_img = output['total_render_dict']['merge_img']
            neural_img = output['total_render_dict']['merge_img_pro']
            
            if gt_image is not None:
                comparison = torch.cat([gt_image, gaussian_img, neural_img], dim=0)
                labels = "GT | Gaussian | DiT Enhanced"
            else:
                comparison = torch.cat([gaussian_img, neural_img], dim=0)
                labels = "Gaussian | DiT Enhanced"
            
            save_path = os.path.join(output_dir, f"test_sample_{idx:03d}.png")
            save_image_grid(comparison, save_path, nrow=len(comparison))
            
            save_image_grid(neural_img, os.path.join(output_dir, f"test_sample_{idx:03d}_dit.png"), nrow=1)
            
            success += 1
            
        except Exception as e:
            print(f"✗ Error on sample {idx}: {e}")
            continue

print("\n" + "="*80)
print(f"✅ Successfully generated {success}/{num_samples} test samples")
print(f"   Output directory: {output_dir}")
print("="*80)

!cp -r {output_dir} /content/drive/MyDrive/gazegaussian_test_outputs
print(f"\n✓ Copied outputs to Drive: /content/drive/MyDrive/gazegaussian_test_outputs")

print(f"\nDisplaying first 5 samples:")
for i in range(min(5, success)):
    img_path = os.path.join(output_dir, f"test_sample_{i:03d}.png")
    if os.path.exists(img_path):
        print(f"\nSample {i}:")
        display(IPImage(filename=img_path))

In [None]:
import json
from pathlib import Path

data_dir = Path("/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train")
h5_files = sorted([f.name for f in data_dir.glob("*.h5")])

print(f"Found {len(h5_files)} training files")
print(f"First 5 files: {h5_files[:5]}")

if not h5_files:
    print("\n❌ No .h5 files found! Check your path.")
else:
    train_split = int(len(h5_files) * 0.9)
    train_files = h5_files[:train_split]
    val_files = h5_files[train_split:]

    custom_config = {
        "train": train_files,
        "val": val_files,
        "val_gaze": val_files,
        "test": [],
        "test_specific": []
    }

    config_path = "/content/GazeGaussian/configs/dataset/eth_xgaze/train_test_split.json"
    with open(config_path, 'w') as f:
        json.dump(custom_config, f, indent=2)

    print(f"\n✓ Updated config")
    print(f"  - Training files: {len(train_files)}")
    print(f"  - Validation files: {len(val_files)}")

## 6. STEP 1: Train MeshHead (~10 epochs, ~2-3 hours)

This creates the canonical 3D head model.

## 6A. ALTERNATIVE: Train MeshHead with Optuna Optimization

This will automatically search for the best hyperparameters (learning rate, MLP sizes, etc.)

In [None]:
%cd /content/GazeGaussian

!python train_meshhead_optuna.py \
    --batch_size 1 \
    --name 'meshhead' \
    --img_dir '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train' \
    --num_epochs 10 \
    --num_workers 2 \
    --early_stopping \
    --patience 5 \
    --dataset_name 'eth_xgaze' \
    --n_trials 15 \
    --study_name 'meshhead_optuna' \
    --optuna_storage 'sqlite:///meshhead_optuna.db'

In [None]:
%cd /content/GazeGaussian

!python train_meshhead.py \
    --batch_size 1 \
    --name 'meshhead' \
    --img_dir '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train' \
    --num_epochs 10 \
    --num_workers 2 \
    --early_stopping \
    --patience 5 \
    --dataset_name 'eth_xgaze'

## 7. Verify MeshHead Checkpoint

In [None]:
import glob
import os

checkpoints = glob.glob("/content/GazeGaussian/work_dirs/meshhead_*/checkpoints/*.pth")
if checkpoints:
    latest_checkpoint = sorted(checkpoints)[-1]
    print(f"✓ MeshHead checkpoint found: {latest_checkpoint}")
    print(f"  Size: {os.path.getsize(latest_checkpoint) / (1024**2):.2f} MB")
    
    with open('/content/meshhead_checkpoint.txt', 'w') as f:
        f.write(latest_checkpoint)
    print(f"\n✓ Checkpoint path saved for Step 2")
else:
    print("❌ No MeshHead checkpoint found! Training may have failed.")

## 8. Verify DiT Configuration

In [None]:
from configs.gazegaussian_options import BaseOptions

opt = BaseOptions()

print("="*80)
print("ENHANCED MODEL CONFIGURATION")
print("="*80)
print(f"\n✓ Neural Renderer Type: {opt.neural_renderer_type}")
print(f"✓ DiT Depth: {opt.dit_depth}")
print(f"✓ DiT Num Heads: {opt.dit_num_heads}")
print(f"✓ DiT Patch Size: {opt.dit_patch_size}")
print(f"✓ VAE Enabled: {opt.use_vae}")
print(f"✓ VAE Z Channels: {opt.vae_z_channels}")
print(f"✓ VAE Frozen: {opt.freeze_vae}")
print(f"✓ Orthogonality Loss: {opt.use_orthogonality_loss}")
print(f"✓ Orthogonality Importance: {opt.orthogonality_loss_importance}")

if opt.neural_renderer_type == "dit" and opt.use_vae and opt.use_orthogonality_loss:
    print("\n✅ All 3 enhancements are ACTIVE!")
else:
    print("\n⚠ Some enhancements may be disabled!")

## 9. STEP 2: Train GazeGaussian with DiT (~30 epochs, ~8-12 hours)

This trains the full pipeline with your 3 enhancements.

In [None]:
%cd /content/GazeGaussian

with open('/content/meshhead_checkpoint.txt', 'r') as f:
    meshhead_checkpoint = f.read().strip()

print(f"Loading MeshHead from: {meshhead_checkpoint}")

!python train_gazegaussian.py \
    --batch_size 1 \
    --name 'gazegaussian_dit' \
    --img_dir '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train' \
    --num_epochs 30 \
    --num_workers 2 \
    --lr 0.0001 \
    --clip_grad \
    --load_meshhead_checkpoint {meshhead_checkpoint} \
    --dataset_name 'eth_xgaze'

## 9A. ALTERNATIVE: Train GazeGaussian with Optuna Optimization

This will automatically search for the best hyperparameters (learning rate, DiT depth/heads/patch size, loss weights, etc.)

In [None]:
%cd /content/GazeGaussian

with open('/content/meshhead_checkpoint.txt', 'r') as f:
    meshhead_checkpoint = f.read().strip()

print(f"Loading MeshHead from: {meshhead_checkpoint}")

!python train_gazegaussian_optuna.py \
    --batch_size 1 \
    --name 'gazegaussian_dit' \
    --img_dir '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train' \
    --num_epochs 30 \
    --num_workers 2 \
    --clip_grad \
    --load_meshhead_checkpoint {meshhead_checkpoint} \
    --dataset_name 'eth_xgaze' \
    --n_trials 20 \
    --study_name 'gazegaussian_optuna' \
    --optuna_storage 'sqlite:///gazegaussian_optuna.db'

## 10. Verify Final Checkpoint

In [None]:
import glob
import os

checkpoints = glob.glob("/content/GazeGaussian/work_dirs/gazegaussian_dit_*/checkpoints/*.pth")
if checkpoints:
    latest_checkpoint = sorted(checkpoints)[-1]
    print(f"✓ GazeGaussian checkpoint found: {latest_checkpoint}")
    print(f"  Size: {os.path.getsize(latest_checkpoint) / (1024**2):.2f} MB")
    print(f"\n✓ Training complete!")
    print(f"\nCopy checkpoint to Drive:")
    !cp {latest_checkpoint} /content/drive/MyDrive/gazegaussian_dit_final.pth
    print("✓ Saved to Drive: gazegaussian_dit_final.pth")
else:
    print("❌ No GazeGaussian checkpoint found! Training may have failed.")

## 12. Analyze Optuna Results (If Using Optuna)

View optimization history, parameter importance, and best hyperparameters

In [None]:
import optuna
import optuna.visualization as vis
from IPython.display import display, HTML
import json

study_path = 'sqlite:///gazegaussian_optuna.db'
study_name = 'gazegaussian_optuna'

try:
    study = optuna.load_study(study_name=study_name, storage=study_path)
    
    print("="*80)
    print("OPTUNA OPTIMIZATION RESULTS")
    print("="*80)
    print(f"\nTotal trials: {len(study.trials)}")
    print(f"Best trial: {study.best_trial.number}")
    print(f"Best validation loss: {study.best_trial.value:.6f}")
    
    print(f"\n{'='*80}")
    print("BEST HYPERPARAMETERS:")
    print("="*80)
    for key, value in study.best_trial.params.items():
        print(f"  {key:35s}: {value}")
    
    print(f"\n{'='*80}")
    print("VISUALIZATIONS:")
    print("="*80)
    
    print("\n1. Optimization History")
    fig = vis.plot_optimization_history(study)
    fig.show()
    
    print("\n2. Parameter Importances")
    fig = vis.plot_param_importances(study)
    fig.show()
    
    print("\n3. Parallel Coordinate Plot")
    fig = vis.plot_parallel_coordinate(study)
    fig.show()
    
    print("\n4. Slice Plot")
    fig = vis.plot_slice(study)
    fig.show()
    
    print("\n5. Contour Plot (Learning Rate vs DiT Depth)")
    fig = vis.plot_contour(study, params=['lr', 'dit_depth'])
    fig.show()
    
    with open('/content/best_hyperparameters.json', 'w') as f:
        json.dump(study.best_trial.params, f, indent=2)
    print(f"\n✓ Saved best hyperparameters to: /content/best_hyperparameters.json")
    
    best_checkpoint = f"/content/GazeGaussian/work_dirs/gazegaussian_dit_trial_{study.best_trial.number}/checkpoints"
    print(f"\n✓ Best checkpoint directory:")
    print(f"  {best_checkpoint}")
    
    !cp -r {best_checkpoint}/*.pth /content/drive/MyDrive/gazegaussian_optuna_best.pth
    print(f"\n✓ Copied best checkpoint to Google Drive!")
    
except Exception as e:
    print(f"❌ Error loading study: {str(e)}")
    print("\nMake sure you ran the Optuna training first!")

## 13. Launch Optuna Dashboard (Optional)

View interactive dashboard for real-time monitoring

In [None]:
!pip install pyngrok

from pyngrok import ngrok
import threading
import subprocess

ngrok.set_auth_token("YOUR_NGROK_TOKEN")

port = 8080
public_url = ngrok.connect(port)

print(f"="*80)
print("OPTUNA DASHBOARD")
print(f"="*80)
print(f"Dashboard URL: {public_url}")
print(f"\nOpen this URL in your browser to view the interactive dashboard")
print(f"="*80)

def run_dashboard():
    subprocess.run([
        "optuna-dashboard",
        "sqlite:///gazegaussian_optuna.db",
        "--port", str(port),
        "--host", "0.0.0.0"
    ])

thread = threading.Thread(target=run_dashboard, daemon=True)
thread.start()

print("\n✓ Dashboard is now running!")
print("⚠️  Keep this cell running to maintain the dashboard connection")

## 11. Generate Test Samples

Generate a few redirected gaze/pose samples for verification.

In [None]:
# TODO: Add inference code to generate samples
# This will be added after confirming training works
print("Sample generation coming in next update...")