# NeRF Rendering from Checkpoint (Original Repo)

**Loads your trained model and renders test images**

**Requirements:**
1. Upload `results.zip` to this Colab
2. Mount Google Drive (must have lego dataset)

**Time:** ~20 minutes for 3 images

## Step 1: Check GPU

In [None]:
!nvidia-smi

import torch
print(f"\nCUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Step 2: Upload results.zip

**Click the folder icon on left → Upload → Select your results.zip**

In [None]:
import os

if os.path.exists('results.zip'):
    print("✓ results.zip found")
    !ls -lh results.zip
else:
    print("✗ Please upload results.zip first!")

## Step 3: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

print("\n✓ Drive mounted")

## Step 4: Setup Dataset

In [None]:
# Copy dataset from Drive
!mkdir -p data/nerf_synthetic
!cp -r /content/drive/MyDrive/data/nerf_synthetic/lego data/nerf_synthetic/

print("✓ Dataset copied")
!ls data/nerf_synthetic/lego/

## Step 5: Fix Dataset JSON Files

In [None]:
import json
import os

# Fix JSON files
for split in ['train', 'val', 'test']:
    json_path = f'data/nerf_synthetic/lego/transforms_{split}.json'
    
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    split_dir = f'data/nerf_synthetic/lego/{split}'
    actual_files = sorted([f.replace('.png', '') for f in os.listdir(split_dir) 
                          if f.endswith('.png') and 'depth' not in f])
    
    for i, frame in enumerate(data['frames']):
        if i < len(actual_files):
            data['frames'][i]['file_path'] = f"./{split}/{actual_files[i]}"
    
    with open(json_path, 'w') as f:
        json.dump(data, f, indent=2)
    
    print(f"✓ Fixed {split}")

print("\n✓ Dataset ready")

## Step 6: Clone Original NeRF Repo

In [None]:
# Clone the ORIGINAL NeRF implementation
!rm -rf nerf-pytorch
!git clone https://github.com/yenchenlin/nerf-pytorch.git

print("\n✓ Original NeRF repo cloned")
!ls nerf-pytorch/

## Step 7: Install Dependencies

In [None]:
!pip install imageio imageio-ffmpeg configargparse -q

print("✓ Dependencies installed")

## Step 8: Extract Checkpoint

In [None]:
# Extract results.zip
!unzip -q results.zip

# Find checkpoint
import glob

checkpoints = glob.glob('results/logs/lego_metrics/*.tar')
if not checkpoints:
    checkpoints = glob.glob('logs/lego_metrics/*.tar')

if checkpoints:
    checkpoint_path = checkpoints[0]
    print(f"✓ Found checkpoint: {checkpoint_path}")
else:
    print("✗ Checkpoint not found!")
    !find . -name "*.tar" -type f

## Step 9: Render Images

**This will take ~20 minutes for 3 images at half resolution**

In [None]:
import sys
sys.path.insert(0, 'nerf-pytorch')

import torch
import numpy as np
import imageio
from tqdm import tqdm
import time

from run_nerf_helpers import *
from run_nerf import render
from load_blender import load_blender_data

device = torch.device("cpu")

print("Loading dataset at HALF resolution...")
images, poses, render_poses, hwf, i_split = load_blender_data(
    'data/nerf_synthetic/lego', 
    half_res=True,
    testskip=8
)
i_train, i_val, i_test = i_split
H, W, focal = hwf
H, W = int(H), int(W)

print(f"Rendering at {H}x{W} (half resolution = faster)")

K = torch.Tensor([[focal, 0, 0.5*W], [0, focal, 0.5*H], [0, 0, 1]])

# Only render 3 test images
test_indices = i_test[:3]
print(f"Will render {len(test_indices)} test images")

print("\nLoading checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
print(f"Checkpoint loaded from: {checkpoint_path}")

# Create models
embed_fn, input_ch = get_embedder(10, 0)
embeddirs_fn, input_ch_views = get_embedder(4, 0)

model = NeRF(D=8, W=256, input_ch=input_ch, output_ch=4, skips=[4], 
             input_ch_views=input_ch_views, use_viewdirs=True)
model_fine = NeRF(D=8, W=256, input_ch=input_ch, output_ch=4, skips=[4], 
                  input_ch_views=input_ch_views, use_viewdirs=True)

model.load_state_dict(checkpoint['network_fn_state_dict'])
model_fine.load_state_dict(checkpoint['network_fine_state_dict'])

model.eval()
model_fine.eval()

print("✓ Models loaded")

# Network query function
def network_query_fn(inputs, viewdirs, network_fn):
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)
    
    if viewdirs is not None:
        input_dirs = viewdirs[:,None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        embedded = torch.cat([embedded, embedded_dirs], -1)
    
    outputs_flat = network_fn(embedded)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs

# Rendering kwargs
render_kwargs_test = {
    'network_query_fn': network_query_fn,
    'network_fn': model,
    'network_fine': model_fine,
    'N_samples': 32,
    'N_importance': 64,
    'perturb': False,
    'white_bkgd': True,
    'raw_noise_std': 0.,
    'near': 2.,
    'far': 6.,
    'use_viewdirs': True
}

os.makedirs('nerf_renders', exist_ok=True)

print("\nStarting render (CPU - will take ~7 mins per image)...\n")

with torch.no_grad():
    for i, idx in enumerate(test_indices):
        start = time.time()
        print(f"Rendering image {i+1}/3 (index {idx})...")
        
        pose = torch.Tensor(poses[idx, :3, :4])
        
        rgb, disp, acc, extras = render(H, W, K, chunk=256, c2w=pose, **render_kwargs_test)
        
        rgb8 = (np.clip(rgb.numpy(), 0, 1) * 255).astype(np.uint8)
        filename = f'nerf_renders/test_{idx:03d}.png'
        imageio.imwrite(filename, rgb8)
        
        elapsed = time.time() - start
        print(f"✓ Image {i+1}/3 saved to {filename} ({elapsed:.1f}s)\n")

print("\n" + "="*60)
print("✓✓✓ RENDERING COMPLETE!")
print("="*60)
print("\nImages saved to: nerf_renders/")

## Step 10: View Results

In [None]:
from IPython.display import Image, display
import glob

imgs = sorted(glob.glob('nerf_renders/*.png'))

print(f"Rendered {len(imgs)} images:\n")

for img in imgs:
    print(f"\n{img}:")
    display(Image(img, width=500))

## Step 11: Download Images

In [None]:
from google.colab import files

# Zip renders
!zip -r nerf_final_images.zip nerf_renders/

print("✓ Images zipped")

# Download
files.download('nerf_final_images.zip')

print("\n✓ Download complete!")

---

## If Images Are Still White/Bad:

**Your checkpoint didn't train properly.** Options:

1. **Retrain from scratch** (~6 hours for proper training)
2. **Use metrics only** + teammate's 3DGS images
3. **Ask teammate to train NeRF** with nerfstudio (~30 mins)

The checkpoint is from 10k iterations which may not be enough for good quality.