# GazeGaussian Enhanced (DiT) - 2-Step Training

## Overview
This notebook trains the enhanced GazeGaussian model with:
1. **DiT Neural Renderer** (replacing U-Net)
2. **VAE Integration**
3. **Orthogonality Regularization**

## Training Process
- **Step 1**: Train MeshHead (~10 epochs, ~2-3 hours)
- **Step 2**: Train GazeGaussian with DiT (~30 epochs, ~8-12 hours)

## Requirements
- GPU: A100 (40GB recommended) or V100 (32GB minimum)
- Dataset: ETH-XGaze training set in Google Drive
- Time: ~12-15 hours total

In [1]:
!nvidia-smi

Thu Oct 30 14:33:52 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   45C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
%cd /content
!rm -rf GazeGaussian
!git clone --recursive https://github.com/kram254/GazeGaussian.git
%cd GazeGaussian
!git submodule update --init --recursive

/content
Cloning into 'GazeGaussian'...
remote: Enumerating objects: 1682, done.[K
remote: Counting objects: 100% (1682/1682), done.[K
remote: Compressing objects: 100% (724/724), done.[K
remote: Total 1682 (delta 970), reused 1654 (delta 942), pack-reused 0 (from 0)[K
Receiving objects: 100% (1682/1682), 18.80 MiB | 21.00 MiB/s, done.
Resolving deltas: 100% (970/970), done.
/content/GazeGaussian


## 2. Install Dependencies

In [4]:
!pip install --upgrade pip setuptools wheel ninja

Collecting pip
  Downloading pip-25.3-py3-none-any.whl.metadata (4.7 kB)
Collecting setuptools
  Downloading setuptools-80.9.0-py3-none-any.whl.metadata (6.6 kB)
Collecting ninja
  Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.1 kB)
Downloading pip-25.3-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading setuptools-80.9.0-py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m53.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (180 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: setuptools, pip, ninja
  Attempting uninstall: setuptools
    Found existing installation: setuptools 75.2.0
    Uninstalli

In [5]:
!pip install opencv-python h5py tqdm scipy scikit-image lpips kornia tensorboardX einops trimesh plyfile

Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Collecting kornia
  Downloading kornia-0.8.1-py2.py3-none-any.whl.metadata (17 kB)
Collecting tensorboardX
  Downloading tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)
Collecting trimesh
  Downloading trimesh-4.9.0-py3-none-any.whl.metadata (18 kB)
Collecting plyfile
  Downloading plyfile-1.1.3-py3-none-any.whl.metadata (43 kB)
Collecting kornia_rs>=0.1.9 (from kornia)
  Downloading kornia_rs-0.1.9-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
Downloading kornia-0.8.1-py2.py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m35.8 MB/s[0m  [33m0:00:00[0m
[?25hDownloading tensorboardx-2.6.4-py3-none-any.whl (87 kB)
Downloading trimesh-4.9.0-py3-none-any.whl (736 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m736.5/736.5 kB[0m [31m47.7 MB/s[0m  [33m

In [6]:
!pip install --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torch
  Downloading https://download.pytorch.org/whl/cu121/torch-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl (780.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m780.4/780.4 MB[0m [31m20.3 MB/s[0m  [33m0:00:18[0m
[?25hCollecting torchvision
  Downloading https://download.pytorch.org/whl/cu121/torchvision-0.20.1%2Bcu121-cp312-cp312-linux_x86_64.whl (7.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.3/7.3 MB[0m [31m148.2 MB/s[0m  [33m0:00:00[0m
[?25hCollecting torchaudio
  Downloading https://download.pytorch.org/whl/cu121/torchaudio-2.5.1%2Bcu121-cp312-cp312-linux_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m132.4 MB/s[0m  [33m0:00:00[0m
[?25hCollecting filelock (from torch)
  Downloading https://download.pytorch.org/whl/filelock-3.19.1-py3-none-any.whl.metadata (2.1 kB)
Collecting typing-extensions>

In [2]:
!pip install opencv-python h5py tqdm scipy scikit-image lpips kornia



In [3]:
import torch
import sys
print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"CUDA available: {torch.cuda.is_available()}")

Python: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
PyTorch: 2.5.1+cu121
CUDA: 12.1
CUDA available: True


In [4]:
import os
print("Checking diff-gaussian-rasterization...")
rast_dir = "submodules/diff-gaussian-rasterization"
if os.path.exists(f"{rast_dir}/setup.py"):
    print(f"✓ setup.py found")
else:
    print(f"✗ setup.py NOT found - cloning submodule")
    !git clone https://github.com/graphdeco-inria/diff-gaussian-rasterization {rast_dir}

Checking diff-gaussian-rasterization...
✗ setup.py NOT found - cloning submodule
Cloning into 'submodules/diff-gaussian-rasterization'...
remote: Enumerating objects: 329, done.[K
remote: Total 329 (delta 0), reused 0 (delta 0), pack-reused 329 (from 1)[K
Receiving objects: 100% (329/329), 111.52 KiB | 913.00 KiB/s, done.
Resolving deltas: 100% (217/217), done.


In [8]:
%cd submodules/diff-gaussian-rasterization
!python setup.py install
%cd ../..

/content/GazeGaussian/submodules/diff-gaussian-rasterization
running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg
running egg_info
writ

In [6]:
# Initialize and update submodules within diff-gaussian-rasterization
%cd /content/GazeGaussian/submodules/diff-gaussian-rasterization
!git submodule update --init --recursive
%cd ../..

/content/GazeGaussian/submodules/diff-gaussian-rasterization
/content/GazeGaussian


## 3. Build CUDA Extensions

In [7]:
import os
os.environ['MAX_JOBS'] = '2'
os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5;8.0;8.6'

%cd /content/GazeGaussian
%cd submodules/diff-gaussian-rasterization
!python setup.py install
%cd ../..

/content/GazeGaussian
/content/GazeGaussian/submodules/diff-gaussian-rasterization
running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg

In [10]:
import os

# Change to the main GazeGaussian directory first
os.chdir('/content/GazeGaussian')

# Clone the simple-knn submodule if it doesn't exist
simple_knn_dir = "submodules/simple-knn"
if not os.path.exists(simple_knn_dir):
    print(f"Cloning simple-knn submodule into {simple_knn_dir}...")
    os.system(f'git clone https://github.com/rusty1s/simple-knn {simple_knn_dir}')
else:
    print(f"simple-knn submodule already exists at {simple_knn_dir}")

# Now change to the simple-knn directory
os.chdir(simple_knn_dir)

with open('simple_knn.cu', 'r') as f:
    content = f.read()

if '#include <cfloat>' not in content:
    content = content.replace(
        '#include <vector>',
        '#include <vector>\n#include <cfloat>'
    )
    with open('simple_knn.cu', 'w') as f:
        f.write(content)
    print("✓ Added cfloat header")
else:
    print("✓ cfloat header already present")

os.system('python setup.py install')
os.chdir('/content/GazeGaussian')

simple-knn submodule already exists at submodules/simple-knn
✓ cfloat header already present


In [11]:
import os
os.chdir('/content/GazeGaussian/submodules/simple-knn')
with open('simple_knn.cu', 'r') as f:
    content = f.read()
if '#include <cfloat>' not in content:
    content = content.replace('#include <vector>', '#include <vector>\n#include <cfloat>')
    with open('simple_knn.cu', 'w') as f:
        f.write(content)
os.chdir('/content/GazeGaussian')

%cd submodules/simple-knn
!python setup.py install
%cd ../..

/content/GazeGaussian/submodules/simple-knn
running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg
running egg_info
writing simple_knn.eg

In [12]:
%cd /content
!git clone --recursive https://github.com/NVIDIAGameWorks/kaolin
%cd kaolin
!python setup.py install
%cd /content/GazeGaussian

/content
Cloning into 'kaolin'...
remote: Enumerating objects: 6543, done.[K
remote: Counting objects: 100% (722/722), done.[K
remote: Compressing objects: 100% (276/276), done.[K
remote: Total 6543 (delta 561), reused 446 (delta 446), pack-reused 5821 (from 2)[K
Receiving objects: 100% (6543/6543), 133.24 MiB | 38.85 MiB/s, done.
Resolving deltas: 100% (3570/3570), done.
Submodule 'third_party/cub' (https://github.com/NVIDIA/cub) registered for path 'third_party/cub'
Cloning into '/content/kaolin/third_party/cub'...
remote: Enumerating objects: 33392, done.        
remote: Counting objects: 100% (247/247), done.        
remote: Compressing objects: 100% (63/63), done.        
remote: Total 33392 (delta 209), reused 184 (delta 184), pack-reused 33145 (from 4)        
Receiving objects: 100% (33392/33392), 18.00 MiB | 22.13 MiB/s, done.
Resolving deltas: 100% (27972/27972), done.
Submodule path 'third_party/cub': checked out '499a7bad3416fcc71a7c50351d6b3cdbf3fbbc27'
/content/kaolin

In [13]:
import os
os.environ['MAX_JOBS'] = '2'
os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5;8.0;8.6'

%cd submodules/diff-gaussian-rasterization
!python setup.py install
%cd ../..

/content/GazeGaussian/submodules/diff-gaussian-rasterization
running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg
running egg_info
writ

In [14]:
import os
os.chdir('/content/GazeGaussian/submodules/simple-knn')
with open('simple_knn.cu', 'r') as f:
    content = f.read()
if '#include <cfloat>' not in content:
    content = content.replace('#include <vector>', '#include <vector>\n#include <cfloat>')
    with open('simple_knn.cu', 'w') as f:
        f.write(content)
os.chdir('/content/GazeGaussian')

%cd submodules/simple-knn
!python setup.py install
%cd ../..

/content/GazeGaussian/submodules/simple-knn
running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg
running egg_info
writing simple_knn.eg

In [15]:
!pip install lpips kornia



In [16]:
import os
os.environ['MAX_JOBS'] = '2'
os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5;8.0;8.6'

%cd submodules/diff-gaussian-rasterization
!python setup.py install
%cd ../..

/content/GazeGaussian/submodules/diff-gaussian-rasterization
running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg
running egg_info
writ

In [17]:
import os
os.chdir('/content/GazeGaussian/submodules/simple-knn')
with open('simple_knn.cu', 'r') as f:
    content = f.read()
if '#include <cfloat>' not in content:
    content = content.replace('#include <vector>', '#include <vector>\n#include <cfloat>')
    with open('simple_knn.cu', 'w') as f:
        f.write(content)
os.chdir('/content/GazeGaussian')

%cd submodules/simple-knn
!python setup.py install
%cd ../..

/content/GazeGaussian/submodules/simple-knn
running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg
running egg_info
writing simple_knn.eg

In [18]:
import os
os.environ['IGNORE_TORCH_VER'] = '1'
os.environ['FORCE_CUDA'] = '1'

%cd /content/kaolin
!python setup.py clean --all
!python setup.py install
%cd /content/GazeGaussian
print("✓ Kaolin reinstalled from source")

/content/kaolin
  from pkg_resources import parse_version
INFO - running clean
INFO - removing 'build/temp.linux-x86_64-cpython-312' (and everything under it)
INFO - removing 'build/lib.linux-x86_64-cpython-312' (and everything under it)
INFO - removing 'build/bdist.linux-x86_64' (and everything under it)
INFO - removing 'build/scripts-3.12' (and everything under it)
INFO - removing 'build'
  from pkg_resources import parse_version
INFO - running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
  

## 4. Verify Installation

In [None]:
print("\n" + "="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import kaolin
    try:
        kaolin_version = kaolin.__version__
    except AttributeError:
        kaolin_version = 'OK (version unknown)'
    print(f"✓ {'kaolin':15s} {kaolin_version}")
except ImportError as e:
    print(f"✗ {'kaolin':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL REQUIRED PACKAGES INSTALLED SUCCESSFULLY!")
    print("   Ready for training!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")

In [None]:
import os
os.environ['MAX_JOBS'] = '2'
os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5;8.0;8.6'

%cd submodules/diff-gaussian-rasterization
!python setup.py install
%cd ../..

In [None]:
import os
os.chdir('/content/GazeGaussian/submodules/simple-knn')
with open('simple_knn.cu', 'r') as f:
    content = f.read()
if '#include <cfloat>' not in content:
    content = content.replace('#include <vector>', '#include <vector>\n#include <cfloat>')
    with open('simple_knn.cu', 'w') as f:
        f.write(content)
os.chdir('/content/GazeGaussian')

%cd submodules/simple-knn
!python setup.py install
%cd ../..

In [None]:
import os
os.environ['MAX_JOBS'] = '2'
os.environ['TORCH_CUDA_ARCH_LIST'] = '7.5;8.0;8.6'

%cd submodules/diff-gaussian-rasterization
!python setup.py install
%cd ../..

## 5. Configure Dataset

In [None]:
%cd /content/GazeGaussian
!rm -rf data
!mkdir -p data
print("Cleaned and recreated data directory.")

In [None]:
%cd /content/GazeGaussian

!mkdir -p data
%cd data

!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partaa
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partab
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partac
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partad
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partae
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partaf
# !wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partag
# !wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partah
# !wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partai
# !wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partaj
# !wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partak
# !wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partal
# !wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partam
# !wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/ETH-XGaze.zip.partan

!cat ETH-XGaze.zip.part* > ETH-XGaze.zip && echo "Concatenation complete"
!unzip ETH-XGaze.zip

%cd /content/GazeGaussian/configs
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/config_models.zip
!unzip config_models.zip

%cd /content/GazeGaussian
!mkdir -p checkpoint
%cd checkpoint
!wget https://huggingface.co/ucwxb/GazeGaussian/resolve/main/gazegaussian_ckp.pth

%cd /content/GazeGaussian

In [None]:
import shutil
import os

print("Copying data to Google Drive (this takes time but saves redownloading)...")
os.makedirs('/content/drive/MyDrive/GazeGaussian_data', exist_ok=True)

if os.path.exists('/content/GazeGaussian/data/ETH-XGaze'):
    shutil.copytree('/content/GazeGaussian/data/ETH-XGaze',
                    '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze',
                    dirs_exist_ok=True)
    print("✓ Training data saved to Drive")

if os.path.exists('/content/GazeGaussian/data/ETH-XGaze_test'):
    shutil.copytree('/content/GazeGaussian/data/ETH-XGaze_test',
                    '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test',
                    dirs_exist_ok=True)
    print("✓ Test data saved to Drive")

if os.path.exists('/content/GazeGaussian/configs/config_models'):
    shutil.copytree('/content/GazeGaussian/configs/config_models',
                    '/content/drive/MyDrive/GazeGaussian_data/config_models',
                    dirs_exist_ok=True)
    print("✓ Config models saved to Drive")

In [None]:
import os

os.makedirs('/content/GazeGaussian/data', exist_ok=True)
os.makedirs('/content/GazeGaussian/configs', exist_ok=True)

if os.path.exists('/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze'):
    !ln -s /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze /content/GazeGaussian/data/ETH-XGaze
    print("✓ Linked training data from Drive")

if os.path.exists('/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test'):
    !ln -s /content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test /content/GazeGaussian/data/ETH-XGaze_test
    print("✓ Linked test data from Drive")

if os.path.exists('/content/drive/MyDrive/GazeGaussian_data/config_models'):
    !ln -s /content/drive/MyDrive/GazeGaussian_data/config_models /content/GazeGaussian/configs/config_models
    print("✓ Linked config models from Drive")

In [None]:
import os

print("\n" + "="*80)
print("DATA VERIFICATION")
print("="*80)

all_data_present = True

# Check for training data directory and a sample file
# Adjusted path and filename format based on actual extraction
train_data_path = '/content/GazeGaussian/data/ETH-XGaze'
sample_train_file = os.path.join(train_data_path, 'xgaze_subject0000.h5')
if os.path.exists(train_data_path) and os.path.exists(sample_train_file):
    print(f"✓ Training data directory and sample file found: {sample_train_file}")
    # Optional: check for number of files to be more robust
    # expected_num_files = 114 # Based on original description
    # if len(os.listdir(train_data_path)) >= expected_num_files:
    #     print(f"✓ Found at least {expected_num_files} training data files.")
    # else:
    #     print(f"⚠ Expected number of training files not found. Found {len(os.listdir(train_data_path))}")
    #     all_data_present = False
else:
    print(f"✗ Training data not found or incomplete. Expected: {sample_train_file}")
    all_data_present = False

# Check for config models directory
config_models_path = '/content/GazeGaussian/configs/config_models'
if os.path.exists(config_models_path):
    print(f"✓ Config models directory found: {config_models_path}")
else:
    print(f"✗ Config models not found. Expected: {config_models_path}")
    all_data_present = False

# Check for checkpoint file
checkpoint_path = '/content/GazeGaussian/checkpoint/gazegaussian_ckp.pth'
if os.path.exists(checkpoint_path):
    print(f"✓ Checkpoint file found: {checkpoint_path}")
else:
    print(f"✗ Checkpoint file not found. Expected: {checkpoint_path}")
    all_data_present = False

print("="*80)

if all_data_present:
    print("\n✅ All necessary data appears to be present. Ready to proceed with training.")
else:
    print("\n⚠ Some necessary data is missing. Please ensure all download and extraction steps were successful.")

In [None]:
import shutil
import os

print("Copying data to Google Drive (this takes time but saves redownloading)...")
os.makedirs('/content/drive/MyDrive/GazeGaussian_data', exist_ok=True)

# Copy the ETH-XGaze_test data
test_data_src = '/content/GazeGaussian/data/ETH-XGaze_test'
test_data_dest = '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test'
if os.path.exists(test_data_src):
    shutil.copytree(test_data_src, test_data_dest, dirs_exist_ok=True)
    print("✓ Test data saved to Drive")
else:
    print("✗ Test data not found at source")

# Copy the config models
config_models_src = '/content/GazeGaussian/configs/config_models'
config_models_dest = '/content/drive/MyDrive/GazeGaussian_data/config_models'
if os.path.exists(config_models_src):
    shutil.copytree(config_models_src, config_models_dest, dirs_exist_ok=True)
    print("✓ Config models saved to Drive")
else:
    print("✗ Config models not found at source")


# Copy the checkpoint file
checkpoint_src = '/content/GazeGaussian/checkpoint/gazegaussian_ckp.pth'
checkpoint_dest_dir = '/content/drive/MyDrive/GazeGaussian_checkpoints'
checkpoint_dest_file = os.path.join(checkpoint_dest_dir, 'gazegaussian_ckp.pth')
if os.path.exists(checkpoint_src):
    os.makedirs(checkpoint_dest_dir, exist_ok=True)
    shutil.copy(checkpoint_src, checkpoint_dest_file)
    print("✓ Checkpoint file saved to Drive")
else:
    print("✗ Checkpoint file not found at source")

In [None]:
import os

os.makedirs('/content/GazeGaussian/data', exist_ok=True)
os.makedirs('/content/GazeGaussian/configs', exist_ok=True)
os.makedirs('/content/GazeGaussian/checkpoint', exist_ok=True)

# Link the ETH-XGaze_test data
test_data_src = '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test'
test_data_dest = '/content/GazeGaussian/data/ETH-XGaze_test'
if os.path.exists(test_data_src) and not os.path.exists(test_data_dest):
    !ln -s {test_data_src} {test_data_dest}
    print("✓ Linked test data from Drive")
elif os.path.exists(test_data_dest):
    print("✓ Test data directory already exists")
else:
    print("✗ Test data not found in Drive to link")


# Link the config models
config_models_src = '/content/drive/MyDrive/GazeGaussian_data/config_models'
config_models_dest = '/content/GazeGaussian/configs/config_models'
if os.path.exists(config_models_src) and not os.path.exists(config_models_dest):
    !ln -s {config_models_src} {config_models_dest}
    print("✓ Linked config models from Drive")
elif os.path.exists(config_models_dest):
    print("✓ Config models directory already exists")
else:
    print("✗ Config models not found in Drive to link")

# Link the checkpoint file
checkpoint_src = '/content/drive/MyDrive/GazeGaussian_checkpoints/gazegaussian_ckp.pth'
checkpoint_dest = '/content/GazeGaussian/checkpoint/gazegaussian_ckp.pth'
if os.path.exists(checkpoint_src) and not os.path.exists(checkpoint_dest):
    !ln -s {checkpoint_src} {checkpoint_dest}
    print("✓ Linked checkpoint file from Drive")
elif os.path.exists(checkpoint_dest):
    print("✓ Checkpoint file already exists")
else:
     print("✗ Checkpoint file not found in Drive to link")

# Link the ETH-XGaze training data if it exists in Drive
train_data_src = '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze'
train_data_dest = '/content/GazeGaussian/data/ETH-XGaze'
if os.path.exists(train_data_src) and not os.path.exists(train_data_dest):
    !ln -s {train_data_src} {train_data_dest}
    print("✓ Linked training data from Drive")
elif os.path.exists(train_data_dest):
    print("✓ Training data directory already exists")
else:
    print("✗ Training data not found in Drive to link")



```
`# This is formatted as code`
```

## 6. STEP 1: Train MeshHead (~10 epochs, ~2-3 hours)

This creates the canonical 3D head model.

In [None]:
!pip install pygltflib

In [None]:
%cd /content/GazeGaussian

!python train_meshhead.py \
    --batch_size 1 \
    --name 'meshhead' \
    --img_dir '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train' \
    --num_epochs 10 \
    --num_workers 2 \
    --early_stopping \
    --patience 5 \
    --dataset_name 'eth_xgaze'

## 7. Verify MeshHead Checkpoint

In [None]:
import glob
import os

checkpoints = glob.glob("/content/GazeGaussian/work_dirs/meshhead_*/checkpoints/*.pth")
if checkpoints:
    latest_checkpoint = sorted(checkpoints)[-1]
    print(f"✓ MeshHead checkpoint found: {latest_checkpoint}")
    print(f"  Size: {os.path.getsize(latest_checkpoint) / (1024**2):.2f} MB")

    with open('/content/meshhead_checkpoint.txt', 'w') as f:
        f.write(latest_checkpoint)
    print(f"\n✓ Checkpoint path saved for Step 2")
else:
    print("❌ No MeshHead checkpoint found! Training may have failed.")

## 8. Verify DiT Configuration

In [None]:
from configs.gazegaussian_options import BaseOptions

opt = BaseOptions()

print("="*80)
print("ENHANCED MODEL CONFIGURATION")
print("="*80)
print(f"\n✓ Neural Renderer Type: {opt.neural_renderer_type}")
print(f"✓ DiT Depth: {opt.dit_depth}")
print(f"✓ DiT Num Heads: {opt.dit_num_heads}")
print(f"✓ DiT Patch Size: {opt.dit_patch_size}")
print(f"✓ VAE Enabled: {opt.use_vae}")
print(f"✓ VAE Z Channels: {opt.vae_z_channels}")
print(f"✓ VAE Frozen: {opt.freeze_vae}")
print(f"✓ Orthogonality Loss: {opt.use_orthogonality_loss}")
print(f"✓ Orthogonality Importance: {opt.orthogonality_loss_importance}")

if opt.neural_renderer_type == "dit" and opt.use_vae and opt.use_orthogonality_loss:
    print("\n✅ All 3 enhancements are ACTIVE!")
else:
    print("\n⚠ Some enhancements may be disabled!")

## 9. STEP 2: Train GazeGaussian with DiT (~30 epochs, ~8-12 hours)

This trains the full pipeline with your 3 enhancements.

In [None]:
%cd /content/GazeGaussian

# Corrected: Reading the MeshHead checkpoint path from the text file
with open('/content/meshhead_checkpoint.txt', 'r') as f:
    meshhead_checkpoint = f.read().strip()

print(f"Loading MeshHead from: {meshhead_checkpoint}")

!python train_gazegaussian.py \
    --batch_size 1 \
    --name 'gazegaussian_dit' \
    --img_dir '/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze/train' \
    --num_epochs 30 \
    --num_workers 2 \
    --lr 0.0001 \
    --clip_grad \
    --load_meshhead_checkpoint {meshhead_checkpoint} \
    --dataset_name 'eth_xgaze'

## 10. Verify Final Checkpoint

## 11. Generate Test Samples

Generate a few redirected gaze/pose samples for verification.

In [None]:
# TODO: Add inference code to generate samples
# This will be added after confirming training works
# print("Sample generation coming in next update...")

import os

# Define the path to the checkpoint in Google Drive
checkpoint_path = '/content/drive/MyDrive/GazeGaussian_checkpoints/gazegaussian_ckp.pth'

# Define the output directory for generated samples
output_dir = '/content/GazeGaussian/generated_samples'
os.makedirs(output_dir, exist_ok=True)

# Change to the GazeGaussian directory
%cd /content/GazeGaussian

# Run the inference script
# You might need to adjust the command based on the actual inference script and its arguments
# This is a placeholder command, replace with the correct one from the GazeGaussian project
print(f"Generating samples using checkpoint: {checkpoint_path}")
!python generate_samples.py \
    --checkpoint_path {checkpoint_path} \
    --output_dir {output_dir} \
    --dataset_name 'eth_xgaze' # Or the appropriate dataset name for inference

In [None]:
import os

checkpoint_dir = '/content/drive/MyDrive/GazeGaussian_checkpoints'

print(f"Checking for checkpoints in: {checkpoint_dir}")

if os.path.exists(checkpoint_dir):
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
    if checkpoints:
        print("\nFound the following checkpoints:")
        for ckp in checkpoints:
            print(f"- {ckp}")
    else:
        print("\nNo .pth checkpoint files found in the directory.")
else:
    print("\nCheckpoint directory not found in Google Drive.")

In [None]:
# ============================================================================
# GAZEGAUSSIAN CHECKPOINT TESTING - SINGLE CELL
# ============================================================================

import torch
import os
from pathlib import Path
import numpy as np
from PIL import Image
import torchvision.utils as vutils
from tqdm import tqdm
from IPython.display import display, Image as IPImage

%cd /content/GazeGaussian


Testing chckpoints

In [None]:
# ============================================================================
# GAZEGAUSSIAN CHECKPOINT TESTING - SINGLE CELL
# ============================================================================

import torch
import os
from pathlib import Path
import numpy as np
from PIL import Image
import torchvision.utils as vutils
from tqdm import tqdm
from IPython.display import display, Image as IPImage

%cd /content/GazeGaussian

from configs.gazegaussian_options import BaseOptions
from models.gaze_gaussian import GazeGaussianNet
from dataloader.eth_xgaze import get_val_loader

def save_image_grid(images, save_path, nrow=4):
    """Save a grid of images"""
    grid = vutils.make_grid(images, nrow=nrow, normalize=True, value_range=(-1, 1))
    grid_np = grid.cpu().numpy().transpose(1, 2, 0)
    grid_np = np.clip((grid_np * 0.5 + 0.5) * 255, 0, 255).astype(np.uint8)
    Image.fromarray(grid_np).save(save_path)
    return grid_np

# ============================================================================
# CONFIGURE YOUR PATHS HERE
# ============================================================================
checkpoint_path = "/content/drive/MyDrive/GazeGaussian_checkpoints/gazegaussian_ckp.pth"
data_dir = "/content/drive/MyDrive/GazeGaussian_data/ETH-XGaze_test/ETH-XGaze_test"
output_dir = "/content/test_outputs"
num_samples = 10  # Change this to generate more/fewer samples
device = 'cuda'
# ============================================================================

print("=" * 80)
print("GAZEGAUSSIAN CHECKPOINT TESTING")
print("=" * 80)

os.makedirs(output_dir, exist_ok=True)

# 1. Load checkpoint
print(f"\n[1/5] Loading checkpoint...")
if not os.path.exists(checkpoint_path):
    print(f"✗ Checkpoint not found: {checkpoint_path}")
    raise FileNotFoundError(checkpoint_path)

checkpoint = torch.load(checkpoint_path, map_location=device)
print(f"✓ Checkpoint loaded")

if isinstance(checkpoint, dict) and 'epoch' in checkpoint:
    print(f"  - Epoch: {checkpoint['epoch']}")
if isinstance(checkpoint, dict) and 'loss' in checkpoint:
    print(f"  - Loss: {checkpoint['loss']:.4f}")

# 2. Initialize model
print(f"\n[2/5] Initializing model...")
opt = BaseOptions()

if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    state_dict = checkpoint['model_state_dict']
elif isinstance(checkpoint, dict):
    state_dict = checkpoint
else:
    state_dict = None

model = GazeGaussianNet(opt, load_state_dict=state_dict)
model = model.to(device)
model.eval()
print("✓ Model initialized")

if hasattr(model, 'neural_render'):
    renderer_type = type(model.neural_render).__name__
    print(f"  - Neural Renderer: {renderer_type}")

# 3. Load validation data
print(f"\n[3/5] Loading validation data...")
opt.img_dir = data_dir
val_loader = get_val_loader(
    opt,
    data_dir=data_dir,
    batch_size=1,
    num_workers=0,
    evaluate=None,
    dataset_name='eth_xgaze'
)
print(f"✓ Data loaded ({len(val_loader.dataset)} samples)")

# 4. Generate images
print(f"\n[4/5] Generating {num_samples} images...")

success_count = 0

with torch.no_grad():
    for idx, data in enumerate(tqdm(val_loader, total=min(num_samples, len(val_loader)), desc="Generating")):
        if idx >= num_samples:
            break

        try:
            # Move data to device
            for key in data:
                if isinstance(data[key], torch.Tensor):
                    data[key] = data[key].to(device)
                elif isinstance(data[key], dict):
                    for sub_key in data[key]:
                        if isinstance(data[key][sub_key], torch.Tensor):
                            data[key][sub_key] = data[key][sub_key].to(device)

            # Forward pass
            output = model(data)

            # Get images
            gt_image = data.get('image', None)
            if gt_image is None:
                gt_image = data.get('img', None)

            # Get rendered images
            gaussian_img = output['total_render_dict']['merge_img']
            neural_img = output['total_render_dict']['merge_img_pro']

            # Create comparison
            if gt_image is not None:
                comparison = torch.cat([gt_image, gaussian_img, neural_img], dim=0)
            else:
                comparison = torch.cat([gaussian_img, neural_img], dim=0)

            # Save images
            save_path = os.path.join(output_dir, f"test_sample_{idx:03d}.png")
            save_image_grid(comparison, save_path, nrow=len(comparison))
            save_image_grid(gaussian_img, os.path.join(output_dir, f"test_sample_{idx:03d}_gaussian.png"), nrow=1)
            save_image_grid(neural_img, os.path.join(output_dir, f"test_sample_{idx:03d}_dit.png"), nrow=1)

            success_count += 1

        except Exception as e:
            print(f"\n✗ Error on sample {idx}: {e}")
            continue

# 5. Summary
print(f"\n[5/5] Summary")
print("=" * 80)
print(f"✅ Successfully generated {success_count}/{num_samples} images")
print(f"   Output: {output_dir}")
print("=" * 80)

# Copy to Drive
!cp -r {output_dir} /content/drive/MyDrive/gazegaussian_test_outputs
print(f"✓ Saved to Drive: /content/drive/MyDrive/gazegaussian_test_outputs")

# Display first 5 samples
print(f"\n{'='*80}\nDISPLAYING RESULTS\n{'='*80}")
for i in range(min(5, success_count)):
    img_path = os.path.join(output_dir, f"test_sample_{i:03d}.png")
    if os.path.exists(img_path):
        print(f"\n--- Sample {i} ---")
        display(IPImage(filename=img_path))

print(f"\n{'='*80}\n✅ TESTING COMPLETE!\n{'='*80}")

In [None]:
import os
os.environ['IGNORE_TORCH_VER'] = '1'
os.environ['FORCE_CUDA'] = '1'

%cd /content/kaolin
!python setup.py clean --all
!python setup.py install
%cd /content/GazeGaussian
print("✓ Kaolin reinstalled from source")

In [None]:
import os
os.environ['IGNORE_TORCH_VER'] = '1'
os.environ['FORCE_CUDA'] = '1'

%cd /content/kaolin
!python setup.py clean --all
!python setup.py install
%cd /content/GazeGaussian
print("✓ Kaolin reinstalled from source")

In [None]:
print("\n" + "="*80)
print("VERIFICATION")
print("="*80)

all_good = True

packages = [
    ('torch', 'PyTorch'),
    ('cv2', 'OpenCV'),
    ('h5py', 'h5py'),
    ('lpips', 'LPIPS'),
    ('kornia', 'Kornia'),
]

for mod, name in packages:
    try:
        m = __import__(mod)
        v = getattr(m, '__version__', 'OK')
        print(f"✓ {name:15s} {v}")
    except ImportError as e:
        print(f"✗ {name:15s} FAILED: {str(e)[:50]}")
        all_good = False

try:
    import simple_knn
    print(f"✓ {'simple-knn':15s} OK")
except ImportError as e:
    print(f"✗ {'simple-knn':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import diff_gaussian_rasterization
    print(f"✓ {'diff-gauss':15s} OK")
except ImportError as e:
    print(f"✗ {'diff-gauss':15s} FAILED: {str(e)[:50]}")
    all_good = False

try:
    import kaolin
    try:
        kaolin_version = kaolin.__version__
    except AttributeError:
        kaolin_version = 'OK (version unknown)'
    print(f"✓ {'kaolin':15s} {kaolin_version}")
except ImportError as e:
    print(f"✗ {'kaolin':15s} FAILED: {str(e)[:50]}")
    all_good = False

print("="*80)

if all_good:
    print("\n✅ ALL REQUIRED PACKAGES INSTALLED SUCCESSFULLY!")
    print("   Ready for training!")
else:
    print("\n⚠ Some packages failed. Check errors above and rerun failed installations.")