# GazeGaussian - Google Colab Setup

## Key Changes
- **PyTorch3D removed**: Functions replaced with native PyTorch implementations
- **Clean build process**: Ensures ABI compatibility by rebuilding extensions with correct PyTorch version
- **CUDA compatibility**: Matches PyTorch installation to Colab's CUDA version

## Setup Steps
1. Check GPU availability
2. Mount Google Drive for data/checkpoints
3. Clone repository with submodules
4. Install dependencies
5. Clean any previous builds
6. Reinstall PyTorch for proper CUDA compatibility
7. Build custom CUDA extensions (diff-gaussian-rasterization, simple-knn)
8. Verify all imports

Run cells in order. If any step fails, check error messages and rerun that specific cell.

In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

### Re-clone Repository (if updating code)
Run this cell if you need to re-clone with updated code.

In [None]:
%cd /content
!rm -rf GazeGaussian
!git clone --recursive https://github.com/kram254/GazeGaussian.git
%cd GazeGaussian
!git submodule update --init --recursive
print("✓ Repository re-cloned with latest code")

In [None]:
%cd /content
!git clone --recursive https://github.com/Abiram929/GazeGaussian.git
%cd GazeGaussian

In [None]:
!git submodule update --init --recursive

In [None]:
import torch
import sys
print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"CUDA available: {torch.cuda.is_available()}")

### Download Required Model Files

In [None]:
%cd /content/GazeGaussian/configs
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/config_models.zip
!unzip -q config_models.zip
%cd /content/GazeGaussian
print("✓ Model config files downloaded")

In [None]:
!pip install --upgrade pip setuptools wheel ninja

In [None]:
!pip install opencv-python h5py tqdm scipy scikit-image lpips kornia

### Note: PyTorch3D Not Required
PyTorch3D functions have been replaced with native PyTorch implementations in the codebase.

### Clean Installation Process
The following cells will:
1. Remove any conflicting cached builds
2. Reinstall PyTorch to ensure ABI compatibility
3. Build custom CUDA extensions from scratch

### Step 1: Clean Previous Builds
Remove any previously compiled extensions that may be incompatible.

In [None]:
!pip uninstall -y diff-gaussian-rasterization simple-knn kaolin
!pip cache purge
!rm -rf /usr/local/lib/python3.*/dist-packages/diff_gaussian_rasterization*
!rm -rf /usr/local/lib/python3.*/dist-packages/simple_knn*
!rm -rf /usr/local/lib/python3.*/dist-packages/kaolin*
!rm -rf /content/GazeGaussian/submodules/diff-gaussian-rasterization/build
!rm -rf /content/GazeGaussian/submodules/simple-knn/build
!rm -rf /tmp/*
print("✓ Cleanup complete")

### Step 2: Reinstall PyTorch
Ensure PyTorch is properly matched with the system CUDA version to avoid ABI compatibility issues.

In [None]:
import torch
print(f"Current PyTorch: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
print("Ensuring PyTorch is properly installed for this CUDA version...")
!pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
import torch
print(f"\nReinstalled PyTorch: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
print(f"CUDA available: {torch.cuda.is_available()}")

### Step 3: Fix simple-knn Source Code
Add missing header for FLT_MAX constant.

In [None]:
import os
os.chdir('/content/GazeGaussian/submodules/simple-knn')
with open('simple_knn.cu', 'r') as f:
    content = f.read()
if '#include <cfloat>' not in content:
    content = content.replace('#include <vector>', '#include <vector>\n#include <cfloat>')
    with open('simple_knn.cu', 'w') as f:
        f.write(content)
os.chdir('/content/GazeGaussian')

### Step 4: Build Custom CUDA Extensions
Build diff-gaussian-rasterization and simple-knn from source.

In [None]:
%cd submodules/diff-gaussian-rasterization
!python setup.py install
%cd ../..

### Step 5: Install Kaolin (Required for MeshHead Training)

In [None]:
%cd /content
!git clone --recursive https://github.com/NVIDIAGameWorks/kaolin
%cd kaolin
!python setup.py install
%cd /content/GazeGaussian
print("\n✓ Kaolin installed successfully!")

In [None]:
print("\n" + "="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import kaolin
    print(f"✓ {'kaolin':15s} {kaolin.__version__}")
except ImportError as e:
    print(f"✗ {'kaolin':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL REQUIRED PACKAGES INSTALLED SUCCESSFULLY!")
    print("   Ready for MeshHead and GazeGaussian training!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")

In [None]:
print("\n" + "="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL CORE PACKAGES INSTALLED SUCCESSFULLY!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")

## Training Commands

### Train MeshHead
```bash
python train_meshhead.py --img_dir /path/to/images --checkpoint_path /content/drive/MyDrive/GazeGaussian_checkpoints/meshhead
```

### Train GazeGaussian
```bash
python train.py -s /path/to/dataset -m /content/drive/MyDrive/GazeGaussian_checkpoints/model
```

Replace paths with your actual data locations. Checkpoints will be saved to Google Drive.

## Verify Data Availability

In [None]:
import os
from pathlib import Path

print("="*80)
print("DATA AVAILABILITY CHECK")
print("="*80)

data_locations = [
    "/content/data",
    "/content/GazeGaussian/data",
    "/content/drive/MyDrive/GazeGaussian_data",
]

print("\n📁 Checking common data locations:\n")

for location in data_locations:
    if os.path.exists(location):
        print(f"✓ Found: {location}")
        
        contents = list(Path(location).rglob("*"))
        dirs = [f for f in contents if f.is_dir()]
        files = [f for f in contents if f.is_file()]
        images = [f for f in files if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]
        
        print(f"  - Subdirectories: {len(dirs)}")
        print(f"  - Total files: {len(files)}")
        print(f"  - Image files: {len(images)}")
        
        if len(contents) <= 20:
            print(f"\n  Contents:")
            for item in sorted(contents)[:20]:
                print(f"    {item.relative_to(location)}")
        else:
            print(f"\n  First 10 items:")
            for item in sorted(contents)[:10]:
                print(f"    {item.relative_to(location)}")
        print()
    else:
        print(f"✗ Not found: {location}")

print("="*80)
print("\n💡 To upload data:")
print("1. From local: Use Colab's file upload or mount Google Drive")
print("2. From Drive: Copy to /content/drive/MyDrive/GazeGaussian_data")
print("3. From URL: Use !wget or !gdown commands")
print("\nExample: !gdown <google-drive-file-id> -O /content/data/dataset.zip")
print("         !unzip /content/data/dataset.zip -d /content/data/")
print("="*80)

### Create Custom Config for Your Dataset
Your data uses `xgaze_` prefix and different subject numbers. This cell creates a matching config.

In [None]:
import json
from pathlib import Path

data_dir = Path("/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test/ETH-XGaze_test")
h5_files = sorted([f.name for f in data_dir.glob("*.h5")])

print(f"Found {len(h5_files)} HDF5 files:\n{h5_files}")

train_split = int(len(h5_files) * 0.8)
train_files = h5_files[:train_split]
val_files = h5_files[train_split:]

custom_config = {
    "train": train_files,
    "val": val_files,
    "val_gaze": val_files,
    "test": [],
    "test_specific": []
}

config_path = "/content/GazeGaussian/configs/dataset/eth_xgaze/train_test_split.json"
with open(config_path, 'w') as f:
    json.dump(custom_config, f, indent=2)

print(f"\n✓ Updated config at: {config_path}")
print(f"  - Training files: {len(train_files)}")
print(f"  - Validation files: {len(val_files)}")
print(f"\nFirst 5 training files: {train_files[:5]}")
print(f"Validation files: {val_files}")

### Optional: Upload/Download Data
Uncomment and modify the commands below based on your data source.

In [None]:
!mkdir -p /content/data

print("Choose your data source method:\n")

print("Option 1: Copy from Google Drive")
print("!cp -r /content/drive/MyDrive/YourDatasetFolder /content/data/")
print()

print("Option 2: Download from URL")
print("!wget https://your-url.com/dataset.zip -O /content/data/dataset.zip")
print("!unzip /content/data/dataset.zip -d /content/data/")
print()

print("Option 3: Download from Google Drive (requires gdown)")
print("!pip install gdown")
print("!gdown <FILE_ID> -O /content/data/dataset.zip")
print("!unzip /content/data/dataset.zip -d /content/data/")
print()

print("Option 4: Upload files manually")
print("from google.colab import files")
print("uploaded = files.upload()")
print()

print("Option 5: Clone from Git LFS")
print("!git lfs clone https://huggingface.co/datasets/your-dataset /content/data/")
print()

print("\n⚠ Uncomment the commands above that match your data source and modify paths accordingly.")

### Inspect HDF5 Dataset Contents
Check what's inside the ETH-XGaze .h5 files

In [None]:
import h5py
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

data_dir = Path("/content/drive/MyDrive/GazeGaussian_data")
h5_files = list(data_dir.rglob("*.h5"))

print("="*80)
print("HDF5 DATASET INSPECTION")
print("="*80)

if not h5_files:
    print("❌ No .h5 files found!")
else:
    print(f"\n✓ Found {len(h5_files)} HDF5 files\n")
    
    sample_file = h5_files[0]
    print(f"📄 Inspecting: {sample_file.name}\n")
    
    with h5py.File(sample_file, 'r') as f:
        print("Dataset keys/groups:")
        def print_structure(name, obj):
            if isinstance(obj, h5py.Dataset):
                print(f"  📊 {name}: shape={obj.shape}, dtype={obj.dtype}")
            elif isinstance(obj, h5py.Group):
                print(f"  📁 {name}/ (group)")
        
        f.visititems(print_structure)
        
        if 'face_patch' in f or 'image' in f or 'img' in f:
            img_key = 'face_patch' if 'face_patch' in f else ('image' if 'image' in f else 'img')
            images = f[img_key]
            print(f"\n📸 Image data found!")
            print(f"   Total images: {len(images)}")
            print(f"   Image shape: {images[0].shape}")
            print(f"   Image dtype: {images.dtype}")
            
            if 'gaze_label' in f or 'gaze' in f:
                gaze_key = 'gaze_label' if 'gaze_label' in f else 'gaze'
                gazes = f[gaze_key]
                print(f"\n👁 Gaze data found!")
                print(f"   Total gaze labels: {len(gazes)}")
                print(f"   Gaze shape: {gazes[0].shape}")
            
            print(f"\n🖼 Displaying first 3 sample images...")
            fig, axes = plt.subplots(1, 3, figsize=(12, 4))
            for i in range(min(3, len(images))):
                img = images[i]
                if img.dtype == np.uint8:
                    axes[i].imshow(img)
                else:
                    axes[i].imshow(img.astype(np.uint8))
                axes[i].set_title(f"Sample {i+1}")
                axes[i].axis('off')
            plt.tight_layout()
            plt.show()
        else:
            print("\n⚠ No standard image key found. Available keys:")
            print(f"   {list(f.keys())}")

print("\n" + "="*80)
print(f"\n✅ Dataset Structure:")
print(f"   Location: {data_dir}")
print(f"   Total .h5 files: {len(h5_files)}")
print(f"   Ready for training: {'✓ Yes' if h5_files else '✗ No'}")
print("="*80)

### Create Data Symlink (Optional)
Link your Google Drive data to the project directory for easier access.

In [None]:
import h5py
from pathlib import Path

data_dir = Path("/content/drive/MyDrive/GazeGaussian_data")
h5_files = list(data_dir.rglob("*.h5"))

print("="*80)
print("📊 COMPLETE DATASET SUMMARY")
print("="*80)

total_images = 0
sample_file = h5_files[0] if h5_files else None

if sample_file:
    with h5py.File(sample_file, 'r') as f:
        images_per_file = len(f['face_patch'])
        total_images = len(h5_files) * images_per_file
        img_shape = f['face_patch'][0].shape

print(f"\n📁 Dataset Location: {data_dir}")
print(f"📦 Total HDF5 Files: {len(h5_files)}")
print(f"🖼️  Images per File: {images_per_file if sample_file else 'N/A'}")
print(f"📸 Total Images: {total_images:,}")
print(f"📐 Image Resolution: {img_shape if sample_file else 'N/A'}")

if sample_file:
    with h5py.File(sample_file, 'r') as f:
        print(f"\n✅ Available Data Fields:")
        print(f"   • Face Images: face_patch [{f['face_patch'].shape}]")
        print(f"   • Gaze Labels: pitchyaw [{f['pitchyaw'].shape}]")
        print(f"   • 3D Vertices: vertice [{f['vertice'].shape}]")
        print(f"   • Camera Intrinsics: inmat [{f['inmat'].shape}]")
        print(f"   • Camera Extrinsics: c2w_Rmat, c2w_Tvec")
        print(f"   • Head Masks: head_mask [{f['head_mask'].shape}]")
        print(f"   • Eye Masks: left_eye_mask, right_eye_mask")
        print(f"   • Facial Landmarks: facial_landmarks [{f['facial_landmarks'].shape}]")
        print(f"   • Latent Codes: latent_codes [{f['latent_codes'].shape}]")

print("\n" + "="*80)
print("🚀 READY TO TRAIN!")
print("="*80)

print("\n📝 IMPORTANT: Run Cell 28 first to create the config matching your data!")
print("\nThen run this training command:")
print("\n# For MeshHead training:")
print("!cd /content/GazeGaussian && python train_meshhead.py \\")
print(f"    --img_dir /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test/ETH-XGaze_test \\")
print(f"    --checkpoint_path /content/drive/MyDrive/GazeGaussian_checkpoints/meshhead \\")
print(f"    --num_epochs 5")

print("\n# For GazeGaussian training:")
print("!cd /content/GazeGaussian && python train.py \\")
print(f"    -s /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test \\")
print(f"    -m /content/drive/MyDrive/GazeGaussian_checkpoints/model \\")
print(f"    --iterations 10000")

print("\n" + "="*80)
print("⚠️  WORKFLOW:")
print("   1. Run Cell 28 to create custom config for your dataset")
print("   2. Run the training command above")
print("="*80)

## 🚀 Start Training

In [None]:
!cd /content/GazeGaussian && python train_meshhead.py \
    --img_dir /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test/ETH-XGaze_test \
    --checkpoint_path /content/drive/MyDrive/GazeGaussian_checkpoints/meshhead \
    --num_epochs 5

### Dataset Summary & Training Readiness

In [None]:
import h5py
from pathlib import Path

data_dir = Path("/content/drive/MyDrive/GazeGaussian_data")
h5_files = list(data_dir.rglob("*.h5"))

print("="*80)
print("📊 COMPLETE DATASET SUMMARY")
print("="*80)

total_images = 0
sample_file = h5_files[0] if h5_files else None

if sample_file:
    with h5py.File(sample_file, 'r') as f:
        images_per_file = len(f['face_patch'])
        total_images = len(h5_files) * images_per_file
        img_shape = f['face_patch'][0].shape

print(f"\n📁 Dataset Location: {data_dir}")
print(f"📦 Total HDF5 Files: {len(h5_files)}")
print(f"🖼️  Images per File: {images_per_file if sample_file else 'N/A'}")
print(f"📸 Total Images: {total_images:,}")
print(f"📐 Image Resolution: {img_shape if sample_file else 'N/A'}")

if sample_file:
    with h5py.File(sample_file, 'r') as f:
        print(f"\n✅ Available Data Fields:")
        print(f"   • Face Images: face_patch [{f['face_patch'].shape}]")
        print(f"   • Gaze Labels: pitchyaw [{f['pitchyaw'].shape}]")
        print(f"   • 3D Vertices: vertice [{f['vertice'].shape}]")
        print(f"   • Camera Intrinsics: inmat [{f['inmat'].shape}]")
        print(f"   • Camera Extrinsics: c2w_Rmat, c2w_Tvec")
        print(f"   • Head Masks: head_mask [{f['head_mask'].shape}]")
        print(f"   • Eye Masks: left_eye_mask, right_eye_mask")
        print(f"   • Facial Landmarks: facial_landmarks [{f['facial_landmarks'].shape}]")
        print(f"   • Latent Codes: latent_codes [{f['latent_codes'].shape}]")

print("\n" + "="*80)
print("🚀 READY TO TRAIN!")
print("="*80)

print("\n📝 Training Command (adjust paths as needed):")
print("\n# For MeshHead training:")
print("!cd /content/GazeGaussian && python train_meshhead.py \\")
print(f"    --img_dir /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test/ETH-XGaze_test \\")
print(f"    --checkpoint_path /content/drive/MyDrive/GazeGaussian_checkpoints/meshhead")

print("\n# For GazeGaussian training:")
print("!cd /content/GazeGaussian && python train.py \\")
print(f"    -s /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test \\")
print(f"    -m /content/drive/MyDrive/GazeGaussian_checkpoints/model \\")
print(f"    --iterations 30000")

print("\n" + "="*80)
print("✅ All systems ready! You can now start training.")
print("="*80)

In [None]:
!mkdir -p data
!mkdir -p /content/drive/MyDrive/GazeGaussian_checkpoints
print("\n✅ Setup complete! Ready for training.")

In [None]:
%cd /content
!rm -rf GazeGaussian
!git clone --recursive https://github.com/kram254/GazeGaussian.git
%cd GazeGaussian
!git submodule update --init --recursive

In [None]:
%cd /content/GazeGaussian/configs
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/config_models.zip
!unzip -q config_models.zip
%cd /content/GazeGaussian

In [None]:
!pip install --upgrade pip setuptools wheel ninja

In [None]:
!pip install opencv-python h5py tqdm scipy scikit-image lpips kornia tensorboardX einops trimesh plyfile

In [None]:
!pip uninstall -y diff-gaussian-rasterization simple-knn kaolin
!pip cache purge
!rm -rf /usr/local/lib/python3.*/dist-packages/diff_gaussian_rasterization*
!rm -rf /usr/local/lib/python3.*/dist-packages/simple_knn*
!rm -rf /usr/local/lib/python3.*/dist-packages/kaolin*
!rm -rf /content/GazeGaussian/submodules/diff-gaussian-rasterization/build
!rm -rf /content/GazeGaussian/submodules/simple-knn/build
!rm -rf /tmp/*

In [None]:
import torch
print(f"Current PyTorch: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
!pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
import torch
print(f"\nReinstalled PyTorch: {torch.__version__}")
print(f"CUDA version: {torch.version.cuda}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
%cd submodules/diff-gaussian-rasterization
!python setup.py install
%cd ../..

In [None]:
%cd submodules/simple-knn
!python setup.py install
%cd ../..

In [None]:
%cd /content
!git clone --recursive https://github.com/NVIDIAGameWorks/kaolin
%cd kaolin
!python setup.py install
%cd /content/GazeGaussian

In [None]:
print("="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import kaolin
    print(f"✓ {'kaolin':15s} {kaolin.__version__}")
except ImportError as e:
    print(f"✗ {'kaolin':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL REQUIRED PACKAGES INSTALLED SUCCESSFULLY!")
    print("   Ready for MeshHead and GazeGaussian training!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")

In [None]:
!cd /content/GazeGaussian && python train_meshhead.py \
    --img_dir /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test/ETH-XGaze_test \
    --checkpoint_path /content/drive/MyDrive/GazeGaussian_checkpoints/meshhead \
    --num_epochs 50 \
    --early_stopping \
    --patience 5