In [None]:
# ===========================
# 1) MOUNT DRIVE & INSTALL DEPS
# ===========================
from google.colab import drive
drive.mount('/content/drive')

DRIVE_WHEELS = "/content/drive/MyDrive/wheels"

# Go to wheels dir and install
import os
assert os.path.isdir(DRIVE_WHEELS), f"‚ùå Wheels folder not found at: {DRIVE_WHEELS}"
%cd $DRIVE_WHEELS

!pip install -q portalocker-3.2.0-py3-none-any.whl
!pip install -q iopath-0.1.10-py3-none-any.whl
!pip install -q tqdm-4.67.1-py3-none-any.whl
!pip install -q typing_extensions-4.15.0-py3-none-any.whl
!pip install -q pytorch3d-0.7.8-cp312-cp312-linux_x86_64.whl

# UI deps
!pip install -q gradio pillow

print("‚úÖ Deps installed")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/wheels
‚úÖ Deps installed


In [None]:
# ===========================
# 2) CLONE REPO & SET PATH
# ===========================
import os, sys

REPO_URL = "https://github.com/jintn/acit4030-3d-project.git"
REPO_DIR = "/content/acit4030-3d-project"

if not os.path.isdir(REPO_DIR):
    %cd /content
    !git clone {REPO_URL}

if REPO_DIR not in sys.path:
    sys.path.append(REPO_DIR)

%cd $REPO_DIR
assert os.path.exists("nerf_model.py"), "‚ùå nerf_model.py not found. Check repo folder."
print("‚úÖ Repo ready:", REPO_DIR)

/content/acit4030-3d-project
‚úÖ Repo ready: /content/acit4030-3d-project


In [None]:
# ===========================
# 3) GLOBALS & HYPERPARAMS
# ===========================
# Paths
DRIVE_EXPORT = "/content/drive/MyDrive/nerf"
os.makedirs(DRIVE_EXPORT, exist_ok=True)
MODEL_PATH  = f"{DRIVE_EXPORT}/nerf_trained.pt"
CONFIG_PATH = f"{DRIVE_EXPORT}/config.json"

# Training / rendering defaults
NUM_VIEWS               = 40
AZIMUTH_RANGE_DEG       = 180
N_ITERS                 = 1000
LR                      = 1e-3
MC_RAYS_PER_IMAGE       = 1250
PTS_PER_RAY             = 168
VOLUME_EXTENT_WORLD     = 3.0
RENDER_SCALE            = 2

# Viewer cache resolution (speed vs. quality)
AZ_STEP                 = 15
EL_STEP                 = 10

# Viewer defaults
START_AZ                = 180
START_EL                = 0

import torch, json, numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("‚úÖ Using device:", device)

‚úÖ Using device: cuda:0


In [None]:
# ===========================
# 4) IMPORTS (REPO + RENDERERS)
# ===========================
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter
import torch.nn.functional as F

from nerf_model import NeuralRadianceField
from utils.plot_image_grid import image_grid
from utils.generate_cow_renders import generate_cow_renders
from utils.helper_functions import (
    generate_rotating_nerf,
    huber,
    show_full_render,
    sample_images_at_mc_locs,
)

from pytorch3d.renderer import (
    FoVPerspectiveCameras,
    NDCMultinomialRaysampler,
    MonteCarloRaysampler,
    EmissionAbsorptionRaymarcher,
    ImplicitRenderer,
    look_at_view_transform,
)

print("‚úÖ Imports OK")

‚úÖ Imports OK


In [None]:
# -------------------------------------------------------------------------
# Hash-based NeRF variant (Instant-NGP style multi-resolution hash encoding)
# -------------------------------------------------------------------------
import math
from typing import Tuple

from nerf_model import HarmonicEmbedding
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points


class MultiResHashEncoder(nn.Module):
    """
    Simplified multi-resolution hash grid encoder inspired by Instant-NGP.

    Takes 3D coordinates in [0, 1]^3 and returns a concatenated feature vector.
    """

    def __init__(
        self,
        num_levels: int = 8,
        features_per_level: int = 2,
        log2_hashmap_size: int = 15,
        base_resolution: int = 16,
        finest_resolution: int = 256,
    ):
        super().__init__()

        self.num_levels = num_levels
        self.features_per_level = features_per_level
        self.log2_hashmap_size = log2_hashmap_size

        # Geometric progression of resolutions from base_resolution to finest_resolution
        b = math.exp(
            math.log(finest_resolution / base_resolution) / (num_levels - 1)
        )

        self.resolutions = []
        self.embeddings = nn.ModuleList()

        for level in range(num_levels):
            res = int(base_resolution * (b ** level))
            self.resolutions.append(res)

            # Hash table size per level (capped by res^3)
            hashmap_size = min(2 ** log2_hashmap_size, res ** 3)
            emb = nn.Embedding(hashmap_size, features_per_level)
            nn.init.uniform_(emb.weight, a=-1e-4, b=1e-4)
            self.embeddings.append(emb)

        # Fixed primes for simple 3D hash
        self.register_buffer(
            "hash_primes",
            torch.tensor([1, 2654435761, 805459861], dtype=torch.long),
        )

    def hash_coords(self, coords: torch.LongTensor, level: int) -> torch.LongTensor:
        """
        coords: (..., 3) integer grid coordinates
        return: (...) hashed indices in [0, table_size)
        """
        h = (coords * self.hash_primes).sum(dim=-1)
        h = h & 0xFFFFFFFF  # 32-bit mask
        table_size = self.embeddings[level].num_embeddings
        return h % table_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (..., 3) in [0, 1]^3
        returns: (..., num_levels * features_per_level)
        """
        assert x.shape[-1] == 3, "Input to MultiResHashEncoder must be (..., 3)."
        orig_shape = x.shape[:-1]
        x = x.view(-1, 3)  # (N, 3)

        feats_per_level = []

        for level, (res, emb) in enumerate(zip(self.resolutions, self.embeddings)):
            # Scale to grid
            pos = x * res  # in [0, res)
            pos_floor = torch.floor(pos).long()
            pos_frac = pos - pos_floor.float()

            max_coord = res - 1
            pos_floor = torch.clamp(pos_floor, 0, max_coord)

            # 8 corner offsets for trilinear interpolation
            offsets = torch.tensor(
                [
                    [0, 0, 0],
                    [1, 0, 0],
                    [0, 1, 0],
                    [1, 1, 0],
                    [0, 0, 1],
                    [1, 0, 1],
                    [0, 1, 1],
                    [1, 1, 1],
                ],
                device=x.device,
                dtype=torch.long,
            )  # (8, 3)

            # (N, 3) + (8, 3) -> (N, 8, 3)
            corner_coords = pos_floor.unsqueeze(1) + offsets.unsqueeze(0)
            corner_coords = torch.clamp(corner_coords, 0, max_coord)

            # Flatten, hash, lookup
            hashed = self.hash_coords(corner_coords.view(-1, 3), level)
            corner_feats = emb(hashed)  # (N * 8, F)
            corner_feats = corner_feats.view(-1, 8, self.features_per_level)

            # Trilinear weights
            fx, fy, fz = pos_frac.unbind(dim=-1)  # (N,)
            fx = fx.view(-1, 1)
            fy = fy.view(-1, 1)
            fz = fz.view(-1, 1)

            w000 = (1 - fx) * (1 - fy) * (1 - fz)
            w100 = (fx) * (1 - fy) * (1 - fz)
            w010 = (1 - fx) * (fy) * (1 - fz)
            w110 = (fx) * (fy) * (1 - fz)
            w001 = (1 - fx) * (1 - fy) * (fz)
            w101 = (fx) * (1 - fy) * (fz)
            w011 = (1 - fx) * (fy) * (fz)
            w111 = (fx) * (fy) * (fz)

            weights = torch.cat(
                [w000, w100, w010, w110, w001, w101, w011, w111], dim=1
            )  # (N, 8)

            feat = (corner_feats * weights.unsqueeze(-1)).sum(dim=1)  # (N, F)
            feats_per_level.append(feat)

        encoded = torch.cat(feats_per_level, dim=-1)  # (N, L*F)
        encoded = encoded.view(*orig_shape, self.num_levels * self.features_per_level)
        return encoded


class HashNeuralRadianceField(nn.Module):
    """
    NeRF variant that:
      * uses MultiResHashEncoder for 3D positions
      * uses harmonic embedding for viewing directions
    Interface matches the original NeuralRadianceField:
      forward(ray_bundle: RayBundle) -> (densities, colors)
    """

    def __init__(
        self,
        num_levels: int = 8,
        features_per_level: int = 2,
        log2_hashmap_size: int = 15,
        base_resolution: int = 16,
        finest_resolution: int = 256,
        n_hidden_neurons: int = 256,
        n_harmonic_functions_dir: int = 60,
        aabb_min: float = -1.0,
        aabb_max: float = 1.0,
    ):
        super().__init__()

        self.aabb_min = aabb_min
        self.aabb_max = aabb_max

        # --- position encoder (hash grid) ---
        self.pos_encoder = MultiResHashEncoder(
            num_levels=num_levels,
            features_per_level=features_per_level,
            log2_hashmap_size=log2_hashmap_size,
            base_resolution=base_resolution,
            finest_resolution=finest_resolution,
        )
        pos_embedding_dim = num_levels * features_per_level

        # --- direction encoder (reuse HarmonicEmbedding from this file) ---
        self.dir_embedding = HarmonicEmbedding(n_harmonic_functions_dir)
        dir_embedding_dim = n_harmonic_functions_dir * 2 * 3

        # --- shared MLP for features ---
        self.mlp = nn.Sequential(
            nn.Linear(pos_embedding_dim, n_hidden_neurons),
            nn.Softplus(beta=10.0),
            nn.Linear(n_hidden_neurons, n_hidden_neurons),
            nn.Softplus(beta=10.0),
        )

        # --- density branch ---
        self.density_layer = nn.Sequential(
            nn.Linear(n_hidden_neurons, 1),
            nn.Softplus(beta=10.0),
        )
        # Initialize bias like original NeRF
        self.density_layer[0].bias.data[0] = -1.5

        # --- color branch ---
        self.color_layer = nn.Sequential(
            nn.Linear(n_hidden_neurons + dir_embedding_dim, n_hidden_neurons),
            nn.Softplus(beta=10.0),
            nn.Linear(n_hidden_neurons, 3),
            nn.Sigmoid(),  # RGB in [0, 1]
        )

    def _normalize_points(self, points: torch.Tensor) -> torch.Tensor:
        """
        Map world coordinates from [aabb_min, aabb_max] to [0, 1] for encoding.
        points: (..., 3)
        """
        return (points - self.aabb_min) / (self.aabb_max - self.aabb_min + 1e-6)

    def _get_densities(self, features: torch.Tensor) -> torch.Tensor:
        raw_densities = self.density_layer(features)
        return 1 - (-raw_densities).exp()  # same mapping as original NeRF

    def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor) -> torch.Tensor:
        """
        View-dependent color, same idea as original NeRF: harmonic embedding of directions.
        """
        spatial_size = features.shape[:-1]

        # Normalize directions to unit length
        rays_directions_normed = torch.nn.functional.normalize(
            rays_directions, dim=-1
        )
        rays_embed = self.dir_embedding(rays_directions_normed)

        # Expand to match feature spatial size
        rays_embed_expand = rays_embed[..., None, :].expand(
            *spatial_size, rays_embed.shape[-1]
        )

        color_input = torch.cat((features, rays_embed_expand), dim=-1)
        return self.color_layer(color_input)

    def forward(self, ray_bundle: RayBundle, **kwargs):
        """
        Matches the original NeuralRadianceField.forward:
          ray_bundle -> densities, colors
        """
        # 1) Convert ray parametrization to 3D points in world coordinates
        rays_points_world = ray_bundle_to_ray_points(ray_bundle)
        # rays_points_world: [minibatch, ..., 3]

        # 2) Normalize points for hash encoder
        pos_norm = self._normalize_points(rays_points_world)

        # 3) Hash encoding of positions
        embeds = self.pos_encoder(pos_norm)

        # 4) MLP to latent features
        features = self.mlp(embeds)

        # 5) Density & color heads
        rays_densities = self._get_densities(features)
        rays_colors = self._get_colors(features, ray_bundle.directions)

        return rays_densities, rays_colors

    def batched_forward(
        self,
        ray_bundle: RayBundle,
        n_batches: int = 16,
        **kwargs,
    ):
        """
        Same pattern as the original NeuralRadianceField.batched_forward,
        but calling this class's forward().
        """
        n_pts_per_ray = ray_bundle.lengths.shape[-1]
        spatial_size = [*ray_bundle.origins.shape[:-1], n_pts_per_ray]

        tot_samples = ray_bundle.origins.shape[:-1].numel()
        batches = torch.chunk(torch.arange(tot_samples, device=ray_bundle.origins.device), n_batches)

        batch_outputs = [
            self.forward(
                RayBundle(
                    origins=ray_bundle.origins.view(-1, 3)[batch_idx],
                    directions=ray_bundle.directions.view(-1, 3)[batch_idx],
                    lengths=ray_bundle.lengths.view(-1, n_pts_per_ray)[batch_idx],
                    xys=None,
                )
            )
            for batch_idx in batches
        ]

        rays_densities, rays_colors = [
            torch.cat(
                [batch_output[output_i] for batch_output in batch_outputs],
                dim=0,
            ).view(*spatial_size, -1)
            for output_i in (0, 1)
        ]

        return rays_densities, rays_colors

In [None]:
# ===========================
# 5) SYNTHETIC DATA (COW RENDERS)
# ===========================
target_cameras, target_images, target_silhouettes = generate_cow_renders(
    num_views=NUM_VIEWS, azimuth_range=AZIMUTH_RANGE_DEG
)
print(f"Generated {len(target_images)} images/silhouettes/cameras.")

Generated 40 images/silhouettes/cameras.


In [None]:
# ===========================
# 6) RENDERERS (MC SAMPLER + FULL GRID)
# ===========================
render_size = target_images.shape[1] * RENDER_SCALE

raysampler_mc = MonteCarloRaysampler(
    min_x=-1.0, max_x=1.0,
    min_y=-1.0, max_y=1.0,
    n_rays_per_image=MC_RAYS_PER_IMAGE,
    n_pts_per_ray=PTS_PER_RAY,
    min_depth=0.1, max_depth=VOLUME_EXTENT_WORLD,
)
renderer_mc = ImplicitRenderer(
    raysampler=raysampler_mc, raymarcher=EmissionAbsorptionRaymarcher()
)

raysampler_grid = NDCMultinomialRaysampler(
    image_height=render_size,
    image_width=render_size,
    n_pts_per_ray=PTS_PER_RAY,
    min_depth=0.1,
    max_depth=VOLUME_EXTENT_WORLD,
)
renderer_grid = ImplicitRenderer(
    raysampler=raysampler_grid, raymarcher=EmissionAbsorptionRaymarcher()
)

print("‚úÖ Renderers ready")

‚úÖ Renderers ready


In [None]:
# ===========================
# 7) MODEL ‚Üí DEVICE
# ===========================
torch.manual_seed(1)

# Baseline:
#neural_radiance_field = NeuralRadianceField()

# Hash-based NeRF:
neural_radiance_field = HashNeuralRadianceField(
    num_levels=8,
    features_per_level=2,
    log2_hashmap_size=15,
    base_resolution=16,
    finest_resolution=256,
    n_hidden_neurons=256,
    n_harmonic_functions_dir=60,
    aabb_min=-1.0,
    aabb_max=1.0,
)

renderer_grid = renderer_grid.to(device)
renderer_mc   = renderer_mc.to(device)
target_cameras      = target_cameras.to(device)
target_images       = target_images.to(device)
target_silhouettes  = target_silhouettes.to(device)
neural_radiance_field = neural_radiance_field.to(device)

print("‚úÖ HashNeuralRadianceField & tensors on device")

‚úÖ HashNeuralRadianceField & tensors on device


In [None]:
# ===========================
# 8) TRAIN LOOP
# ===========================
optimizer = torch.optim.Adam(neural_radiance_field.parameters(), lr=LR)
batch_size = 6
n_iter = N_ITERS

loss_history_color, loss_history_sil = [], []

for iteration in range(n_iter):
    if iteration == round(n_iter * 0.75):
        print("Decreasing LR 10-fold ...")
        optimizer = torch.optim.Adam(neural_radiance_field.parameters(), lr=LR * 0.1)

    optimizer.zero_grad()
    batch_idx = torch.randperm(len(target_cameras))[:batch_size]

    batch_cameras = FoVPerspectiveCameras(
        R=target_cameras.R[batch_idx],
        T=target_cameras.T[batch_idx],
        znear=target_cameras.znear[batch_idx],
        zfar=target_cameras.zfar[batch_idx],
        aspect_ratio=target_cameras.aspect_ratio[batch_idx],
        fov=target_cameras.fov[batch_idx],
        device=device,
    )

    rendered_images_silhouettes, sampled_rays = renderer_mc(
        cameras=batch_cameras,
        volumetric_function=neural_radiance_field,
    )
    rendered_images, rendered_silhouettes = rendered_images_silhouettes.split([3, 1], dim=-1)

    silhouettes_at_rays = sample_images_at_mc_locs(
        target_silhouettes[batch_idx, ..., None], sampled_rays.xys
    )
    colors_at_rays = sample_images_at_mc_locs(
        target_images[batch_idx], sampled_rays.xys
    )

    sil_err = huber(rendered_silhouettes, silhouettes_at_rays).abs().mean()
    color_err = huber(rendered_images, colors_at_rays).abs().mean()
    loss = color_err + sil_err

    loss_history_color.append(float(color_err.detach()))
    loss_history_sil.append(float(sil_err.detach()))

    loss.backward()
    optimizer.step()

    if iteration % 100 == 0:
        print(f"Iteration {iteration}/{n_iter} | loss={loss.item():.4f}")
        show_idx = torch.randperm(len(target_cameras))[:1]
        fig = show_full_render(
            neural_radiance_field,
            FoVPerspectiveCameras(
                R=target_cameras.R[show_idx],
                T=target_cameras.T[show_idx],
                znear=target_cameras.znear[show_idx],
                zfar=target_cameras.zfar[show_idx],
                aspect_ratio=target_cameras.aspect_ratio[show_idx],
                fov=target_cameras.fov[show_idx],
                device=device,
            ),
            target_images[show_idx][0],
            target_silhouettes[show_idx][0],
            renderer_grid,
            loss_history_color,
            loss_history_sil,
        )
        fig.savefig(f"intermediate_{iteration}.png")
        plt.close(fig)

print("‚úÖ Training complete")

Iteration 0/1000 | loss=0.3038


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.89 GiB. GPU 0 has a total capacity of 14.74 GiB of which 1.87 GiB is free. Process 49052 has 12.87 GiB memory in use. Of the allocated memory 9.93 GiB is allocated by PyTorch, and 2.63 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# ===========================
# 9) SAVE MODEL & CONFIG (VERSIONED)
# ===========================
import os, json, time, torch

SAVE_DIR = "/content/drive/MyDrive/nerf"
os.makedirs(SAVE_DIR, exist_ok=True)

# Create timestamped filenames so nothing is overwritten
timestamp = time.strftime("%Y%m%d_%H%M%S")

MODEL_PATH_NEW  = f"{SAVE_DIR}/nerf_hash_{timestamp}.pt"
CONFIG_PATH_NEW = f"{SAVE_DIR}/config_hash_{timestamp}.json"

# --- save weights ---
torch.save(neural_radiance_field.state_dict(), MODEL_PATH_NEW)
print(f"üíæ Saved BASELINE model to: {MODEL_PATH_NEW}")

# --- save config (training + rendering intrinsics) ---
cfg = {
    "fov": float(target_cameras.fov[0]),
    "aspect_ratio": float(target_cameras.aspect_ratio[0]),
    "znear": float(target_cameras.znear[0]),
    "zfar": float(target_cameras.zfar[0]),
    "render_size": int(target_images.shape[1] * RENDER_SCALE),
    "volume_extent_world": float(VOLUME_EXTENT_WORLD),

    # Also store training hyperparams (optional but useful)
    "training_iters": N_ITERS,
    "learning_rate": LR,
    "mc_rays_per_image": MC_RAYS_PER_IMAGE,
    "pts_per_ray": PTS_PER_RAY,
}

with open(CONFIG_PATH_NEW, "w") as f:
    json.dump(cfg, f, indent=2)

print(f"üíæ Saved config to: {CONFIG_PATH_NEW}")

üíæ Saved BASELINE model to: /content/drive/MyDrive/nerf/nerf_hash_20251118_082228.pt
üíæ Saved config to: /content/drive/MyDrive/nerf/config_hash_20251118_082228.json


In [None]:
# ===========================
# 10b) LOAD HASH-NeRF MODEL & RENDERER FOR VIEWER
# ===========================
from pytorch3d.renderer import (
    FoVPerspectiveCameras,
    NDCMultinomialRaysampler,
    EmissionAbsorptionRaymarcher,
    ImplicitRenderer,
)
import json, torch, os

# --- specify which hash model to load ---
HASH_MODEL_PATH  = "/content/drive/MyDrive/nerf/nerf_hash_20251118_082228.pt"      # <-- INSERT YOUR CHECKPOINT
HASH_CONFIG_PATH = "/content/drive/MyDrive/nerf/config_hash_20251118_082228.json"  # <-- INSERT MATCHING CONFIG

assert os.path.exists(HASH_MODEL_PATH),  f"‚ùå Model not found: {HASH_MODEL_PATH}"
assert os.path.exists(HASH_CONFIG_PATH), f"‚ùå Config not found: {HASH_CONFIG_PATH}"

# --- load HashNeRF model ---
model = HashNeuralRadianceField(
    num_levels=8,
    features_per_level=2,
    log2_hashmap_size=15,
    base_resolution=16,
    finest_resolution=256,
    n_hidden_neurons=256,
    n_harmonic_functions_dir=60,
    aabb_min=-1.0,
    aabb_max=1.0,
).to(device)

state = torch.load(HASH_MODEL_PATH, map_location=device)
model.load_state_dict(state)
model.eval()
print("‚úÖ Loaded HASH-NeRF model:", HASH_MODEL_PATH)

# --- load config ---
with open(HASH_CONFIG_PATH, "r") as f:
    cfg = json.load(f)
print("‚úÖ Loaded HASH-NeRF config")

# --- rebuild renderer (same as baseline) ---
raysampler_grid = NDCMultinomialRaysampler(
    image_height=int(cfg["render_size"]),
    image_width=int(cfg["render_size"]),
    n_pts_per_ray=PTS_PER_RAY,
    min_depth=0.1,
    max_depth=float(cfg["volume_extent_world"]),
)
renderer_grid = ImplicitRenderer(
    raysampler=raysampler_grid,
    raymarcher=EmissionAbsorptionRaymarcher(),
).to(device)

# --- default viewer camera ---
base_cameras = FoVPerspectiveCameras(
    device=device,
    znear=torch.tensor([cfg["znear"]], device=device),
    zfar=torch.tensor([cfg["zfar"]], device=device),
    aspect_ratio=torch.tensor([cfg["aspect_ratio"]], device=device),
    fov=torch.tensor([cfg["fov"]], device=device),
)

print("‚úÖ HASH-NeRF viewer renderer ready")

‚úÖ Loaded HASH-NeRF model: /content/drive/MyDrive/nerf/nerf_hash_20251118_082228.pt
‚úÖ Loaded HASH-NeRF config
‚úÖ HASH-NeRF viewer renderer ready


In [None]:
# ===========================
# 11) GPU CACHE + INTERPOLATION + BLUR
# ===========================
from typing import Dict, Tuple

CACHE_PATH = f"{DRIVE_EXPORT}/views_AZ{AZ_STEP}_EL{EL_STEP}.npz"

_cache_tensor = None
_key_to_idx: Dict[Tuple[int, int], int] = {}
_H = _W = None

def to_uint8_np_from_torch(img_t: torch.Tensor) -> np.ndarray:
    return (img_t.clamp(0,1).detach().cpu().numpy() * 255).astype(np.uint8)

@torch.no_grad()
def render_full_gpu(elev_deg: float, azim_deg: float, dist: float = 2.7) -> torch.Tensor:
    R, T = look_at_view_transform(dist=dist, elev=elev_deg, azim=azim_deg, device=device)
    cam = FoVPerspectiveCameras(
        R=R, T=T,
        znear=base_cameras.znear, zfar=base_cameras.zfar,
        aspect_ratio=base_cameras.aspect_ratio, fov=base_cameras.fov,
        device=device,
    )
    img_sil, _ = renderer_grid(cameras=cam, volumetric_function=model.batched_forward)
    return img_sil[..., :3].squeeze(0)

@torch.no_grad()
def precompute_cache_gpu():
    global _cache_tensor, _key_to_idx, _H, _W
    azs = np.arange(0, 361, AZ_STEP)
    els = np.arange(-30, 31, EL_STEP)
    imgs = []
    _key_to_idx.clear()
    idx = 0
    for el in els:
        for az in azs:
            k = (int(az) % 360, int(np.clip(el, -30, 30)))
            if k not in _key_to_idx:
                img = render_full_gpu(el, az)
                imgs.append(img)
                _key_to_idx[k] = idx
                idx += 1
    _cache_tensor = torch.stack(imgs, dim=0)
    _H, _W = _cache_tensor.shape[1:3]
    print(f"‚ö° Cached {len(_key_to_idx)} views on {device} (AZ {AZ_STEP}¬∞, EL {EL_STEP}¬∞).")

def save_cache_npz_gpu(path: str):
    if _cache_tensor is None:
        return
    imgs_u8 = (_cache_tensor.clamp(0,1).mul(255).byte().cpu().numpy())
    keys = np.array(list(_key_to_idx.keys()), dtype=object)
    order = np.array([_key_to_idx[k] for k in _key_to_idx], dtype=np.int32)
    sort_idx = np.argsort(order)
    np.savez_compressed(path, keys=keys[sort_idx], imgs=imgs_u8[sort_idx])
    print(f"üíæ Saved GPU cache ‚Üí {path} ({len(keys)} views)")

def load_cache_npz_gpu(path: str) -> bool:
    global _cache_tensor, _key_to_idx, _H, _W
    if not os.path.exists(path):
        return False
    data = np.load(path, allow_pickle=True)
    keys = list(data["keys"])
    imgs = data["imgs"]
    _key_to_idx.clear()
    for i, k in enumerate(keys):
        _key_to_idx[tuple(k)] = i
    _cache_tensor = torch.from_numpy(imgs.astype(np.float32) / 255.0).to(device)
    _H, _W = _cache_tensor.shape[1:3]
    print(f"üì• Loaded GPU cache from {path} ({len(_key_to_idx)} views).")
    return True

def _get_four_indices(az, el):
    az0 = int(np.floor(az / AZ_STEP) * AZ_STEP) % 360
    az1 = (az0 + AZ_STEP) % 360
    el0 = int(np.clip(np.floor((el + 30) / EL_STEP) * EL_STEP - 30, -30, 30))
    el1 = int(np.clip(el0 + EL_STEP, -30, 30))
    return (az0, el0), (az1, el0), (az0, el1), (az1, el1)

@torch.no_grad()
def bilinear_preview_gpu(az, el) -> torch.Tensor:
    (az0, el0), (az1, el0b), (az0b, el1), (az1b, el1b) = _get_four_indices(az, el)
    i00 = _key_to_idx[(az0, el0)]
    i10 = _key_to_idx[(az1, el0b)]
    i01 = _key_to_idx[(az0b, el1)]
    i11 = _key_to_idx[(az1b, el1b)]
    I00 = _cache_tensor[i00]
    I10 = _cache_tensor[i10]
    I01 = _cache_tensor[i01]
    I11 = _cache_tensor[i11]
    t = torch.tensor(((az - az0) % 360) / AZ_STEP, device=device).float()
    u = torch.tensor((el - el0) / max(EL_STEP, 1e-6), device=device).float()
    top = (1 - t) * I00 + t * I10
    bot = (1 - t) * I01 + t * I11
    return (1 - u) * top + u * bot

def gaussian_kernel1d(radius: int, sigma: float, device):
    x = torch.arange(-radius, radius+1, device=device)
    w = torch.exp(-(x**2)/(2*sigma*sigma))
    w = w / w.sum()
    return w

@torch.no_grad()
def blur_preview_gpu(img: torch.Tensor, radius: int = 2, sigma: float = 1.5) -> torch.Tensor:
    k1d = gaussian_kernel1d(radius, sigma, device=img.device)
    x = img.permute(2,0,1).unsqueeze(0)
    kh = k1d.view(1,1,1,-1)
    x = F.conv2d(x, kh.expand(3,1,1,kh.shape[-1]), padding=(0, radius), groups=3)
    kv = k1d.view(1,1,-1,1)
    x = F.conv2d(x, kv.expand(3,1,kv.shape[-2],1), padding=(radius, 0), groups=3)
    return x.squeeze(0).permute(1,2,0)

In [None]:
# ---- Helpers to pick the latest cache automatically ----
import io, hashlib, glob, os, re, time

def model_md5(model) -> str:
    buf = io.BytesIO()
    torch.save(model.state_dict(), buf)
    return hashlib.md5(buf.getvalue()).hexdigest()[:8]

def find_latest_cache(cache_dir: str, az_step: int, el_step: int, preferred_hash: str | None):
    """
    Return (path, reason) where `path` is the best cache to use:
      1) exact match for AZ/EL and preferred_hash, newest mtime if multiple
      2) else newest file that matches AZ/EL regardless of hash
      3) else None
    Expected filename pattern: views_AZ{az}_EL{el}_{hash}.npz
    """
    pattern = os.path.join(cache_dir, f"views_AZ{az_step}_EL{el_step}_*.npz")
    candidates = glob.glob(pattern)
    if not candidates:
        return None, "no cache files found"

    # extract (mtime, path, hash)
    rx = re.compile(rf"views_AZ{az_step}_EL{el_step}_(?P<h>[0-9a-fA-F]+)\.npz$")
    parsed = []
    for p in candidates:
        m = rx.search(os.path.basename(p))
        h = m.group("h") if m else None
        parsed.append((os.path.getmtime(p), p, h))

    # 1) prefer exact hash match if available
    if preferred_hash:
        exact = [t for t in parsed if t[2] == preferred_hash]
        if exact:
            exact.sort(key=lambda t: t[0], reverse=True)
            return exact[0][1], "match model hash"

    # 2) otherwise newest by mtime
    parsed.sort(key=lambda t: t[0], reverse=True)
    return parsed[0][1], "newest by mtime"

In [None]:
# ===========================
# 12) BUILD/LOAD CACHE + CALLBACKS (auto-pick latest)
# ===========================
FORCE_REBUILD_CACHE = False  # set True to force regeneration

# derive preferred cache name for *this* model
PREFERRED_HASH = model_md5(model)
CACHE_DIR = DRIVE_EXPORT

# Optionally print which cache we intend to use
print(f"‚ÑπÔ∏è Preferred cache hash for this model: {PREFERRED_HASH}")

def build_and_save_cache():
    global _cache_tensor, _key_to_idx, _H, _W
    _cache_tensor = None
    _key_to_idx.clear()
    torch.cuda.empty_cache()
    precompute_cache_gpu()
    # name includes steps + model hash so caches don‚Äôt get mixed up
    out_path = os.path.join(CACHE_DIR, f"views_AZ{AZ_STEP}_EL{EL_STEP}_{PREFERRED_HASH}.npz")
    save_cache_npz_gpu(out_path)
    return out_path

if FORCE_REBUILD_CACHE:
    print("üîÅ FORCE_REBUILD_CACHE=True ‚Üí rebuilding cache now‚Ä¶")
    CACHE_PATH = build_and_save_cache()
else:
    # try to find the best existing cache
    CACHE_PATH, reason = find_latest_cache(CACHE_DIR, AZ_STEP, EL_STEP, preferred_hash=PREFERRED_HASH)
    if CACHE_PATH is not None and load_cache_npz_gpu(CACHE_PATH):
        print(f"‚úÖ Loaded cache: {CACHE_PATH} ({reason})")
    else:
        print("‚ö†Ô∏è No suitable cache found ‚Üí building a new one‚Ä¶")
        CACHE_PATH = build_and_save_cache()

@torch.no_grad()
def on_change(azim, elev, mode, current_img):
    img = bilinear_preview_gpu(azim, elev)
    if mode == "Blur preview" and img is not None:
        img = blur_preview_gpu(img, radius=2, sigma=1.5)
    return to_uint8_np_from_torch(img)

@torch.no_grad()
def on_release(azim, elev):
    img = render_full_gpu(elev, azim)
    return to_uint8_np_from_torch(img)

START_IMG_T = render_full_gpu(START_EL, START_AZ)
START_IMG = to_uint8_np_from_torch(START_IMG_T)
H, W = START_IMG.shape[:2]
print("‚úÖ Viewer callbacks ready")

‚ÑπÔ∏è Preferred cache hash for this model: e12ee8e9
üì• Loaded GPU cache from /content/drive/MyDrive/nerf/views_AZ15_EL10_65f687b2.npz (168 views).
‚úÖ Loaded cache: /content/drive/MyDrive/nerf/views_AZ15_EL10_65f687b2.npz (newest by mtime)
‚úÖ Viewer callbacks ready


In [None]:
# ===========================
# 13) GRADIO UI
# ===========================
import gradio as gr

with gr.Blocks(title="NeRF Viewer") as demo:
    gr.Markdown("## üêÑ NeRF Interactive Viewer\nDrag on the image to rotate. Release to render full quality.")
    with gr.Row():
        image_out = gr.Image(
            value=START_IMG, label="Render",
            type="numpy", height=H, width=W, interactive=True
        )
        with gr.Column(scale=0):
            az = gr.Slider(0, 360, value=START_AZ, step=1, label="Azimuth (¬∞)")
            el = gr.Slider(-30, 30, value=START_EL, step=1, label="Elevation (¬∞)")
            quality = gr.Radio(
                choices=["Bilinear preview", "Blur preview"],
                value="Bilinear preview",
                label="Drag preview"
            )

    # Client-side JS: drag the image to update sliders (which triggers previews)
    drag_bind = gr.HTML("""
    <script>
    (function(){
      const sleep = (ms) => new Promise(r => setTimeout(r, ms));
      async function bind() {
        for (let i=0;i<50;i++){
          const app = window.gradioApp?.();
          if (app) break;
          await sleep(100);
        }
        const app = window.gradioApp?.();
        if (!app) return;
        const img = app.querySelector('div.svelte-1ipelgc img, .image-container img');
        const az = app.querySelector('input[type="range"][min="0"][max="360"]');
        const el = app.querySelector('input[type="range"][min="-30"][max="30"]');
        if (!img || !az || !el) return;

        let dragging = false, lastX = 0, lastY = 0;
        const clamp = (v,min,max)=>Math.max(min,Math.min(max,v));
        const step = (v,s)=>Math.round(v/s)*s;

        img.addEventListener('mousedown', (e)=>{
          dragging = true; lastX = e.clientX; lastY = e.clientY; e.preventDefault();
        });
        window.addEventListener('mouseup', ()=>{ dragging=false; });
        window.addEventListener('mousemove', (e)=>{
          if(!dragging) return;
          const dx = e.clientX - lastX;
          const dy = e.clientY - lastY;
          lastX = e.clientX; lastY = e.clientY;

          const AZ_SENS = 0.5;
          const EL_SENS = 0.3;

          let azVal = (parseFloat(az.value) + dx * AZ_SENS) % 360;
          if(azVal < 0) azVal += 360;
          let elVal = clamp(parseFloat(el.value) - dy * EL_SENS, -30, 30);

          az.value = String(step(azVal,1));
          el.value = String(step(elVal,1));

          az.dispatchEvent(new Event('input', {bubbles:true}));
          el.dispatchEvent(new Event('input', {bubbles:true}));
        });
        window.addEventListener('mouseup', ()=>{
          az.dispatchEvent(new Event('change', {bubbles:true}));
          el.dispatchEvent(new Event('change', {bubbles:true}));
        });
      }
      bind();
    })();
    </script>
    """)

    az.change(on_change, inputs=[az, el, quality, image_out], outputs=image_out, queue=False)
    el.change(on_change, inputs=[az, el, quality, image_out], outputs=image_out, queue=False)
    az.release(on_release, inputs=[az, el], outputs=image_out, queue=True)
    el.release(on_release, inputs=[az, el], outputs=image_out, queue=True)

demo.launch(inline=True, debug=True)

It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://f42337b1e819e70804.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://f42337b1e819e70804.gradio.live


