# Cache NeRF Model for Fast Rendering
## Modified for Fern Scene

This notebook extracts the trained NeRF model into cached data structures for 30 FPS real-time rendering.

In [2]:
import torch
from collections import defaultdict
import numpy as np
import mcubes
import trimesh

from models.rendering import *
from models.nerf import *

from datasets import dataset_dict

from utils import load_ckpt
import tqdm

import os
# Set GPU device (change if you have multiple GPUs)
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"  # Use GPU 0

# STEP 1: Load Model and Data

**IMPORTANT: Modify these paths for YOUR scene!**

In [3]:
###############################################################################
# MODIFY THESE FOR YOUR SCENE
###############################################################################

# Image resolution (must match training!)
img_wh = (504, 378)  # LLFF fern default resolution
# If you trained with different resolution, change this!

dataset_name = 'llff'  # We're using LLFF format

scene_name = 'horns'  # Name for output files

# Path to your fern dataset (only needs poses_bounds.npy, not images!)
root_dir = r'..\data\nerf_llff_data\horns'

# Path to your TRAINED model
ckpt_path = r'ckpts\horns\epoch=4.ckpt'  # Adjust epoch number if needed

###############################################################################
# END OF MODIFICATIONS
###############################################################################

print(f"Loading scene: {scene_name}")
print(f"Dataset from: {root_dir}")
print(f"Model from: {ckpt_path}")

# Load dataset (only for metadata like scene bounds)
kwargs = {'root_dir': root_dir,
          'img_wh': img_wh}
if dataset_name == 'llff':
    kwargs['spheric_poses'] = True
    kwargs['split'] = 'test'
else:
    kwargs['split'] = 'train'
    
chunk = 1024*32  # Batch size for processing
dataset = dataset_dict[dataset_name](**kwargs)

# Load the trained model
embedding_xyz = Embedding(3, 10)
embedding_dir = Embedding(3, 4)

nerf_fine = NeRF()
load_ckpt(nerf_fine, ckpt_path, model_name='nerf_fine')
nerf_fine.cuda().eval()

# Create output directory
os.makedirs('output', exist_ok=True)

print("✓ Model loaded successfully!")

Loading scene: horns
Dataset from: ..\data\nerf_llff_data\horns
Model from: ckpts\horns\epoch=4.ckpt
H = 378, W = 504, focal = 421.10296470036104, near = 1.3333333333333333, far = 21.334730896952106
example c2w:
[[ 0.9840111   0.0366225  -0.17430131 -0.60779285]
 [-0.01882579  0.99453607  0.10268202  0.36733163]
 [ 0.17710941 -0.09775888  0.97932398  0.17471896]]
[INFO] Built fastnerf: 8 x 256 | 4 x 128 | 8
✓ Model loaded successfully!


# STEP 2: Find Scene Bounds (Trial and Error)

**What this does:**
- Tests different bounding box sizes to find where the object is
- Start with conservative bounds, adjust if needed

**For LLFF fern, these defaults should work well.**

In [4]:
###############################################################################
# STEP 2A: Calculate 3D Grid (UVWs and Sigma)
###############################################################################

# Resolution of the 3D grid
# Start LOW (256-512) for testing, then increase (768-1024) for final quality
N = 512  # Lower = faster but lower quality. For testing use 256-512.
         # For final caching use 768-1024

print(f"Grid resolution: {N}^3 = {N**3:,} voxels")

###############################################################################
# ADJUST THESE BOUNDS IF YOUR SCENE IS CUT OFF
###############################################################################
# These define the 3D bounding box around your scene
# For LLFF fern, these defaults work well

xmin, xmax = -1.2, 1.2  # left/right range
ymin, ymax = -1.2, 1.2  # forward/backward range
zmin, zmax = -1.2, 1.2  # up/down range

print(f"Scene bounds: X[{xmin}, {xmax}], Y[{ymin}, {ymax}], Z[{zmin}, {zmax}]")

# IMPORTANT: All ranges must have the same length!
assert (xmax - xmin) == (ymax - ymin) == (zmax - zmin), "Bounds must be cubic!"

###############################################################################

# Create 3D grid of coordinates
x = np.linspace(xmin, xmax, N, endpoint=False)
y = np.linspace(ymin, ymax, N, endpoint=False)
z = np.linspace(zmin, zmax, N, endpoint=False)
xyz_ = torch.FloatTensor(np.stack(np.meshgrid(x, y, z), -1).reshape(-1, 3))
dir_ = torch.zeros_like(xyz_)  # Direction doesn't matter for density

print(f"Processing {xyz_.shape[0]:,} points...")

# Extract cached data from the model
with torch.no_grad():
    B = xyz_.shape[0]
    uvw, sigma = [], []
    for i in tqdm.trange(0, B, chunk):
        # Query the model at these 3D positions
        xyz_embedded = embedding_xyz(xyz_[i:i+chunk].cuda())
        dir_embedded = embedding_dir(dir_[i:i+chunk].cuda())
        xyzdir_embedded = torch.cat([xyz_embedded, dir_embedded], 1)
        
        # Get the cached components
        uvw_, beta_, sigma_ = nerf_fine(xyzdir_embedded, return_components=True)
        uvw.append(uvw_.cpu())    # Color/appearance data
        sigma.append(sigma_.cpu()) # Density (where stuff is)
        
    # Reshape to 3D grid
    uvw = torch.cat(uvw, 0).numpy().astype(np.float32).reshape(N, N, N, -1)
    sigma = torch.cat(sigma, 0).numpy().astype(np.float32).reshape(N, N, N, -1)

sigma = np.maximum(sigma, 0)  # Ensure non-negative density

print(f"✓ UVW shape: {uvw.shape}, dtype: {uvw.dtype}")
print(f"✓ Sigma shape: {sigma.shape}, dtype: {sigma.dtype}")
print(f"✓ Sigma range: [{sigma.min():.2f}, {sigma.max():.2f}]")

Grid resolution: 512^3 = 134,217,728 voxels
Scene bounds: X[-1.2, 1.2], Y[-1.2, 1.2], Z[-1.2, 1.2]
Processing 134,217,728 points...


100%|██████████| 4096/4096 [00:27<00:00, 146.57it/s]


✓ UVW shape: (512, 512, 512, 24), dtype: float32
✓ Sigma shape: (512, 512, 512, 1), dtype: float32
✓ Sigma range: [0.00, 1900.45]


# STEP 3: Save Sparse Cache (Memory Efficient)

**What this does:**
- Only saves voxels where something exists (sigma > threshold)
- Saves memory and loading time

In [5]:
###############################################################################
# Density threshold: voxels below this are considered empty
###############################################################################
sigma_thresh = 0  # Start with 0, increase if too much empty space is cached

print(f"Density threshold: {sigma_thresh}")

# Find non-empty voxels
mask = (sigma[:, :, :, 0] > sigma_thresh)

# Get coordinates of non-empty voxels
coords = np.nonzero(mask)
nnz = coords[0].shape[0]
sparsity = nnz / np.prod(mask.shape)

print(f"Non-empty voxels: {nnz:,} / {np.prod(mask.shape):,} ({sparsity*100:.1f}%)")

# Create index mapping (for fast lookup during rendering)
inds = -np.ones_like(mask, dtype=np.int32)
inds[coords] = np.arange(nnz)

# Extract data only for non-empty voxels
uvws = np.concatenate([
    uvw[coords],   # Color data
    sigma[coords], # Density
], axis=1).astype(np.float32)

print(f"Index map shape: {inds.shape}")
print(f"Sparse data shape: {uvws.shape}")

# Save the cached data
output_prefix = f'output/{scene_name}'
np.save(f'{output_prefix}_inds_{N}_{sigma_thresh}.npy', inds)
np.save(f'{output_prefix}_uvws_{N}_{sigma_thresh}.npy', uvws)

print(f"✓ Saved cache files:")
print(f"  - {output_prefix}_inds_{N}_{sigma_thresh}.npy")
print(f"  - {output_prefix}_uvws_{N}_{sigma_thresh}.npy")

Density threshold: 0
Non-empty voxels: 34,083,785 / 134,217,728 (25.4%)
Index map shape: (512, 512, 512)
Sparse data shape: (34083785, 25)
✓ Saved cache files:
  - output/horns_inds_512_0.npy
  - output/horns_uvws_512_0.npy


# STEP 4: Cache Directional Data (Beta)

**What this does:**
- Stores how colors change based on viewing direction
- Needed for realistic reflections and lighting effects

In [6]:
###############################################################################
# Cache directional appearance (cartesian version - more stable)
###############################################################################

M = 200  # Resolution for viewing directions
         # Higher = better quality but slower. 200 is good balance.

print(f"Caching {M}^3 = {M**3:,} view directions...")

# Create grid of normalized direction vectors
nx = np.linspace(-1, 1, M, endpoint=False)
ny = np.linspace(-1, 1, M, endpoint=False)
nz = np.linspace(-1, 1, M, endpoint=False)
dir_ = np.stack(np.meshgrid(nx, ny, nz), -1).reshape(-1, 3)
dir_ = dir_ / (np.linalg.norm(dir_, ord=2, axis=-1, keepdims=True) + 1e-6)
dir_ = torch.FloatTensor(dir_).cuda()
xyz_ = torch.zeros_like(dir_)  # Position doesn't matter for beta

chunk = 1024*32

with torch.no_grad():
    B = dir_.shape[0]
    beta = []
    for i in tqdm.trange(0, B, chunk):
        end = min(B, i+chunk)
        xyz_embedded = embedding_xyz(xyz_[i:end].cuda())
        dir_embedded = embedding_dir(dir_[i:end].cuda())
        xyzdir_embedded = torch.cat([xyz_embedded, dir_embedded], 1)
        
        # Get directional components
        uvw_, beta_, sigma_ = nerf_fine(xyzdir_embedded, return_components=True)
        beta.append(beta_.cpu())
        
    beta = torch.cat(beta, 0).numpy().astype(np.float32).reshape(M, M, M, -1)

print(f"✓ Beta shape: {beta.shape}, dtype: {beta.dtype}")

# Save beta cache
np.save(f'{output_prefix}_beta_{M}_cart.npy', beta)
print(f"✓ Saved: {output_prefix}_beta_{M}_cart.npy")

Caching 200^3 = 8,000,000 view directions...


100%|██████████| 245/245 [00:01<00:00, 134.76it/s]


✓ Beta shape: (200, 200, 200, 8), dtype: float32
✓ Saved: output/horns_beta_200_cart.npy


# STEP 5: Visualize 3D Mesh (Optional)

**What this does:**
- Extracts a 3D mesh from the density field
- Lets you see if the scene bounds are correct

**Skip this if you don't have trimesh/mcubes installed**

In [7]:
###############################################################################
# Visualize the density as a 3D mesh
###############################################################################

# Try different thresholds to see the object clearly
sigma_thresh_vis = 20  # Increase if too much noise, decrease if object is missing

print(f"Extracting mesh with threshold {sigma_thresh_vis}...")

try:
    vertices, triangles = mcubes.marching_cubes(sigma[:, :, :, 0], sigma_thresh_vis)
    
    # Normalize vertices to [0, 1] range
    vertices = vertices / N
    
    mesh = trimesh.Trimesh(vertices, triangles)
    
    print(f"✓ Mesh: {len(vertices)} vertices, {len(triangles)} faces")
    print("Opening 3D viewer...")
    
    mesh.show()  # Opens interactive 3D viewer
    
except Exception as e:
    print(f"Could not visualize mesh: {e}")
    print("Install mcubes and trimesh: pip install PyMCubes trimesh")

Extracting mesh with threshold 20...
✓ Mesh: 5279092 vertices, 10525106 faces
Opening 3D viewer...


# DONE!

## Summary of Generated Files:

```
output/
├── fern_inds_512_0.npy      ← Index map (which voxels are occupied)
├── fern_uvws_512_0.npy      ← Color/density data for each voxel
└── fern_beta_200_cart.npy   ← Directional appearance data
```

## Next Steps:

1. **For higher quality**: Re-run with N=768 or N=1024
2. **For interactive rendering**: Use these files with the Qt renderer
3. **If scene is cut off**: Adjust xmin/xmax/ymin/ymax/zmin/zmax bounds

## Using the Cache:

The Qt renderer will load these files for 30 FPS real-time rendering!

## File Sizes:
- N=512: ~200-500 MB total
- N=768: ~500 MB - 1 GB
- N=1024: ~1-2 GB