# Neural Radiance Fields

This notebook was written while referencing the original NeRF code so as to visualize step by step on how NeRF works.

## Import Dependencies and Check CUDA

In [2]:
# Standard libraries
import sys
import os

# Third party libraries
import imageio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Set seed
np.random.seed(0)
DEBUG = False
torch.set_default_dtype(torch.float32)

In [3]:
# Check device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
import utils
# Load experiment config from yaml
EXPERIMENT_NAME = "llff_fern"
folder_path = os.path.join("yaml", EXPERIMENT_NAME)
print(f"Targeted yaml folder: {folder_path}")

# Load the yaml data for use later in terms of args.
args = utils.load_or_create_config(folder_path)
print(f"Experiment name: {args["expname"]}")


Targeted yaml folder: yaml\llff_fern
Loading configuration from yaml\llff_fern
Configuration validation passed! Arguments are valid and correctly set.
Experiment name: nerf_experiment


In [None]:
from batch import batchify
"""
Prepare 3D sample points for a NeRF-style network by applying positional encodings,
run the network on these encodings in memory-safe chunks, and then reshape the results
back to the original sampling layout.

High-level intuition:
- We often sample many 3D points (xyz) per ray and, optionally, use a per-ray viewing
  direction. Raw coordinates are hard for small MLPs to learn high-frequency detail,
  so we first apply a positional encoding that maps them to a higher-dimensional space.
- We flatten everything to a big batch so the network can process all samples uniformly.
- To avoid running out of memory, we split this big batch into chunks and process them
  sequentially, then stitch the outputs back together and restore the original shape.
"""

def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """
    Prepare inputs for a NeRF-style MLP and apply the network in chunks.

    Conceptual overview:
    - Positions (xyz) and, optionally, viewing directions are first positional-encoded
      (a deterministic mapping to a higher-dimensional space using sin/cos at multiple
      frequencies). This helps the MLP represent fine details and sharp changes.
    - We flatten leading dimensions so all samples are processed as a single batch.
    - To keep memory usage in check, we process this batch in chunks (netchunk).
    - Finally, we reshape outputs to match the original sampling layout.

    Args:
        inputs (torch.Tensor): Sample positions with shape [..., Cpos], typically Cpos = 3.
            Example: [N_rays, N_samples, 3]. The leading dimensions can be any shape.
        viewdirs (Optional[torch.Tensor]): Per-ray viewing directions with shape
            [N_rays, Cdir] (typically Cdir = 3), or None if not using view-dependent effects.
            When provided, each ray direction is broadcast to all samples along that ray.
        fn (Callable[[torch.Tensor], torch.Tensor]): Neural network (e.g., NeRF MLP) that
            consumes encoded features and returns outputs per sample.
        embed_fn (Callable[[torch.Tensor], torch.Tensor]): Positional encoder for positions;
            maps [*, Cpos] -> [*, Cpos_enc].
        embeddirs_fn (Optional[Callable[[torch.Tensor], torch.Tensor]]): Positional encoder
            for directions; maps [*, Cdir] -> [*, Cdir_enc]. Only used if viewdirs is not None.
        netchunk (int): Maximum number of samples to process per chunk to limit peak memory.

    Returns:
        torch.Tensor: Network outputs with shape [..., Cout], where the leading dimensions
        match those of `inputs` (excluding its last channel), and Cout is determined by `fn`.
    """
    # Flatten all leading dimensions so we have a simple [N, Cpos] batch of positions.
    # N is the total number of samples across rays and per-ray samples.
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])

    # Positional-encode the flattened positions (e.g., apply sin/cos at multiple frequencies).
    # This expands each 3D input into a richer, higher-dimensional representation
    # that makes it easier for the MLP to model fine spatial detail.
    embedded = embed_fn(inputs_flat)

    # If using view-dependent appearance (e.g., specular highlights that vary with direction),
    # we also encode per-ray viewing directions and concatenate them with position encodings.
    if viewdirs is not None:
        # Insert a length-1 axis, then broadcast each ray direction across all samples on that ray
        # so that every sample point along a ray shares the same view direction.
        input_dirs = viewdirs[:, None].expand(inputs.shape)

        # Flatten directions to align with the flattened positions: [N, Cdir].
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])

        # Positional-encode viewing directions in the same spirit as positions.
        embedded_dirs = embeddirs_fn(input_dirs_flat)

        # Concatenate encoded positions and encoded directions along the feature/channel axis.
        embedded = torch.cat([embedded, embedded_dirs], -1)

    # Apply the network to the encoded features in memory-safe chunks along the batch dimension.
    # This prevents out-of-memory errors when the total number of samples is very large.
    outputs_flat = batchify(fn, netchunk)(embedded)

    # Restore the original leading shape (e.g., [N_rays, N_samples]) and append the output channels.
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])

    return outputs

## Instantiate NeRF
This section creates a function to instantiate NeRF.

In [None]:
from embedder import get_embedder
from nerf import NeRF

def create_nerf(args):
    """Instantiate NeRF's MLP model.
    """
    embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
    output_ch = 5 if args.N_importance > 0 else 4
    skips = [4]
    model = NeRF(D=args.netdepth, W=args.netwidth,
                 input_ch=input_ch, output_ch=output_ch, skips=skips,
                 input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
    grad_vars = list(model.parameters())

    model_fine = None
    if args.N_importance > 0:
        model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
                          input_ch=input_ch, output_ch=output_ch, skips=skips,
                          input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
        grad_vars += list(model_fine.parameters())

    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=args.netchunk)

    # Create optimizer
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

    start = 0
    basedir = args.basedir
    expname = args.expname

    ##########################

    # Load checkpoints
    if args.ft_path is not None and args.ft_path!='None':
        ckpts = [args.ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]

    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not args.no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        # Load model
        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])

    ##########################

    render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : args.perturb,
        'N_importance' : args.N_importance,
        'network_fine' : model_fine,
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'use_viewdirs' : args.use_viewdirs,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

    # NDC only good for LLFF-style forward facing data
    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer