# Neural Radiance Fields

This notebook was written while referencing the original NeRF code so as to visualize step by step on how NeRF works. Some functions have been refactored for better understanding but as much as possible, none of the actual code was changed.

## Attribution and Citation
- Original paper: Mildenhall, Ben; Srinivasan, Pratul P.; Tancik, Matthew; Barron, Jonathan T.; Ramamoorthi, Ravi; Ng, Ren. “NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis.” ECCV 2020. [arXiv:2003.08934](https://arxiv.org/abs/2003.08934)
- Reference code: [bmild/nerf](https://github.com/bmild/nerf) (MIT License). Copyright © 2020 Ben Mildenhall.

```bibtex
@inproceedings{mildenhall2020nerf,
  title     = {NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
  author    = {Mildenhall, Ben and Srinivasan, Pratul P. and Tancik, Matthew and Barron, Jonathan T. and Ramamoorthi, Ravi and Ng, Ren},
  booktitle = {European Conference on Computer Vision (ECCV)},
  year      = {2020}
}
```

## Import Dependencies and Check CUDA

In [10]:
# Standard libraries
import sys
import os

# Third party libraries
import imageio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

# Helper libraries
from render import *
from embedder import *
from nerf_helpers import *
import utils

# Loaders
from load_llff import *
from load_blender import *
from load_deepvoxels import * 
from load_LINEMOD import *

# Set seed
np.random.seed(0)
DEBUG = False
torch.set_default_dtype(torch.float32)

In [11]:
# Check device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [12]:
# Load experiment config from yaml
EXPERIMENT_NAME = "llff_fern"
folder_path = os.path.join("yaml", EXPERIMENT_NAME)
print(f"Targeted yaml folder: {folder_path}")

# Load the yaml data for use later in terms of args.
args = utils.load_or_create_config(folder_path)

print(f"Experiment name: {args["expname"]}")

Targeted yaml folder: yaml\llff_fern
Loading configuration from yaml\llff_fern
Configuration validation passed! Arguments are valid and correctly set.
Experiment name: fern_test


In [13]:
# Convert dict-style args to dot-access with recursive wrapping
class AttrDict(dict):
    """Dictionary with attribute-style access that recursively wraps nested dicts/lists.

    Example:
        d = AttrDict.from_obj({"a": {"b": 1}, "c": [{"d": 2}]})
        d.a.b == 1
        d.c[0].d == 2
        d.new_key = 3  # also writes to the underlying dict
    """
    __slots__ = ()

    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError as e:
            raise AttributeError(f"No such attribute: {name}") from e

    def __setattr__(self, name, value):
        self[name] = value

    @classmethod
    def from_obj(cls, obj):
        if isinstance(obj, dict):
            return cls({k: cls.from_obj(v) for k, v in obj.items()})
        if isinstance(obj, list):
            return [cls.from_obj(v) for v in obj]
        return obj

# If args came from YAML as a plain dict, wrap it for dot access
try:
    if isinstance(args, dict):
        args = AttrDict.from_obj(args)
        print("Converted args dict to AttrDict (dot-access enabled). Example: args.expname ->", args.expname)
except NameError:
    # If this cell runs before args is defined, it's a no-op
    pass


Converted args dict to AttrDict (dot-access enabled). Example: args.expname -> fern_test


In [14]:
"""
Prepare 3D sample points for a NeRF-style network by applying positional encodings,
run the network on these encodings in memory-safe chunks, and then reshape the results
back to the original sampling layout.

High-level intuition:
- We often sample many 3D points (xyz) per ray and, optionally, use a per-ray viewing
  direction. Raw coordinates are hard for small MLPs to learn high-frequency detail,
  so we first apply a positional encoding that maps them to a higher-dimensional space.
- We flatten everything to a big batch so the network can process all samples uniformly.
- To avoid running out of memory, we split this big batch into chunks and process them
  sequentially, then stitch the outputs back together and restore the original shape.
"""

def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """
    Prepare inputs for a NeRF-style MLP and apply the network in chunks.

    Conceptual overview:
    - Positions (xyz) and, optionally, viewing directions are first positional-encoded
      (a deterministic mapping to a higher-dimensional space using sin/cos at multiple
      frequencies). This helps the MLP represent fine details and sharp changes.
    - We flatten leading dimensions so all samples are processed as a single batch.
    - To keep memory usage in check, we process this batch in chunks (netchunk).
    - Finally, we reshape outputs to match the original sampling layout.

    Args:
        inputs (torch.Tensor): Sample positions with shape [..., Cpos], typically Cpos = 3.
            Example: [N_rays, N_samples, 3]. The leading dimensions can be any shape.
        viewdirs (Optional[torch.Tensor]): Per-ray viewing directions with shape
            [N_rays, Cdir] (typically Cdir = 3), or None if not using view-dependent effects.
            When provided, each ray direction is broadcast to all samples along that ray.
        fn (Callable[[torch.Tensor], torch.Tensor]): Neural network (e.g., NeRF MLP) that
            consumes encoded features and returns outputs per sample.
        embed_fn (Callable[[torch.Tensor], torch.Tensor]): Positional encoder for positions;
            maps [*, Cpos] -> [*, Cpos_enc].
        embeddirs_fn (Optional[Callable[[torch.Tensor], torch.Tensor]]): Positional encoder
            for directions; maps [*, Cdir] -> [*, Cdir_enc]. Only used if viewdirs is not None.
        netchunk (int): Maximum number of samples to process per chunk to limit peak memory.

    Returns:
        torch.Tensor: Network outputs with shape [..., Cout], where the leading dimensions
        match those of `inputs` (excluding its last channel), and Cout is determined by `fn`.
    """
    # Flatten all leading dimensions so we have a simple [N, Cpos] batch of positions.
    # N is the total number of samples across rays and per-ray samples.
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])

    # Positional-encode the flattened positions (e.g., apply sin/cos at multiple frequencies).
    # This expands each 3D input into a richer, higher-dimensional representation
    # that makes it easier for the MLP to model fine spatial detail.
    embedded = embed_fn(inputs_flat)

    # If using view-dependent appearance (e.g., specular highlights that vary with direction),
    # we also encode per-ray viewing directions and concatenate them with position encodings.
    if viewdirs is not None:
        # Insert a length-1 axis, then broadcast each ray direction across all samples on that ray
        # so that every sample point along a ray shares the same view direction.
        input_dirs = viewdirs[:, None].expand(inputs.shape)

        # Flatten directions to align with the flattened positions: [N, Cdir].
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])

        # Positional-encode viewing directions in the same spirit as positions.
        embedded_dirs = embeddirs_fn(input_dirs_flat)

        # Concatenate encoded positions and encoded directions along the feature/channel axis.
        embedded = torch.cat([embedded, embedded_dirs], -1)

    # Apply the network to the encoded features in memory-safe chunks along the batch dimension.
    # This prevents out-of-memory errors when the total number of samples is very large.
    outputs_flat = batchify(fn, netchunk)(embedded)

    # Restore the original leading shape (e.g., [N_rays, N_samples]) and append the output channels.
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])

    return outputs

## Instantiate NeRF
This section creates a function to instantiate NeRF.

In [15]:
from nerf import NeRF

def create_nerf(args):
    """Instantiate NeRF's MLP model.
    """
    embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
    output_ch = 5 if args.N_importance > 0 else 4
    skips = [4]
    model = NeRF(D=args.netdepth, W=args.netwidth,
                 input_ch=input_ch, output_ch=output_ch, skips=skips,
                 input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
    grad_vars = list(model.parameters())

    model_fine = None
    if args.N_importance > 0:
        model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
                          input_ch=input_ch, output_ch=output_ch, skips=skips,
                          input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
        grad_vars += list(model_fine.parameters())

    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=args.netchunk)

    # Create optimizer
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

    start = 0
    basedir = args.basedir
    expname = args.expname

    ##########################

    # Load checkpoints
    if args.ft_path is not None and args.ft_path!='None':
        ckpts = [args.ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]

    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not args.no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        # Load model
        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])

    ##########################

    render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : args.perturb,
        'N_importance' : args.N_importance,
        'network_fine' : model_fine,
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'use_viewdirs' : args.use_viewdirs,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

    # NDC only good for LLFF-style forward facing data
    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer

## Training Loop

In [16]:
def train():
    """
    End-to-end NeRF training loop.

    High-level overview for newcomers:
    - Load a dataset of posed images (e.g., LLFF/Blender/LINEMOD). Each image comes with a camera pose.
    - Create a NeRF model (coarse and optionally fine MLPs) and a renderer.
    - On each iteration, sample camera rays and their target RGB values from the dataset.
    - Render rays with the NeRF model via volumetric rendering (accumulate colors along the ray).
    - Compute a reconstruction loss (e.g., MSE) against ground-truth pixels and optimize the networks.
    - Periodically render validation trajectories and/or save snapshots.

    Key concepts:
    - Rays: For each pixel, we cast a ray into the scene with origin/direction computed from intrinsics and pose.
    - Sampling: We sample multiple points along each ray (coarse). Optionally resample (fine) where the scene is likely.
    - Volume rendering: Convert per-point density+color to opacity weights and composite to a final pixel color.
    - Hierarchical sampling: A second pass focuses samples where the coarse pass is confident the scene exists.
    """

    # ----------------------
    # 1) Load data and choose near/far bounds depending on dataset
    # ----------------------
    K = None
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify
        )
        hwf = poses[0,:3,-1]           # (H, W, focal)
        poses = poses[:,:3,:4]         # Only keep rotation+translation (3x4) per pose
        print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        # Optional LLFF holdout: use every N-th image as test
        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([i for i in np.arange(int(images.shape[0]))
                            if (i not in i_test and i not in i_val)])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            # If not using NDC (e.g., inward-facing/360 scenes), near/far from bounds
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.
        else:
            # Forward-facing (LLFF) uses NDC, so near/far are normalized
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip
        )
        print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        # Standard near/far for Blender synthetic scenes
        near = 2.
        far = 6.

        # Composite over white if requested (makes background white instead of black)
        if args.white_bkgd:
            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]

    elif args.dataset_type == 'LINEMOD':
        images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(
            args.datadir, args.half_res, args.testskip
        )
        print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
        print(f'[CHECK HERE] near: {near}, far: {far}.')
        i_train, i_val, i_test = i_split

        if args.white_bkgd:
            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]

    elif args.dataset_type == 'deepvoxels':
        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip
        )
        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        # DeepVoxels scenes define a hemisphere radius; near/far around it
        hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
        near = hemi_R-1.
        far = hemi_R+1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # ----------------------
    # 2) Prepare intrinsics (H, W, focal) and default K if not provided
    # ----------------------
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if K is None:
        # Construct a pinhole intrinsics matrix assuming principal point at image center
        K = np.array([
            [focal, 0, 0.5*W],
            [0, focal, 0.5*H],
            [0, 0, 1]
        ])

    # If we are evaluating on the test set, use the corresponding subset of poses
    if args.render_test:
        render_poses = np.array(poses[i_test])

    # ----------------------
    # 3) Logging setup and persistence of configs
    # ----------------------
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)

    # Save the parsed args for reproducibility
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg, attr in sorted(args.items()):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    # Save the YAML configuration that produced these args
    f_yaml = os.path.join(basedir, expname, 'config.yaml')
    try:
        utils.save_yaml(dict(args), f_yaml)
    except Exception:
        # Fallback: write a simple YAML dump directly
        with open(f_yaml, 'w', encoding='utf-8') as yf:
            import yaml as _yaml
            _yaml.dump(dict(args), yf, default_flow_style=False, indent=2)

    # ----------------------
    # 4) Create NeRF models and optimizer
    # ----------------------
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
    global_step = start

    # Near/far bounds are used by the renderer; update both train and test configs
    bds_dict = { 'near' : near, 'far' : far }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move the camera trajectory used for rendering validation videos to the GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # ----------------------
    # 5) Short-circuit: render only mode
    # ----------------------
    if args.render_only:
        print('RENDER ONLY')
        with torch.no_grad():
            if args.render_test:
                # Switch to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None

            testsavedir = os.path.join(
                basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)
            )
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            # Render a path and save to video
            rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test,
                                   gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)

            return

    # ----------------------
    # 6) Prepare random-ray batching (optional) and move data to GPU
    # ----------------------
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # Precompute all rays for all training images, then shuffle mini-batches each step
        print('get rays')
        rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:,None]], 1)                   # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4])                         # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0)                 # train images only
        rays_rgb = np.reshape(rays_rgb, [-1,3,3])                              # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        print('done')
        i_batch = 0

    # Move arrays/tensors to GPU for training
    if use_batching:
        images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)

    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # ----------------------
    # 7) Main optimization loop
    # ----------------------
    # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))  # Optional TB logging

    start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()

        # Sample a batch of rays and target colors
        if use_batching:
            # Random over all images (global batching)
            batch = rays_rgb[i_batch:i_batch+N_rand]  # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            # Move sliding window; reshuffle at epoch end
            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0
        else:
            # Random rays from a randomly chosen training image (per-image batching)
            img_i = np.random.choice(i_train)
            target = images[img_i]
            target = torch.Tensor(target).to(device)
            pose = poses[img_i, :3,:4]

            if N_rand is not None:
                # Compute per-pixel rays for this image, then sample N_rand of them
                rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

                if i < args.precrop_iters:
                    # Optional: focus early training on the image center (stabilizes training)
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 
                            torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
                        ), -1)
                    if i == start:
                        print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")                
                else:
                    coords = torch.stack(
                        torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1
                    )  # (H, W, 2)

                coords = torch.reshape(coords, [-1,2])                         # (H*W, 2)
                select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)  # (N_rand,)
                select_coords = coords[select_inds].long()                     # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]      # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]      # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0], select_coords[:, 1]]    # (N_rand, 3)

        # ---- Core rendering + loss ----
        rgb, disp, acc, extras = render(
            H, W, K, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train
        )

        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        # If hierarchical sampling is enabled, include the coarse-pass loss
        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()
        optimizer.step()

        # --- Learning rate decay (exponential) ---
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate

        dt = time.time()-time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")

        # ----------------------
        # 8) Periodic logging, checkpointing, and visualization
        # ----------------------
        if i%args.i_weights==0:
            # Save model checkpoints for resuming or analysis
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save({
                'global_step': global_step,
                'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
                'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, path)
            print('Saved checkpoints at', path)

        if i%args.i_video==0 and i > 0:
            # Render a validation trajectory and write MP4 previews
            with torch.no_grad():
                rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

            # If you want to visualize view-dependent effects, you can fix the camera position
            # and vary only the view direction (see commented example in original code).

        if i%args.i_testset==0 and i > 0:
            # Render the held-out test set and save frames
            testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            with torch.no_grad():
                render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk,
                            render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
            print('Saved test set')

        if i%args.i_print==0:
            tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
            # Additional TensorBoard logging code is kept in comments in the original.

        global_step += 1

In [17]:
# Insert training parameters
N_iters = 500 + 1

In [18]:
train()

Loaded image data (378, 504, 3, 20) [378.         504.         407.56579161]
Loaded ./data/nerf_llff_data/fern 16.985296178676084 80.00209740336334
recentered (3, 5)
[[ 1.0000000e+00  0.0000000e+00  0.0000000e+00 -2.9802323e-09]
 [ 0.0000000e+00  1.0000000e+00 -1.8730975e-09 -9.6857544e-09]
 [-0.0000000e+00  1.8730975e-09  1.0000000e+00 -7.4505807e-10]]
Data:
(20, 3, 5) (20, 378, 504, 3) (20, 2)
HOLDOUT view is 12
Loaded llff (20, 378, 504, 3) (120, 3, 5) [378.     504.     407.5658] ./data/nerf_llff_data/fern
Auto LLFF holdout, 8
DEFINING BOUNDS
NEAR FAR 0.0 1.0
Found ckpts []
get rays
done, concats
shuffle rays
done
Begin
TRAIN views are [ 1  2  3  4  5  6  7  9 10 11 12 13 14 15 17 18 19]
TEST views are [ 0  8 16]
VAL views are [ 0  8 16]


 20%|██        | 101/500 [00:09<00:36, 11.00it/s]

[TRAIN] Iter: 100 Loss: 0.0506649985909462  PSNR: 15.991083145141602


 40%|████      | 201/500 [00:18<00:27, 11.05it/s]

[TRAIN] Iter: 200 Loss: 0.04145825281739235  PSNR: 16.938528060913086


 60%|██████    | 301/500 [00:27<00:18, 11.00it/s]

[TRAIN] Iter: 300 Loss: 0.03398731350898743  PSNR: 17.717384338378906


 80%|████████  | 401/500 [00:36<00:08, 11.08it/s]

[TRAIN] Iter: 400 Loss: 0.02866814285516739  PSNR: 18.31490707397461


100%|██████████| 500/500 [00:45<00:00, 11.00it/s]

[TRAIN] Iter: 500 Loss: 0.0285334549844265  PSNR: 18.495981216430664



