# GazeGaussian Enhanced (DiT) - 2-Step Training

## Overview
This notebook trains the enhanced GazeGaussian model with:
1. **DiT Neural Renderer** (replacing U-Net)
2. **VAE Integration**
3. **Orthogonality Regularization**

## Training Process
- **Step 1**: Train MeshHead (~10 epochs, ~2-3 hours)
- **Step 2**: Train GazeGaussian with DiT (~30 epochs, ~8-12 hours)

## Requirements
- GPU: A100 (40GB recommended) or V100 (32GB minimum)
- Dataset: ETH-XGaze training set in Google Drive
- Time: ~12-15 hours total

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 1. Clone Repository

In [None]:
%cd /content
!rm -rf GazeGaussian
!git clone --recursive https://github.com/kram254/GazeGaussian.git
%cd GazeGaussian
!git submodule update --init --recursive

## 2. Install Dependencies

In [None]:
!pip install --upgrade pip setuptools wheel ninja

In [None]:
!pip install opencv-python h5py tqdm scipy scikit-image lpips kornia tensorboardX einops trimesh plyfile

In [None]:
!pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

## 3. Build CUDA Extensions

In [None]:
%cd /content/GazeGaussian/submodules/diff-gaussian-rasterization
!python setup.py install
%cd /content/GazeGaussian

In [None]:
%cd /content/GazeGaussian/submodules/simple-knn
!python setup.py install
%cd /content/GazeGaussian

In [None]:
!pip install kaolin-core

## 4. Verify Installation

In [None]:
print("\n" + "="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import kaolin
    try:
        kaolin_version = kaolin.__version__
    except AttributeError:
        kaolin_version = 'OK (version unknown)'
    print(f"✓ {'kaolin':15s} {kaolin_version}")
except ImportError as e:
    print(f"✗ {'kaolin':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL REQUIRED PACKAGES INSTALLED SUCCESSFULLY!")
    print("   Ready for training!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")

## 5. Configure Dataset

In [None]:
import json
from pathlib import Path

data_dir = Path("/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train")
h5_files = sorted([f.name for f in data_dir.glob("*.h5")])

print(f"Found {len(h5_files)} training files")
print(f"First 5 files: {h5_files[:5]}")

if not h5_files:
    print("\n❌ No .h5 files found! Check your path.")
else:
    train_split = int(len(h5_files) * 0.9)
    train_files = h5_files[:train_split]
    val_files = h5_files[train_split:]

    custom_config = {
        "train": train_files,
        "val": val_files,
        "val_gaze": val_files,
        "test": [],
        "test_specific": []
    }

    config_path = "/content/GazeGaussian/configs/dataset/eth_xgaze/train_test_split.json"
    with open(config_path, 'w') as f:
        json.dump(custom_config, f, indent=2)

    print(f"\n✓ Updated config")
    print(f"  - Training files: {len(train_files)}")
    print(f"  - Validation files: {len(val_files)}")

## 6. STEP 1: Train MeshHead (~10 epochs, ~2-3 hours)

This creates the canonical 3D head model.

In [None]:
%cd /content/GazeGaussian

!python train_meshhead.py \
    --batch_size 1 \
    --name 'meshhead' \
    --img_dir '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train' \
    --num_epochs 10 \
    --num_workers 2 \
    --early_stopping \
    --patience 5 \
    --dataset_name 'eth_xgaze'

## 7. Verify MeshHead Checkpoint

In [None]:
import glob
import os

checkpoints = glob.glob("/content/GazeGaussian/work_dirs/meshhead_*/checkpoints/*.pth")
if checkpoints:
    latest_checkpoint = sorted(checkpoints)[-1]
    print(f"✓ MeshHead checkpoint found: {latest_checkpoint}")
    print(f"  Size: {os.path.getsize(latest_checkpoint) / (1024**2):.2f} MB")
    
    with open('/content/meshhead_checkpoint.txt', 'w') as f:
        f.write(latest_checkpoint)
    print(f"\n✓ Checkpoint path saved for Step 2")
else:
    print("❌ No MeshHead checkpoint found! Training may have failed.")

## 8. Verify DiT Configuration

In [None]:
from configs.gazegaussian_options import BaseOptions

opt = BaseOptions()

print("="*80)
print("ENHANCED MODEL CONFIGURATION")
print("="*80)
print(f"\n✓ Neural Renderer Type: {opt.neural_renderer_type}")
print(f"✓ DiT Depth: {opt.dit_depth}")
print(f"✓ DiT Num Heads: {opt.dit_num_heads}")
print(f"✓ DiT Patch Size: {opt.dit_patch_size}")
print(f"✓ VAE Enabled: {opt.use_vae}")
print(f"✓ VAE Z Channels: {opt.vae_z_channels}")
print(f"✓ VAE Frozen: {opt.freeze_vae}")
print(f"✓ Orthogonality Loss: {opt.use_orthogonality_loss}")
print(f"✓ Orthogonality Importance: {opt.orthogonality_loss_importance}")

if opt.neural_renderer_type == "dit" and opt.use_vae and opt.use_orthogonality_loss:
    print("\n✅ All 3 enhancements are ACTIVE!")
else:
    print("\n⚠ Some enhancements may be disabled!")

## 9. STEP 2: Train GazeGaussian with DiT (~30 epochs, ~8-12 hours)

This trains the full pipeline with your 3 enhancements.

In [None]:
%cd /content/GazeGaussian

with open('/content/meshhead_checkpoint.txt', 'r') as f:
    meshhead_checkpoint = f.read().strip()

print(f"Loading MeshHead from: {meshhead_checkpoint}")

!python train_gazegaussian.py \
    --batch_size 1 \
    --name 'gazegaussian_dit' \
    --img_dir '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train' \
    --num_epochs 30 \
    --num_workers 2 \
    --lr 0.0001 \
    --clip_grad \
    --load_meshhead_checkpoint {meshhead_checkpoint} \
    --dataset_name 'eth_xgaze'

## 10. Verify Final Checkpoint

In [None]:
import glob
import os

checkpoints = glob.glob("/content/GazeGaussian/work_dirs/gazegaussian_dit_*/checkpoints/*.pth")
if checkpoints:
    latest_checkpoint = sorted(checkpoints)[-1]
    print(f"✓ GazeGaussian checkpoint found: {latest_checkpoint}")
    print(f"  Size: {os.path.getsize(latest_checkpoint) / (1024**2):.2f} MB")
    print(f"\n✓ Training complete!")
    print(f"\nCopy checkpoint to Drive:")
    !cp {latest_checkpoint} /content/drive/MyDrive/gazegaussian_dit_final.pth
    print("✓ Saved to Drive: gazegaussian_dit_final.pth")
else:
    print("❌ No GazeGaussian checkpoint found! Training may have failed.")

## 11. Generate Test Samples

Generate a few redirected gaze/pose samples for verification.

In [None]:
# TODO: Add inference code to generate samples
# This will be added after confirming training works
print("Sample generation coming in next update...")