In [None]:
!nvidia-smi

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content
!rm -rf GazeGaussian
!git clone --recursive https://github.com/kram254/GazeGaussian.git
%cd GazeGaussian
!git submodule update --init --recursive

In [None]:
!pip install --upgrade pip setuptools wheel ninja

In [None]:
!pip install opencv-python h5py tqdm scipy scikit-image lpips kornia tensorboardX einops trimesh plyfile

In [None]:
!pip uninstall -y diff-gaussian-rasterization simple-knn kaolin
!pip cache purge
!rm -rf /usr/local/lib/python3.*/dist-packages/diff_gaussian_rasterization*
!rm -rf /usr/local/lib/python3.*/dist-packages/simple_knn*
!rm -rf /usr/local/lib/python3.*/dist-packages/kaolin*
!rm -rf /content/GazeGaussian/submodules/diff-gaussian-rasterization/build
!rm -rf /content/GazeGaussian/submodules/simple-knn/build
!rm -rf /tmp/*

In [None]:
!pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

In [None]:
%cd /content/GazeGaussian/submodules/diff-gaussian-rasterization
!python setup.py install

In [None]:
%cd /content/GazeGaussian/submodules/simple-knn
!python setup.py install

In [None]:
!pip install kaolin-core

In [None]:
print("="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import kaolin
    print(f"✓ {'kaolin':15s} OK")
except ImportError as e:
    print(f"✗ {'kaolin':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL REQUIRED PACKAGES INSTALLED SUCCESSFULLY!")
    print("   Ready for MeshHead and GazeGaussian training!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")

In [None]:
print("\n" + "="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import kaolin
    try:
        kaolin_version = kaolin.__version__
    except AttributeError:
        kaolin_version = 'OK (version unknown)'
    print(f"✓ {'kaolin':15s} {kaolin_version}")
except ImportError as e:
    print(f"✗ {'kaolin':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL REQUIRED PACKAGES INSTALLED SUCCESSFULLY!")
    print("   Ready for MeshHead and GazeGaussian training!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")

In [None]:
!cd /content/GazeGaussian && python train_meshhead.py \
    --img_dir /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test/ETH-XGaze_test \
    --num_epochs 50 \
    --early_stopping \
    --patience 5

In [None]:
import torch
import os
from pathlib import Path
import numpy as np
from PIL import Image
import torchvision.utils as vutils
from tqdm import tqdm

from configs.gazegaussian_options import BaseOptions
from models.gaze_gaussian import GazeGaussianNet
from dataloader.eth_xgaze import get_val_loader

def save_image_grid(images, save_path, nrow=4):
    grid = vutils.make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
    grid_np = grid.cpu().numpy().transpose(1, 2, 0)
    grid_np = np.clip((grid_np * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
    Image.fromarray(grid_np).save(save_path)
    return grid_np

checkpoint_path = "/content/drive/MyDrive/GazeGaussian_checkpoints/gazegaussian_ckp.pth"
data_dir = "/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test/ETH-XGaze_test"
output_dir = "/content/test_outputs"
num_samples = 5

print("="*80)
print("TESTING CHECKPOINT")
print("="*80)

os.makedirs(output_dir, exist_ok=True)

print(f"\n[1/4] Loading checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location='cuda')
print(f"✓ Checkpoint loaded")

print(f"\n[2/4] Initializing model...")
opt = BaseOptions()

try:
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif isinstance(checkpoint, dict):
        state_dict = checkpoint
    else:
        state_dict = None
    
    model = GazeGaussianNet(opt, load_state_dict=state_dict)
    model = model.cuda()
    model.eval()
    print(f"✓ Model initialized")
    print(f"  Renderer type: {type(model.neural_render).__name__}")
except Exception as e:
    print(f"✗ Error: {e}")
    raise

print(f"\n[3/4] Loading validation data...")
opt.img_dir = data_dir
val_loader = get_val_loader(opt, data_dir=data_dir, batch_size=1, num_workers=0, evaluate=None, dataset_name='eth_xgaze')
print(f"✓ Data loaded ({len(val_loader.dataset)} samples)")

print(f"\n[4/4] Generating {num_samples} images...")

with torch.no_grad():
    for idx, data in enumerate(tqdm(val_loader, total=num_samples)):
        if idx >= num_samples:
            break
        
        try:
            for key in data:
                if isinstance(data[key], torch.Tensor):
                    data[key] = data[key].cuda()
                elif isinstance(data[key], dict):
                    for sub_key in data[key]:
                        if isinstance(data[key][sub_key], torch.Tensor):
                            data[key][sub_key] = data[key][sub_key].cuda()
            
            output = model(data)
            
            gt_image = data.get('image', data.get('img', None))
            gaussian_img = output['total_render_dict']['merge_img']
            neural_img = output['total_render_dict']['merge_img_pro']
            
            if gt_image is not None:
                comparison = torch.cat([gt_image, gaussian_img, neural_img], dim=0)
            else:
                comparison = torch.cat([gaussian_img, neural_img], dim=0)
            
            save_path = os.path.join(output_dir, f"test_{idx:03d}.png")
            save_image_grid(comparison, save_path, nrow=len(comparison))
            
            save_image_grid(neural_img, os.path.join(output_dir, f"test_{idx:03d}_dit.png"), nrow=1)
            
        except Exception as e:
            print(f"Error on sample {idx}: {e}")
            continue

print(f"\n✅ Generated images saved to: {output_dir}")
print(f"   Files: test_000.png, test_001.png, ...")

from google.colab import files
from IPython.display import display, Image as IPImage

print(f"\nDisplaying first 3 samples:")
for i in range(min(3, num_samples)):
    img_path = os.path.join(output_dir, f"test_{i:03d}.png")
    if os.path.exists(img_path):
        display(IPImage(filename=img_path))