# Neural Radiance Fields

This notebook was written while referencing the original NeRF code so as to visualize step by step on how NeRF works. Some functions have been refactored for better understanding but as much as possible, none of the actual code was changed.

## Attribution and Citation
- Original paper: Mildenhall, Ben; Srinivasan, Pratul P.; Tancik, Matthew; Barron, Jonathan T.; Ramamoorthi, Ravi; Ng, Ren. “NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis.” ECCV 2020. [arXiv:2003.08934](https://arxiv.org/abs/2003.08934)
- Reference code: [bmild/nerf](https://github.com/bmild/nerf) (MIT License). Copyright © 2020 Ben Mildenhall.

```bibtex
@inproceedings{mildenhall2020nerf,
  title     = {NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis},
  author    = {Mildenhall, Ben and Srinivasan, Pratul P. and Tancik, Matthew and Barron, Jonathan T. and Ramamoorthi, Ravi and Ng, Ren},
  booktitle = {European Conference on Computer Vision (ECCV)},
  year      = {2020}
}
```

## Import Dependencies and Check CUDA

In [1]:
# Standard libraries
import sys
import os
import time

# Third party libraries
import imageio
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

# Helper libraries
from embedder import *
from nerf_helpers import *
import utils

# Loaders
from load_llff import *
from load_blender import *
from load_deepvoxels import * 
from load_LINEMOD import *

# Set seed
np.random.seed(0)
DEBUG = False
torch.set_default_dtype(torch.float32)

In [2]:
# Check device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Load experiment config from yaml
EXPERIMENT_NAME = "ship_blender200k_fullres_higher_samples"
folder_path = os.path.join("yaml", EXPERIMENT_NAME)
print(f"Targeted yaml folder: {folder_path}")

# Load the yaml data for use later in terms of args.
args = utils.load_or_create_config(folder_path)

print(f"Experiment name: {args["expname"]}")

Targeted yaml folder: yaml\ship_blender200k_fullres_higher_samples
Loading configuration from yaml\ship_blender200k_fullres_higher_samples
Configuration validation passed! Arguments are valid and correctly set.
Experiment name: ship_blender200k_fullres_higher_samples


In [4]:
# Convert dict-style args to dot-access with recursive wrapping
class AttrDict(dict):
    """Dictionary with attribute-style access that recursively wraps nested dicts/lists.

    Example:
        d = AttrDict.from_obj({"a": {"b": 1}, "c": [{"d": 2}]})
        d.a.b == 1
        d.c[0].d == 2
        d.new_key = 3  # also writes to the underlying dict
    """
    __slots__ = ()

    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError as e:
            raise AttributeError(f"No such attribute: {name}") from e

    def __setattr__(self, name, value):
        self[name] = value

    @classmethod
    def from_obj(cls, obj):
        if isinstance(obj, dict):
            return cls({k: cls.from_obj(v) for k, v in obj.items()})
        if isinstance(obj, list):
            return [cls.from_obj(v) for v in obj]
        return obj

# If args came from YAML as a plain dict, wrap it for dot access
try:
    if isinstance(args, dict):
        args = AttrDict.from_obj(args)
        print("Converted args dict to AttrDict (dot-access enabled). Example: args.expname ->", args.expname)
except NameError:
    # If this cell runs before args is defined, it's a no-op
    pass


Converted args dict to AttrDict (dot-access enabled). Example: args.expname -> ship_blender200k_fullres_higher_samples


## Render

NeRF rendering utilities

This part contains helper functions to render volumes using Neural Radiance Fields (NeRF).
If you're new to NeRF, here's the high-level idea you'll see reflected in the code:

- A NeRF model takes a 3D point (and often a viewing direction) and predicts color (RGB)
  and density (sigma) at that point.
- To render an image, we cast a ray through each pixel, sample many points along the ray,
  evaluate the network at those points, and composite the colors using volume rendering.
- We optionally do this in two passes: a coarse pass to find where the scene is, then a
  fine pass that samples more densely in the important regions (importance sampling).

The functions below help with:
- Splitting big tensors/ray sets into smaller chunks to avoid out-of-memory issues.
- Converting raw network outputs into rendered RGB/depth/opacity via volume rendering.
- Orchestrating the full per-ray rendering with optional hierarchical sampling.


In [5]:
"""
The following function wraps a function fn so it processes a large input tensor in smaller chunks
to reduce peak memory usage, then stitches the result together.

This splits inputs along the first dimension into slices of size chunk and applies fn to each slice.
Then it concatenates the results back together along the first dimension.

This reduces gpu and cpu spikes when fn is heavy (i.e neural network forward pass) and preserves
autograd, as gradients flow through torch.cat and slicing.
"""
def batchify(fn, chunk):
    """Wrap a function so it runs on smaller input chunks to reduce peak memory.

    This is useful for heavy functions like neural network forward passes. Instead of
    evaluating the entire input tensor in one go, we split it along the first dimension
    into slices of size ``chunk`` and then concatenate the results.

    Args:
        fn (Callable[[Tensor], Tensor]): The function to run on chunks.
        chunk (Optional[int]): Number of rows to process per call. If ``None``,
            no chunking is performed.

    Returns:
        Callable[[Tensor], Tensor]: A wrapper that applies ``fn`` on input chunks and
        stitches the outputs along the first dimension.
    """
    if chunk is None:
        return fn
    def ret(inputs):
        return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret


In [6]:
def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False, pytest=False):
    """
    Convert raw network predictions along rays into rendered outputs using
    volumetric rendering (accumulation over samples).

    High-level intuition (NeRF volume rendering):
    - The network predicts, per sampled point along a ray, an RGB value and a density
      (often called sigma). Density indicates how much light is absorbed/emitted there.
    - We turn densities into per-sample opacities (alpha) based on the distance between
      adjacent samples along the ray.
    - We compute transmittance (how much light makes it to a sample without being blocked)
      via a cumulative product, then form per-sample weights = transmittance * alpha.
    - We composite colors, depths, and other quantities by weighted sums over samples.

    Args:
        raw (torch.Tensor): [N_rays, N_samples, 4] raw predictions per sample. The first
            3 channels are RGB logits (before sigmoid), and the last channel is density (sigma).
        z_vals (torch.Tensor): [N_rays, N_samples] sample depths or t-values along each ray.
        rays_d (torch.Tensor): [N_rays, 3] direction vectors for each ray. Used to scale
            step sizes from parametric units to metric distances.
        raw_noise_std (float): Stddev of Gaussian noise added to sigma during training for
            regularization. Set to 0.0 at eval time.
        white_bkgd (bool): If True, composite the result over a white background (useful
            for synthetic datasets rendered on white).
        pytest (bool): If True, use deterministic numpy noise for reproducible tests.

    Returns:
        tuple:
            - rgb_map (torch.Tensor): [N_rays, 3] rendered RGB color per ray.
            - disp_map (torch.Tensor): [N_rays] disparity (inverse depth) per ray.
            - acc_map (torch.Tensor): [N_rays] accumulated opacity per ray (sum of weights).
            - weights (torch.Tensor): [N_rays, N_samples] per-sample contribution weights.
            - depth_map (torch.Tensor): [N_rays] expected depth per ray.
    """
    # Map density (sigma) and step size (distance between samples) to opacity (alpha):
    #   alpha = 1 - exp(-relu(sigma) * delta)
    # relu ensures sigma is non-negative, as negative density is not physical.
    raw2alpha = lambda raw_sigma, dists, act_fn=F.relu: 1.0 - torch.exp(-act_fn(raw_sigma) * dists)

    # Compute distances between adjacent samples along each ray in z (or t) space.
    # Shape after diff: [N_rays, N_samples-1]
    dists = z_vals[..., 1:] - z_vals[..., :-1]

    # For the last sample on each ray, append a very large distance so that its
    # contribution is properly modeled as the ray exiting the volume.
    # Resulting shape: [N_rays, N_samples]
    dists = torch.cat([dists, torch.tensor([1e10], device=z_vals.device, dtype=z_vals.dtype).expand(dists[..., :1].shape)], dim=-1)

    # Convert parametric distances to metric distances by multiplying by the ray length.
    # This accounts for non-unit ray directions. rays_d[..., None, :] has shape [N_rays, 1, 3]
    # and we take its L2 norm to scale each sample distance.
    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

    # Convert raw RGB logits to [0,1] colors per sample.
    rgb = torch.sigmoid(raw[..., :3])  # [N_rays, N_samples, 3]

    # Optional noise added to densities during training for regularization.
    # Ensure raw_noise_std is a float (configs may pass strings).
    try:
        noise_std = float(raw_noise_std)
    except (TypeError, ValueError):
        noise_std = 0.0

    noise = 0.0
    if noise_std > 0.0:
        noise = torch.randn(raw[..., 3].shape, device=raw.device, dtype=raw.dtype) * noise_std

        # Deterministic noise path for unit tests.
        if pytest:
            np.random.seed(0)
            noise_np = np.random.rand(*list(raw[..., 3].shape)) * noise_std
            noise = torch.tensor(noise_np, device=raw.device, dtype=raw.dtype)

    # Opacity per sample from density and distance.
    alpha = raw2alpha(raw[..., 3] + noise, dists)  # [N_rays, N_samples]
    weights = alpha * torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1), device=alpha.device, dtype=alpha.dtype), 1.-alpha + 1e-10], -1), -1)[:, :-1]

    # Rendered color is the weighted sum of per-sample colors along the ray.
    rgb_map = torch.sum(weights[..., None] * rgb, dim=-2)  # [N_rays, 3]

    # Expected depth is the weighted sum of sample depths.
    depth_map = torch.sum(weights * z_vals, dim=-1)

    # Disparity is inverse depth. We divide expected depth by total weight (visibility)
    # and guard with epsilon to avoid divide-by-zero when the ray hits nothing.
    denom = torch.max(1e-10 * torch.ones_like(depth_map), torch.sum(weights, dim=-1))
    disp_map = 1.0 / torch.clamp(depth_map / denom, min=1e-10)

    # Accumulated opacity along the ray (how much of the ray got "stopped").
    acc_map = torch.sum(weights, dim=-1)

    # If the scene assumes a white background, composite the missing transmittance as white.
    if white_bkgd:
        rgb_map = rgb_map + (1.0 - acc_map[..., None])

    return rgb_map, disp_map, acc_map, weights, depth_map

In [7]:
def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False):
    """Render a bundle of rays using NeRF-style volume rendering.

    High-level steps for each ray:
    1) Sample N points between near and far along the ray (evenly in depth or inverse depth).
    2) Query the network at those 3D points (and optionally view directions) to get raw RGB+sigma.
    3) Convert raw predictions to colors via volume rendering (accumulate with alphas/weights).
    4) If enabled, run hierarchical (importance) sampling: take a second set of samples drawn
       from a PDF defined by the coarse weights, re-evaluate the network (fine), and re-render.

    Args:
        ray_batch (torch.Tensor): [num_rays, Cray]. Per-ray data packed together. The first
            3 entries are ray origins, next 3 are ray directions, next 2 are near/far bounds,
            and the last 3 (if present) are unit view directions for view-dependent effects.
        network_fn (Callable): The coarse NeRF MLP. Given points (and viewdirs), predicts
            raw RGB (logits) and density (sigma).
        network_query_fn (Callable): A helper that formats inputs and calls the network.
        N_samples (int): Number of stratified samples for the coarse pass.
        retraw (bool): If True, also return the raw outputs from the last pass.
        lindisp (bool): If True, sample uniformly in inverse depth (disparity) instead of depth.
            This concentrates samples near the camera, helpful for scenes with large depth ranges.
        perturb (float): If > 0, enable stratified sampling noise during training for anti-aliasing.
        N_importance (int): Extra samples for the fine pass (hierarchical sampling). 0 disables it.
        network_fine (Optional[Callable]): A separate fine MLP. If None, reuse ``network_fn``.
        white_bkgd (bool): If True, composite the result over a white background.
        raw_noise_std (float): Stddev of Gaussian noise added to density during training.
        verbose (bool): If True, print additional debug information.
        pytest (bool): If True, make randomness deterministic for unit tests.

    Returns:
        Dict[str, torch.Tensor]: Always includes:
            - ``rgb_map`` [num_rays, 3]: Rendered color from the last pass (fine if enabled).
            - ``disp_map`` [num_rays]: Disparity (1/depth) from the last pass.
            - ``acc_map`` [num_rays]: Accumulated opacity from the last pass.
        Optionally includes:
            - ``raw`` [num_rays, num_samples, 4]: Raw outputs of the last pass if ``retraw``.
            - ``rgb0``, ``disp0``, ``acc0``: Coarse pass results when ``N_importance > 0``.
            - ``z_std`` [num_rays]: Std. dev. of fine samples (measures sampling concentration).
    """
    N_rays = ray_batch.shape[0]
   # Unpack packed ray data: origins, directions, near/far bounds, and optional viewdirs.
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
    viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1] # [-1,1]

    # Step 1: Choose parametric sample positions t in [0, 1] and map to depths z in [near, far].
    t_vals = torch.linspace(0., 1., steps=N_samples, device=ray_batch.device)
    if not lindisp:
        # Uniform samples in depth
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        # Uniform samples in inverse depth (places more samples closer to the camera)
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

    z_vals = z_vals.expand([N_rays, N_samples])

    if perturb > 0.:
        # During training, jitter samples within each interval for stratified sampling.
        # This reduces aliasing and improves robustness.
        # Get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # Draw stratified samples inside those intervals
        t_rand = torch.rand(z_vals.shape, device=z_vals.device)

        # Pytest, overwrite u with numpy's fixed random numbers
        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.tensor(t_rand, device=z_vals.device, dtype=z_vals.dtype)

        z_vals = lower + (upper - lower) * t_rand

    # Compute 3D sample locations along each ray: o + t*d for each sampled depth t (z_vals).
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]

    # raw = run_network(pts)
    # Query the (coarse) network at all sample points. ``viewdirs`` enables view-dependent effects.
    raw = network_query_fn(pts, viewdirs, network_fn)
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

    if N_importance > 0:

        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map

        # Hierarchical sampling (importance sampling):
        # Build a PDF from coarse weights and draw additional samples where the scene is likely.
        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        # Exclude the first and last weights to avoid boundary artifacts when forming the PDF.
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
        # Detach so gradients do not flow through the sampling operation.
        z_samples = z_samples.detach()

        # Merge coarse and fine samples, then sort along the ray.
        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]

        # Use dedicated fine network if provided, else reuse the coarse network.
        run_fn = network_fn if network_fine is None else network_fine
        # raw = run_network(pts, fn=run_fn)
        raw = network_query_fn(pts, viewdirs, run_fn)

        rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)

    ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        # Include coarse pass results for potential losses or visualization.
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        # Standard deviation of fine samples: indicates how concentrated sampling is per ray.
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]

    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")

    return ret

In [8]:
"""
Render a large set of rays by splitting them into manageable minibatches to avoid
out-of-memory (OOM) issues, then stitch the per-batch results back together.

High-level intuition:
- Rendering involves evaluating many rays (often millions). Each ray requires sampling
  points, running an MLP, and compositing colors/densities, which can be memory-heavy.
- Instead of rendering all rays at once, we process them in chunks (minibatches), which
  keeps peak memory usage bounded.
- We collect and concatenate the results for each output quantity (e.g., rgb, depth, acc).
"""

def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
    """
    Render rays in smaller minibatches to avoid OOM.

    Args:
        rays_flat (torch.Tensor): Rays flattened along the batch dimension, shape [N, Cray].
            Each row encodes a single ray's data (e.g., origin, direction, near/far, etc.).
        chunk (int): Number of rays to render per minibatch.
        **kwargs: Additional keyword arguments forwarded to `render_rays` (e.g., models,
            sampling counts, noise parameters).

    Returns:
        Dict[str, torch.Tensor]: A dictionary where each key corresponds to a rendered
        quantity (e.g., 'rgb_map', 'disp_map', 'acc_map', etc.), and each tensor has
        shape [N, ...], formed by concatenating per-chunk outputs along the first dimension.
    """
    # Accumulate lists of per-chunk outputs in a dictionary keyed by output name.
    all_ret = {}

    # Iterate over rays in chunks: [0:chunk], [chunk:2*chunk], ...
    for i in range(0, rays_flat.shape[0], chunk):
        # Render a minibatch of rays using the provided rendering function and settings.
        ret = render_rays(rays_flat[i:i+chunk], **kwargs)

        # For each output field produced by the renderer, append the minibatch result
        # to a growing list so we can concatenate later.
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    # Concatenate lists of chunk results into full [N, ...] tensors for each field.
    all_ret = {k: torch.cat(all_ret[k], dim=0) for k in all_ret}

    return all_ret

In [9]:
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):
    """Render a full image or a provided set of rays using NeRF.

    There are two common ways to call this function:
    - Full image: pass a camera-to-world matrix ``c2w`` and camera intrinsics ``K``.
      The function will generate one ray per pixel and render all of them.
    - Custom rays: pass ``rays=(rays_o, rays_d)`` to render only those rays.

    Args:
        H (int): Image height in pixels.
        W (int): Image width in pixels.
        K (torch.Tensor or np.ndarray): 3x3 camera intrinsics matrix. We use ``K[0,0]``
            (focal length in pixels) when converting to NDC for forward-facing scenes.
        chunk (int): Max number of rays to process per minibatch to control memory usage.
        rays (tuple[Tensor, Tensor], optional): Tuple ``(rays_o, rays_d)`` with shapes
            [..., 3] each, giving ray origins and directions. If provided, ``c2w`` is ignored.
        c2w (torch.Tensor or np.ndarray, optional): [3,4] camera-to-world matrix. If provided,
            rays for the full image are generated via ``get_rays``.
        ndc (bool): If True, convert rays to normalized device coordinates (recommended for
            forward-facing scenes as in the original NeRF paper).
        near (float): Near plane distance used to initialize per-ray near bounds.
        far (float): Far plane distance used to initialize per-ray far bounds.
        use_viewdirs (bool): If True, pass unit viewing directions to the network to enable
            view-dependent appearance (specularities).
        c2w_staticcam (torch.Tensor or np.ndarray, optional): If provided with ``use_viewdirs``
            enabled, generate rays from ``c2w_staticcam`` but keep view directions from ``c2w``.
            This is useful to visualize how view-dependent effects change with direction.
        **kwargs: Forwarded to ``render_rays`` (e.g., networks, sample counts, noise settings).

    Returns:
        list: ``[rgb_map, disp_map, acc_map, extras]``
            - rgb_map: [H, W, 3] rendered colors
            - disp_map: [H, W] disparity (1/depth)
            - acc_map: [H, W] accumulated opacity
            - extras: dict with any additional outputs from ``render_rays``
    """
    # Infer rendering device from the model to keep all tensors consistent
    model_device = next(kwargs['network_fn'].parameters()).device

    if c2w is not None:
        # Special case: render a full image by generating one ray per pixel.
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # Use the provided custom ray batch.
        rays_o, rays_d = rays

    if use_viewdirs:
        # Provide normalized ray directions to the network for view-dependent effects.
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # Visualize only the effect of changing view direction while keeping camera fixed.
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float().to(model_device)

    sh = rays_d.shape # [..., 3]
    if ndc:
        # Convert to NDC (assumes a pinhole camera model), commonly used for LLFF/forward-facing scenes.
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1,3]).float().to(model_device)
    rays_d = torch.reshape(rays_d, [-1,3]).float().to(model_device)

    # Initialize per-ray near/far bounds and pack rays into a single tensor expected by render_rays.
    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    # Render all rays in memory-friendly chunks, then reshape results back to image grids.
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]


In [10]:
def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0, 
                calculate_metrics=False, metrics_include_lpips=True, metrics_device='cuda'):
    """Render a sequence of camera poses to produce a video or trajectory.

    This convenience function loops over a list/array of camera-to-world matrices and calls
    ``render`` for each pose. Optionally writes frames to ``savedir`` and/or renders at a
    lower resolution for speed.

    Args:
        render_poses (Iterable[Tensor or np.ndarray]): Sequence of [3,4] camera-to-world matrices.
        hwf (tuple): ``(H, W, focal)`` from dataset metadata. Only ``H`` and ``W`` are used here.
        K (Tensor or np.ndarray): 3x3 intrinsics matrix passed through to ``render``.
        chunk (int): Chunk size forwarded to ``render``.
        render_kwargs (dict): Keyword args forwarded to ``render`` (e.g., networks and settings).
        gt_imgs (optional): Ground-truth images; if provided, you can compute metrics.
        savedir (str, optional): If provided, write each rendered RGB frame as a PNG to this folder.
        render_factor (int): If > 0, downsample H and W by this factor to render faster.
        calculate_metrics (bool): If True and gt_imgs is provided, calculate image quality metrics.
        metrics_include_lpips (bool): Whether to include LPIPS in metrics calculation.
        metrics_device (str): Device to use for LPIPS calculation.

    Returns:
        If calculate_metrics=False: Tuple[np.ndarray, np.ndarray]: ``(rgbs, disps)``
        If calculate_metrics=True: Tuple[np.ndarray, np.ndarray, dict]: ``(rgbs, disps, metrics)``
        where metrics contains averaged PSNR, SSIM, and optionally LPIPS values.
    """

    H, W, focal = hwf

    if render_factor!=0:
        # Render downsampled for speed by reducing both resolution and focal length proportionally.
        H = H//render_factor
        W = W//render_factor
        focal = focal/render_factor

    rgbs = []
    disps = []
    
    # Initialize metrics collection if requested
    if calculate_metrics and gt_imgs is not None:
        from nerf_helpers import calculate_metrics as calc_metrics
        all_metrics = {'psnr': [], 'ssim': [], 'mse': []}
        if metrics_include_lpips:
            all_metrics['lpips'] = []

    t = time.time()
    for i, c2w in enumerate(tqdm(render_poses)):
        # Simple timing print to monitor rendering speed
        print(i, time.time() - t)
        t = time.time()

        # Render the current pose; we discard the accumulated opacity and extras here
        rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
        rgbs.append(rgb.cpu().numpy())
        disps.append(disp.cpu().numpy())
        if i==0:
            print(rgb.shape, disp.shape)

        # Calculate metrics vs. ground truth if requested
        if calculate_metrics and gt_imgs is not None and render_factor==0 and i < len(gt_imgs):
            gt_img = gt_imgs[i]
            rendered_img = rgb.cpu().numpy()
            
            # Calculate comprehensive metrics
            metrics = calc_metrics(rendered_img, gt_img, 
                                 include_lpips=metrics_include_lpips, 
                                 device=metrics_device)
            
            # Store metrics
            for key in all_metrics:
                if key in metrics and metrics[key] is not None:
                    all_metrics[key].append(metrics[key])
            
            # Print metrics for this frame
            print(f"Frame {i} - PSNR: {metrics.get('psnr', 'N/A'):.2f}, "
                  f"SSIM: {metrics.get('ssim', 'N/A'):.4f}", end="")
            if metrics_include_lpips and 'lpips' in metrics:
                print(f", LPIPS: {metrics.get('lpips', 'N/A'):.4f}")
            else:
                print()

        # Optionally write the frame to disk as an 8-bit PNG.
        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)

    # Stack lists into contiguous arrays with a time dimension.
    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)

    # Calculate average metrics if requested
    if calculate_metrics and gt_imgs is not None:
        avg_metrics = {}
        for key, values in all_metrics.items():
            if values:  # Only average if we have values
                avg_metrics[f'avg_{key}'] = np.mean(values)
                avg_metrics[f'std_{key}'] = np.std(values)
        
        # Print summary
        print("\n=== METRICS SUMMARY ===")
        for key in ['psnr', 'ssim', 'lpips']:
            if f'avg_{key}' in avg_metrics:
                print(f"Average {key.upper()}: {avg_metrics[f'avg_{key}']:.4f} ± {avg_metrics[f'std_{key}']:.4f}")
        print("=======================")
        
        return rgbs, disps, avg_metrics

    return rgbs, disps


In [11]:
"""
Prepare 3D sample points for a NeRF-style network by applying positional encodings,
run the network on these encodings in memory-safe chunks, and then reshape the results
back to the original sampling layout.

High-level intuition:
- We often sample many 3D points (xyz) per ray and, optionally, use a per-ray viewing
  direction. Raw coordinates are hard for small MLPs to learn high-frequency detail,
  so we first apply a positional encoding that maps them to a higher-dimensional space.
- We flatten everything to a big batch so the network can process all samples uniformly.
- To avoid running out of memory, we split this big batch into chunks and process them
  sequentially, then stitch the outputs back together and restore the original shape.
"""

def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """
    Prepare inputs for a NeRF-style MLP and apply the network in chunks.

    Conceptual overview:
    - Positions (xyz) and, optionally, viewing directions are first positional-encoded
      (a deterministic mapping to a higher-dimensional space using sin/cos at multiple
      frequencies). This helps the MLP represent fine details and sharp changes.
    - We flatten leading dimensions so all samples are processed as a single batch.
    - To keep memory usage in check, we process this batch in chunks (netchunk).
    - Finally, we reshape outputs to match the original sampling layout.

    Args:
        inputs (torch.Tensor): Sample positions with shape [..., Cpos], typically Cpos = 3.
            Example: [N_rays, N_samples, 3]. The leading dimensions can be any shape.
        viewdirs (Optional[torch.Tensor]): Per-ray viewing directions with shape
            [N_rays, Cdir] (typically Cdir = 3), or None if not using view-dependent effects.
            When provided, each ray direction is broadcast to all samples along that ray.
        fn (Callable[[torch.Tensor], torch.Tensor]): Neural network (e.g., NeRF MLP) that
            consumes encoded features and returns outputs per sample.
        embed_fn (Callable[[torch.Tensor], torch.Tensor]): Positional encoder for positions;
            maps [*, Cpos] -> [*, Cpos_enc].
        embeddirs_fn (Optional[Callable[[torch.Tensor], torch.Tensor]]): Positional encoder
            for directions; maps [*, Cdir] -> [*, Cdir_enc]. Only used if viewdirs is not None.
        netchunk (int): Maximum number of samples to process per chunk to limit peak memory.

    Returns:
        torch.Tensor: Network outputs with shape [..., Cout], where the leading dimensions
        match those of `inputs` (excluding its last channel), and Cout is determined by `fn`.
    """
    # Flatten all leading dimensions so we have a simple [N, Cpos] batch of positions.
    # N is the total number of samples across rays and per-ray samples.
    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])

    # Positional-encode the flattened positions (e.g., apply sin/cos at multiple frequencies).
    # This expands each 3D input into a richer, higher-dimensional representation
    # that makes it easier for the MLP to model fine spatial detail.
    embedded = embed_fn(inputs_flat)

    # If using view-dependent appearance (e.g., specular highlights that vary with direction),
    # we also encode per-ray viewing directions and concatenate them with position encodings.
    if viewdirs is not None:
        # Insert a length-1 axis, then broadcast each ray direction across all samples on that ray
        # so that every sample point along a ray shares the same view direction.
        input_dirs = viewdirs[:, None].expand(inputs.shape)

        # Flatten directions to align with the flattened positions: [N, Cdir].
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])

        # Positional-encode viewing directions in the same spirit as positions.
        embedded_dirs = embeddirs_fn(input_dirs_flat)

        # Concatenate encoded positions and encoded directions along the feature/channel axis.
        embedded = torch.cat([embedded, embedded_dirs], -1)

    # Apply the network to the encoded features in memory-safe chunks along the batch dimension.
    # This prevents out-of-memory errors when the total number of samples is very large.
    # Ensure inputs are on the same device as the model parameters.
    model_device = next(fn.parameters()).device
    embedded = embedded.to(model_device)
    outputs_flat = batchify(fn, netchunk)(embedded)

    # Restore the original leading shape (e.g., [N_rays, N_samples]) and append the output channels.
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])

    return outputs

## Instantiate NeRF
This section creates a function to instantiate NeRF.

In [12]:
from nerf import NeRF

def create_nerf(args):
    """Instantiate NeRF's MLP model.
    """
    embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
    output_ch = 5 if args.N_importance > 0 else 4
    skips = [4]
    model = NeRF(D=args.netdepth, W=args.netwidth,
                 input_ch=input_ch, output_ch=output_ch, skips=skips,
                 input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
    grad_vars = list(model.parameters())

    model_fine = None
    if args.N_importance > 0:
        model_fine = NeRF(D=args.netdepth_fine, W=args.netwidth_fine,
                          input_ch=input_ch, output_ch=output_ch, skips=skips,
                          input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
        grad_vars += list(model_fine.parameters())

    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=args.netchunk)

    # Create optimizer
    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

    start = 0
    basedir = args.basedir
    expname = args.expname

    ##########################

    # Load checkpoints
    # Make checkpt dirs
    checkpoint_dir = os.path.join(basedir, expname, "checkpoints")
    os.makedirs(checkpoint_dir, exist_ok=True)

    if args.ft_path is not None and args.ft_path!='None':
        ckpts = [args.ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, "checkpoints", f) for f in sorted(os.listdir(os.path.join(basedir, expname, "checkpoints"))) if 'tar' in f]

    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not args.no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        # Load model
        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])

    ##########################

    render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : args.perturb,
        'N_importance' : args.N_importance,
        'network_fine' : model_fine,
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'use_viewdirs' : args.use_viewdirs,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }

    # NDC only good for LLFF-style forward facing data
    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer

## Training Loop

In [13]:

def train():
    import os
    import skimage
    """
    End-to-end NeRF training loop.

    High-level overview for newcomers:
    - Load a dataset of posed images (e.g., LLFF/Blender/LINEMOD). Each image comes with a camera pose.
    - Create a NeRF model (coarse and optionally fine MLPs) and a renderer.
    - On each iteration, sample camera rays and their target RGB values from the dataset.
    - Render rays with the NeRF model via volumetric rendering (accumulate colors along the ray).
    - Compute a reconstruction loss (e.g., MSE) against ground-truth pixels and optimize the networks.
    - Periodically render validation trajectories and/or save snapshots.

    Key concepts:
    - Rays: For each pixel, we cast a ray into the scene with origin/direction computed from intrinsics and pose.
    - Sampling: We sample multiple points along each ray (coarse). Optionally resample (fine) where the scene is likely.
    - Volume rendering: Convert per-point density+color to opacity weights and composite to a final pixel color.
    - Hierarchical sampling: A second pass focuses samples where the coarse pass is confident the scene exists.
    """

    # ----------------------
    # 1) Load data and choose near/far bounds depending on dataset
    # ----------------------
    K = None
    if args.dataset_type == 'llff':
        images, poses, bds, render_poses, i_test = load_llff_data(
            args.datadir, args.factor, recenter=True, bd_factor=.75, spherify=args.spherify
        )
        hwf = poses[0,:3,-1]           # (H, W, focal)
        poses = poses[:,:3,:4]         # Only keep rotation+translation (3x4) per pose
        print('Loaded llff', images.shape, render_poses.shape, hwf, args.datadir)
        if not isinstance(i_test, list):
            i_test = [i_test]

        # Optional LLFF holdout: use every N-th image as test
        if args.llffhold > 0:
            print('Auto LLFF holdout,', args.llffhold)
            i_test = np.arange(images.shape[0])[::args.llffhold]

        i_val = i_test
        i_train = np.array([i for i in np.arange(int(images.shape[0]))
                            if (i not in i_test and i not in i_val)])

        print('DEFINING BOUNDS')
        if args.no_ndc:
            # If not using NDC (e.g., inward-facing/360 scenes), near/far from bounds
            near = np.ndarray.min(bds) * .9
            far = np.ndarray.max(bds) * 1.
        else:
            # Forward-facing (LLFF) uses NDC, so near/far are normalized
            near = 0.
            far = 1.
        print('NEAR FAR', near, far)

    elif args.dataset_type == 'blender':
        images, poses, render_poses, hwf, i_split = load_blender_data(
            args.datadir, args.half_res, args.testskip
        )
        print('Loaded blender', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        # Standard near/far for Blender synthetic scenes
        near = 2.
        far = 6.

        # Composite over white if requested (makes background white instead of black)
        if args.white_bkgd:
            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]

    elif args.dataset_type == 'LINEMOD':
        images, poses, render_poses, hwf, K, i_split, near, far = load_LINEMOD_data(
            args.datadir, args.half_res, args.testskip
        )
        print(f'Loaded LINEMOD, images shape: {images.shape}, hwf: {hwf}, K: {K}')
        print(f'[CHECK HERE] near: {near}, far: {far}.')
        i_train, i_val, i_test = i_split

        if args.white_bkgd:
            images = images[...,:3]*images[...,-1:] + (1.-images[...,-1:])
        else:
            images = images[...,:3]

    elif args.dataset_type == 'deepvoxels':
        images, poses, render_poses, hwf, i_split = load_dv_data(
            scene=args.shape, basedir=args.datadir, testskip=args.testskip
        )
        print('Loaded deepvoxels', images.shape, render_poses.shape, hwf, args.datadir)
        i_train, i_val, i_test = i_split

        # DeepVoxels scenes define a hemisphere radius; near/far around it
        hemi_R = np.mean(np.linalg.norm(poses[:,:3,-1], axis=-1))
        near = hemi_R-1.
        far = hemi_R+1.

    else:
        print('Unknown dataset type', args.dataset_type, 'exiting')
        return

    # ----------------------
    # 2) Prepare intrinsics (H, W, focal) and default K if not provided
    # ----------------------
    H, W, focal = hwf
    H, W = int(H), int(W)
    hwf = [H, W, focal]

    if K is None:
        # Construct a pinhole intrinsics matrix assuming principal point at image center
        K = np.array([
            [focal, 0, 0.5*W],
            [0, focal, 0.5*H],
            [0, 0, 1]
        ])

    # If we are evaluating on the test set, use the corresponding subset of poses
    if args.render_test:
        render_poses = np.array(poses[i_test])

    # ----------------------
    # 3) Logging setup and persistence of configs
    # ----------------------
    basedir = args.basedir
    expname = args.expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)

    # Save the parsed args for reproducibility
    f = os.path.join(basedir, expname, 'args.txt')
    with open(f, 'w') as file:
        for arg, attr in sorted(args.items()):
            attr = getattr(args, arg)
            file.write('{} = {}\n'.format(arg, attr))
    # Save the YAML configuration that produced these args
    f_yaml = os.path.join(basedir, expname, 'config.yaml')
    try:
        utils.save_yaml(dict(args), f_yaml)
    except Exception:
        # Fallback: write a simple YAML dump directly
        with open(f_yaml, 'w', encoding='utf-8') as yf:
            import yaml as _yaml
            _yaml.dump(dict(args), yf, default_flow_style=False, indent=2)

    # ----------------------
    # 4) Create NeRF models and optimizer
    # ----------------------
    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)
    global_step = start

    # Near/far bounds are used by the renderer; update both train and test configs
    bds_dict = { 'near' : near, 'far' : far }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move the camera trajectory used for rendering validation videos to the GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # ----------------------
    # 5) Short-circuit: render only mode
    # ----------------------
    if args.render_only:
        print('RENDER ONLY')
        with torch.no_grad():
            if args.render_test:
                # Switch to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None

            testsavedir = os.path.join(
                basedir, expname, 'renderonly_{}_{:06d}'.format('test' if args.render_test else 'path', start)
            )
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            # Render a path and save to video
            rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test,
                                   gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
            print('Done rendering', testsavedir)
            imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)

            return

    # ----------------------
    # 6) Prepare random-ray batching (optional) and move data to GPU
    # ----------------------
    N_rand = args.N_rand
    use_batching = not args.no_batching
    if use_batching:
        # Precompute all rays for all training images, then shuffle mini-batches each step
        print('get rays')
        rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0)  # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:,None]], 1)                   # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4])                         # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0)                 # train images only
        rays_rgb = np.reshape(rays_rgb, [-1,3,3])                              # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)
        print('done')
        i_batch = 0

    # Move arrays/tensors to GPU for training
    if use_batching:
        images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)

    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    # ----------------------
    # 7) Main optimization loop
    # ----------------------
    # writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))  # Optional TB logging

    start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()

        # Ensure global step and loop iteration i are synced for checkpoint resume
        global_step = i

        # Sample a batch of rays and target colors
        if use_batching:
            # Random over all images (global batching)
            batch = rays_rgb[i_batch:i_batch+N_rand]  # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            # Move sliding window; reshuffle at epoch end
            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0
        else:
            # Random rays from a randomly chosen training image (per-image batching)
            img_i = np.random.choice(i_train)
            target = images[img_i]
            target = torch.Tensor(target).to(device)
            pose = poses[img_i, :3,:4]

            if N_rand is not None:
                # Compute per-pixel rays for this image, then sample N_rand of them
                rays_o, rays_d = get_rays(H, W, K, torch.Tensor(pose))  # (H, W, 3), (H, W, 3)

                if i < args.precrop_iters:
                    # Optional: focus early training on the image center (stabilizes training)
                    dH = int(H//2 * args.precrop_frac)
                    dW = int(W//2 * args.precrop_frac)
                    coords = torch.stack(
                        torch.meshgrid(
                            torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), 
                            torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
                        ), -1)
                    if i == start:
                        print(f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")                
                else:
                    coords = torch.stack(
                        torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1
                    )  # (H, W, 2)

                coords = torch.reshape(coords, [-1,2])                         # (H*W, 2)
                select_inds = np.random.choice(coords.shape[0], size=[N_rand], replace=False)  # (N_rand,)
                select_coords = coords[select_inds].long()                     # (N_rand, 2)
                rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]      # (N_rand, 3)
                rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]      # (N_rand, 3)
                batch_rays = torch.stack([rays_o, rays_d], 0)
                target_s = target[select_coords[:, 0], select_coords[:, 1]]    # (N_rand, 3)

        # ---- Core rendering + loss ----
        rgb, disp, acc, extras = render(
            H, W, K, chunk=args.chunk, rays=batch_rays, verbose=i < 10, retraw=True, **render_kwargs_train
        )

        optimizer.zero_grad()
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        # If hierarchical sampling is enabled, include the coarse-pass loss
        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()
        optimizer.step()

        # --- Learning rate decay (exponential) ---
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate

        dt = time.time()-time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")

        # ----------------------
        # 8) Periodic logging, checkpointing, and visualization
        # ----------------------
        if i%args.i_weights==0:
            # Save model checkpoints for resuming or analysis
            path = os.path.join(basedir, expname, "checkpoints" ,'{:06d}.tar'.format(i))
            torch.save({
                'global_step': global_step,
                'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
                'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, path)
            print('Saved checkpoints at', path)

        if i%args.i_video==0 and i > 0:
            # Render a validation trajectory and write MP4 previews
            with torch.no_grad():
                rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

            # If you want to visualize view-dependent effects, you can fix the camera position
            # and vary only the view direction (see commented example in original code).

        if i%args.i_testset==0 and i > 0:
            # Render the held-out test set and save frames with comprehensive metrics
            testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)

            # Make metrics dir to store our metrics
            metricsdir = os.path.join(basedir, expname, "metrics")
            os.makedirs(metricsdir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            
            # Calculate whether to include LPIPS (slower, so maybe every 3rd evaluation)
            # Count how many test evaluations we've done
            test_eval_count = i // args.i_testset
            # include_lpips = (test_eval_count % 3 == 0)  # Every 3rd test evaluation
            include_lpips = True  # use true for now as its not that bad = the computation
                        

            # Alternative: Always include LPIPS (comment out above and uncomment below)
            # include_lpips = True
            
            with torch.no_grad():
                try:
                    # Enhanced render with metrics calculation
                    rgbs, disps, metrics = render_path(
                        torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk,
                        render_kwargs_test, 
                        gt_imgs=images[i_test], 
                        savedir=testsavedir,
                        calculate_metrics=True,
                        metrics_include_lpips=include_lpips,
                        metrics_device=device
                    )
                    
                    # Log metrics to both TXT and JSON formats
                    import json
                    import time as time_module
                    
                    # Text format (human readable)
                    metrics_file_txt = os.path.join(basedir, expname, "metrics", f'metrics_{i:06d}.txt')
                    with open(metrics_file_txt, 'w') as f:
                        f.write(f"Iteration: {i}\n")
                        f.write(f"Timestamp: {time_module.strftime('%Y-%m-%d %H:%M:%S')}\n")
                        f.write(f"LPIPS_included: {include_lpips}\n")
                        f.write("-" * 40 + "\n")
                        for key, value in metrics.items():
                            f.write(f"{key}: {value:.6f}\n")
                    
                    # JSON format (machine readable)
                    metrics_file_json = os.path.join(basedir, expname, "metrics", f'metrics_{i:06d}.json')
                    json_data = {
                        'iteration': i,
                        'timestamp': time_module.strftime('%Y-%m-%d %H:%M:%S'),
                        'lpips_included': include_lpips,
                        'metrics': {k: float(v) for k, v in metrics.items()}
                    }
                    with open(metrics_file_json, 'w') as f:
                        json.dump(json_data, f, indent=2)
                    
                    # Append to consolidated training log
                    training_log_file = os.path.join(basedir, expname, "metrics", 'training_metrics.json')
                    if os.path.exists(training_log_file):
                        with open(training_log_file, 'r') as f:
                            training_log = json.load(f)
                    else:
                        training_log = {'experiment': expname, 'metrics_history': []}
                    
                    training_log['metrics_history'].append(json_data)
                    with open(training_log_file, 'w') as f:
                        json.dump(training_log, f, indent=2)
                    
                    print(f"Metrics logged to: {metrics_file_txt} and {metrics_file_json}")
                    print(f"Training history updated: {training_log_file}")
                    
                except Exception as e:
                    print(f"Enhanced evaluation failed, using basic render: {e}")
                    # Fallback to basic rendering
                    render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, args.chunk,
                               render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
            
            print('Saved test set with metrics')

        if i%args.i_print==0:
            tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")
            
            import json
            import time as time_module
            
            # Save detailed training log
            training_data = {
                'iteration': i,
                'timestamp': time_module.strftime('%Y-%m-%d %H:%M:%S'),
                'loss': float(loss.item()),
                'psnr': float(psnr.item()),
                'learning_rate': optimizer.param_groups[0]['lr']
            }
            
            # Add coarse loss if available
            if 'rgb0' in extras:
                training_data['loss_coarse'] = float(img_loss0.item())
                training_data['psnr_coarse'] = float(psnr0.item())
            
            # Append to training log file
            training_log_file = os.path.join(basedir, expname, 'training_log.jsonl')
            with open(training_log_file, 'a') as f:
                f.write(json.dumps(training_data) + '\n')
            
            # Also save as CSV for easy analysis
            csv_log_file = os.path.join(basedir, expname, 'training_log.csv')
            import os
            if not os.path.exists(csv_log_file):
                # Create header
                with open(csv_log_file, 'w') as f:
                    headers = ['iteration', 'timestamp', 'loss', 'psnr', 'learning_rate']
                    if 'rgb0' in extras:
                        headers.extend(['loss_coarse', 'psnr_coarse'])
                    f.write(','.join(headers) + '\n')
            
            # Append data
            with open(csv_log_file, 'a') as f:
                row_data = [str(training_data[key]) for key in ['iteration', 'timestamp', 'loss', 'psnr', 'learning_rate']]
                if 'rgb0' in extras:
                    row_data.extend([str(training_data['loss_coarse']), str(training_data['psnr_coarse'])])
                f.write(','.join(row_data) + '\n')

        global_step += 1

## Start Training

In [14]:
# Insert training parameters
N_iters = 200000 + 1
train()

Loaded blender (138, 800, 800, 4) torch.Size([40, 4, 4]) [800, 800, np.float64(1111.1110311937682)] ./data/nerf_synthetic/ship
Found ckpts []
Not ndc!
Begin
TRAIN views are [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]
TEST views are [113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
 131 132 133 134 135 136 137]
VAL views are [100 101 102 103 104 105 106 107 108 109 110 111 112]


  0%|          | 0/200000 [00:00<?, ?it/s]

[Config] Center cropping of size 400 x 400 is enabled until iter 500


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  1%|          | 1001/200000 [02:29<8:07:48,  6.80it/s]

[TRAIN] Iter: 1000 Loss: 0.02703813463449478  PSNR: 18.873912811279297


  1%|          | 2001/200000 [04:59<8:12:16,  6.70it/s]

[TRAIN] Iter: 2000 Loss: 0.014134297147393227  PSNR: 21.651899337768555


  2%|▏         | 3001/200000 [07:31<8:11:59,  6.67it/s]

[TRAIN] Iter: 3000 Loss: 0.011281629092991352  PSNR: 22.807647705078125


  2%|▏         | 4001/200000 [10:02<8:28:59,  6.42it/s]

[TRAIN] Iter: 4000 Loss: 0.011225137859582901  PSNR: 22.735225677490234


  3%|▎         | 5001/200000 [12:34<8:02:10,  6.74it/s]

[TRAIN] Iter: 5000 Loss: 0.009835847653448582  PSNR: 23.127941131591797


  3%|▎         | 6001/200000 [15:05<8:42:35,  6.19it/s]

[TRAIN] Iter: 6000 Loss: 0.010197076946496964  PSNR: 23.46279525756836


  4%|▎         | 7001/200000 [17:36<8:16:19,  6.48it/s]

[TRAIN] Iter: 7000 Loss: 0.008184758946299553  PSNR: 24.578683853149414


  4%|▍         | 8001/200000 [20:08<7:58:23,  6.69it/s]

[TRAIN] Iter: 8000 Loss: 0.014123108237981796  PSNR: 22.141983032226562


  5%|▍         | 9001/200000 [22:41<7:49:26,  6.78it/s]

[TRAIN] Iter: 9000 Loss: 0.00782245397567749  PSNR: 24.787160873413086


  5%|▌         | 10001/200000 [25:12<8:32:09,  6.18it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\010000.tar
[TRAIN] Iter: 10000 Loss: 0.011145872995257378  PSNR: 23.285533905029297


  6%|▌         | 11001/200000 [27:43<7:47:11,  6.74it/s]

[TRAIN] Iter: 11000 Loss: 0.009289225563406944  PSNR: 24.767581939697266


  6%|▌         | 12001/200000 [30:14<7:43:59,  6.75it/s]

[TRAIN] Iter: 12000 Loss: 0.007449848111718893  PSNR: 24.514026641845703


  7%|▋         | 13001/200000 [32:45<7:57:48,  6.52it/s]

[TRAIN] Iter: 13000 Loss: 0.011484567075967789  PSNR: 23.209569931030273


  7%|▋         | 14001/200000 [35:16<8:02:00,  6.43it/s]

[TRAIN] Iter: 14000 Loss: 0.008831962943077087  PSNR: 24.231021881103516


  8%|▊         | 15001/200000 [37:48<7:32:48,  6.81it/s]

[TRAIN] Iter: 15000 Loss: 0.008566387929022312  PSNR: 24.492753982543945


  8%|▊         | 16001/200000 [40:13<7:23:44,  6.91it/s]

[TRAIN] Iter: 16000 Loss: 0.006082460284233093  PSNR: 26.032611846923828


  9%|▊         | 17001/200000 [42:40<7:52:11,  6.46it/s]

[TRAIN] Iter: 17000 Loss: 0.005529690533876419  PSNR: 25.723543167114258


  9%|▉         | 18001/200000 [45:16<7:45:11,  6.52it/s]

[TRAIN] Iter: 18000 Loss: 0.009225727058947086  PSNR: 24.037626266479492


 10%|▉         | 19001/200000 [47:55<7:42:04,  6.53it/s]

[TRAIN] Iter: 19000 Loss: 0.0064555564895272255  PSNR: 25.290658950805664


 10%|█         | 20001/200000 [50:31<7:45:04,  6.45it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\020000.tar
[TRAIN] Iter: 20000 Loss: 0.007642399985343218  PSNR: 25.073627471923828


 11%|█         | 21001/200000 [53:05<7:28:22,  6.65it/s]

[TRAIN] Iter: 21000 Loss: 0.008661126717925072  PSNR: 23.899343490600586


 11%|█         | 22001/200000 [55:40<7:32:09,  6.56it/s]

[TRAIN] Iter: 22000 Loss: 0.006721525453031063  PSNR: 25.82712173461914


 12%|█▏        | 23001/200000 [58:15<7:25:54,  6.62it/s]

[TRAIN] Iter: 23000 Loss: 0.006798615679144859  PSNR: 25.714977264404297


 12%|█▏        | 24001/200000 [1:00:47<7:22:56,  6.62it/s]

[TRAIN] Iter: 24000 Loss: 0.006691374816000462  PSNR: 25.345914840698242


 13%|█▎        | 25001/200000 [1:03:20<7:22:00,  6.60it/s]

[TRAIN] Iter: 25000 Loss: 0.006233075633645058  PSNR: 25.451080322265625


 13%|█▎        | 26001/200000 [1:05:55<7:29:08,  6.46it/s] 

[TRAIN] Iter: 26000 Loss: 0.007812762632966042  PSNR: 24.991661071777344


 14%|█▎        | 27001/200000 [1:08:31<7:16:44,  6.60it/s]

[TRAIN] Iter: 27000 Loss: 0.007583301514387131  PSNR: 24.596891403198242


 14%|█▍        | 28001/200000 [1:11:05<7:19:42,  6.52it/s]

[TRAIN] Iter: 28000 Loss: 0.005594450980424881  PSNR: 25.93825340270996


 15%|█▍        | 29001/200000 [1:13:39<7:05:16,  6.70it/s]

[TRAIN] Iter: 29000 Loss: 0.00743396021425724  PSNR: 25.54022789001465


 15%|█▌        | 30001/200000 [1:16:15<7:31:00,  6.28it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\030000.tar
[TRAIN] Iter: 30000 Loss: 0.009090022183954716  PSNR: 24.051387786865234


 16%|█▌        | 31001/200000 [1:18:49<7:33:21,  6.21it/s]

[TRAIN] Iter: 31000 Loss: 0.00683456240221858  PSNR: 25.557044982910156


 16%|█▌        | 32001/200000 [1:21:23<7:00:57,  6.65it/s]

[TRAIN] Iter: 32000 Loss: 0.006219794042408466  PSNR: 27.061830520629883


 17%|█▋        | 33001/200000 [1:23:56<7:07:11,  6.52it/s]

[TRAIN] Iter: 33000 Loss: 0.0047484394162893295  PSNR: 26.7652530670166


 17%|█▋        | 34001/200000 [1:26:30<6:59:31,  6.59it/s]

[TRAIN] Iter: 34000 Loss: 0.0061880131252110004  PSNR: 26.12705421447754


 18%|█▊        | 35001/200000 [1:29:06<7:39:04,  5.99it/s]

[TRAIN] Iter: 35000 Loss: 0.006873925216495991  PSNR: 25.015981674194336


 18%|█▊        | 36001/200000 [1:31:38<6:54:38,  6.59it/s]

[TRAIN] Iter: 36000 Loss: 0.005809767637401819  PSNR: 26.616756439208984


 19%|█▊        | 37001/200000 [1:34:12<6:58:51,  6.49it/s]

[TRAIN] Iter: 37000 Loss: 0.0067953369580209255  PSNR: 25.560213088989258


 19%|█▉        | 38001/200000 [1:36:47<6:56:25,  6.48it/s]

[TRAIN] Iter: 38000 Loss: 0.0037822378799319267  PSNR: 29.009016036987305


 20%|█▉        | 39001/200000 [1:39:22<6:55:46,  6.45it/s]

[TRAIN] Iter: 39000 Loss: 0.005900314077734947  PSNR: 26.393186569213867


 20%|██        | 40001/200000 [1:41:57<6:57:10,  6.39it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\040000.tar
[TRAIN] Iter: 40000 Loss: 0.005269918590784073  PSNR: 26.161104202270508


 21%|██        | 41001/200000 [1:44:32<6:47:42,  6.50it/s]

[TRAIN] Iter: 41000 Loss: 0.004055521450936794  PSNR: 27.665876388549805


 21%|██        | 42001/200000 [1:47:06<6:43:09,  6.53it/s]

[TRAIN] Iter: 42000 Loss: 0.006109561771154404  PSNR: 26.07720375061035


 22%|██▏       | 43001/200000 [1:49:40<6:40:07,  6.54it/s]

[TRAIN] Iter: 43000 Loss: 0.004919724073261023  PSNR: 27.10034942626953


 22%|██▏       | 44001/200000 [1:52:13<6:33:58,  6.60it/s]

[TRAIN] Iter: 44000 Loss: 0.003904170822352171  PSNR: 27.770381927490234


 23%|██▎       | 45001/200000 [1:54:44<6:27:38,  6.66it/s]

[TRAIN] Iter: 45000 Loss: 0.004818654619157314  PSNR: 27.465129852294922


 23%|██▎       | 46001/200000 [1:57:16<6:35:19,  6.49it/s]

[TRAIN] Iter: 46000 Loss: 0.007608301937580109  PSNR: 25.171993255615234


 24%|██▎       | 47001/200000 [1:59:49<6:24:17,  6.64it/s]

[TRAIN] Iter: 47000 Loss: 0.0053566559217870235  PSNR: 27.49433708190918


 24%|██▍       | 48001/200000 [2:02:21<6:25:28,  6.57it/s]

[TRAIN] Iter: 48000 Loss: 0.004898060578852892  PSNR: 26.850053787231445


 25%|██▍       | 49001/200000 [2:04:55<6:30:41,  6.44it/s]

[TRAIN] Iter: 49000 Loss: 0.004556602798402309  PSNR: 27.519981384277344


 25%|██▍       | 49999/200000 [2:07:28<6:27:24,  6.45it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\050000.tar




0 0.002465486526489258




torch.Size([800, 800, 3]) torch.Size([800, 800])
1 40.88119888305664




2 43.27020215988159




3 40.609182357788086




4 41.454001903533936




5 41.036882400512695




6 40.43398904800415




7 41.32930397987366




8 42.12748384475708




9 40.67503499984741




10 40.24250674247742




11 40.7868332862854




12 40.83236241340637




13 40.65360188484192




14 40.72481942176819




15 40.47212195396423




16 40.89969205856323




17 42.912891149520874




18 43.043986082077026




19 39.01219701766968




20 39.00898313522339




21 38.6578323841095




22 38.74339437484741




23 38.92628788948059




24 39.355818033218384




25 39.31734848022461




26 38.8914897441864




27 38.735692739486694




28 38.548213720321655




29 38.80468201637268




30 38.95731329917908




31 38.803335428237915




32 38.953521490097046




33 36.83932185173035




34 30.932189226150513




35 30.931474447250366




36 30.972352743148804




37 30.979562759399414




38 30.945709705352783




39 30.96122121810913


100%|██████████| 40/40 [25:40<00:00, 38.51s/it]


Done, saving (40, 800, 800, 3) (40, 800, 800)
test poses shape torch.Size([25, 4, 4])




0 0.0023288726806640625
torch.Size([800, 800, 3]) torch.Size([800, 800])
Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: d:\GitHub\nerf-projects\nerf\.venv\Lib\site-packages\lpips\weights\v0.1\vgg.pth
Frame 0 - PSNR: 26.45, SSIM: 0.7505, LPIPS: 0.3062




1 33.08561110496521




Frame 1 - PSNR: 25.70, SSIM: 0.7451, LPIPS: 0.2963
2 31.01781463623047




Frame 2 - PSNR: 25.81, SSIM: 0.7514, LPIPS: 0.2872
3 31.003568410873413




Frame 3 - PSNR: 26.07, SSIM: 0.7571, LPIPS: 0.2949
4 31.05986523628235




Frame 4 - PSNR: 26.57, SSIM: 0.7894, LPIPS: 0.2756
5 31.03875470161438




Frame 5 - PSNR: 26.31, SSIM: 0.8123, LPIPS: 0.2701
6 31.06104564666748




Frame 6 - PSNR: 25.48, SSIM: 0.8167, LPIPS: 0.2681
7 31.044384002685547




Frame 7 - PSNR: 25.64, SSIM: 0.8220, LPIPS: 0.2619
8 31.03125762939453




Frame 8 - PSNR: 26.04, SSIM: 0.8298, LPIPS: 0.2548
9 31.021644115447998




Frame 9 - PSNR: 26.88, SSIM: 0.8552, LPIPS: 0.2275
10 31.069543600082397




Frame 10 - PSNR: 27.09, SSIM: 0.8718, LPIPS: 0.2135
11 31.054287433624268




Frame 11 - PSNR: 25.64, SSIM: 0.8563, LPIPS: 0.2207
12 31.019580125808716




Frame 12 - PSNR: 24.67, SSIM: 0.8506, LPIPS: 0.2254
13 31.040910243988037




Frame 13 - PSNR: 24.75, SSIM: 0.8555, LPIPS: 0.2131
14 31.040313482284546




Frame 14 - PSNR: 24.98, SSIM: 0.8711, LPIPS: 0.1967
15 31.023792028427124




Frame 15 - PSNR: 26.25, SSIM: 0.8842, LPIPS: 0.1818
16 30.9947988986969




Frame 16 - PSNR: 26.59, SSIM: 0.8690, LPIPS: 0.1926
17 31.038219690322876




Frame 17 - PSNR: 25.59, SSIM: 0.8438, LPIPS: 0.2246
18 31.048261404037476




Frame 18 - PSNR: 25.17, SSIM: 0.8277, LPIPS: 0.2508
19 31.0466468334198




Frame 19 - PSNR: 25.56, SSIM: 0.8163, LPIPS: 0.2695
20 31.04893207550049




Frame 20 - PSNR: 25.94, SSIM: 0.8155, LPIPS: 0.2787
21 31.081602811813354




Frame 21 - PSNR: 26.47, SSIM: 0.8109, LPIPS: 0.2901
22 31.063095569610596




Frame 22 - PSNR: 27.26, SSIM: 0.8166, LPIPS: 0.2960
23 31.03350806236267




Frame 23 - PSNR: 27.44, SSIM: 0.7969, LPIPS: 0.2993
24 31.05296277999878


100%|██████████| 25/25 [12:58<00:00, 31.12s/it]
 25%|██▌       | 50000/200000 [2:46:08<29004:12:00, 696.10s/it]

Frame 24 - PSNR: 27.08, SSIM: 0.7709, LPIPS: 0.3040

=== METRICS SUMMARY ===
Average PSNR: 26.0580 ± 0.7502
Average SSIM: 0.8195 ± 0.0402
Average LPIPS: 0.2560 ± 0.0378
Metrics logged to: ./logs\ship_blender200k_fullres_higher_samples\metrics\metrics_050000.txt and ./logs\ship_blender200k_fullres_higher_samples\metrics\metrics_050000.json
Training history updated: ./logs\ship_blender200k_fullres_higher_samples\metrics\training_metrics.json
Saved test set with metrics
[TRAIN] Iter: 50000 Loss: 0.0059603676199913025  PSNR: 26.013748168945312


 26%|██▌       | 51001/200000 [2:48:33<6:00:03,  6.90it/s]     

[TRAIN] Iter: 51000 Loss: 0.004183984361588955  PSNR: 27.86517333984375


 26%|██▌       | 52001/200000 [2:50:58<5:56:31,  6.92it/s]

[TRAIN] Iter: 52000 Loss: 0.005671474151313305  PSNR: 26.84699821472168


 27%|██▋       | 53001/200000 [2:53:23<5:55:57,  6.88it/s]

[TRAIN] Iter: 53000 Loss: 0.005463732406497002  PSNR: 26.86399269104004


 27%|██▋       | 54001/200000 [2:55:47<5:50:54,  6.93it/s]

[TRAIN] Iter: 54000 Loss: 0.007277948781847954  PSNR: 24.842395782470703


 28%|██▊       | 55001/200000 [2:58:12<5:49:56,  6.91it/s]

[TRAIN] Iter: 55000 Loss: 0.006327393464744091  PSNR: 26.081729888916016


 28%|██▊       | 56001/200000 [3:00:56<7:07:46,  5.61it/s]

[TRAIN] Iter: 56000 Loss: 0.005565634462982416  PSNR: 26.498062133789062


 29%|██▊       | 57001/200000 [3:03:50<6:43:48,  5.90it/s]

[TRAIN] Iter: 57000 Loss: 0.004488984122872353  PSNR: 27.517047882080078


 29%|██▉       | 58001/200000 [3:06:43<6:18:23,  6.25it/s]

[TRAIN] Iter: 58000 Loss: 0.004457986447960138  PSNR: 27.28682518005371


 30%|██▉       | 59001/200000 [3:09:37<6:37:28,  5.91it/s]

[TRAIN] Iter: 59000 Loss: 0.005672221537679434  PSNR: 26.301687240600586


 30%|███       | 60001/200000 [3:12:33<6:30:53,  5.97it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\060000.tar
[TRAIN] Iter: 60000 Loss: 0.005299689713865519  PSNR: 25.9290714263916


 31%|███       | 61001/200000 [3:15:27<6:17:38,  6.13it/s]

[TRAIN] Iter: 61000 Loss: 0.0034662147518247366  PSNR: 29.236658096313477


 31%|███       | 62001/200000 [3:18:22<6:54:14,  5.55it/s]

[TRAIN] Iter: 62000 Loss: 0.006179520394653082  PSNR: 26.815662384033203


 32%|███▏      | 63001/200000 [3:21:16<6:21:09,  5.99it/s]

[TRAIN] Iter: 63000 Loss: 0.0054740216583013535  PSNR: 26.925926208496094


 32%|███▏      | 64001/200000 [3:24:07<6:44:41,  5.60it/s]

[TRAIN] Iter: 64000 Loss: 0.0063876016065478325  PSNR: 25.552223205566406


 33%|███▎      | 65001/200000 [3:26:47<5:37:29,  6.67it/s]

[TRAIN] Iter: 65000 Loss: 0.0045754555612802505  PSNR: 27.557844161987305


 33%|███▎      | 66001/200000 [3:29:19<5:37:46,  6.61it/s]

[TRAIN] Iter: 66000 Loss: 0.005206495523452759  PSNR: 26.645606994628906


 34%|███▎      | 67001/200000 [3:31:51<5:37:38,  6.57it/s]

[TRAIN] Iter: 67000 Loss: 0.0062256911769509315  PSNR: 25.81958770751953


 34%|███▍      | 68001/200000 [3:34:22<5:32:16,  6.62it/s]

[TRAIN] Iter: 68000 Loss: 0.007164039183408022  PSNR: 25.84735107421875


 35%|███▍      | 69001/200000 [3:36:48<5:13:38,  6.96it/s]

[TRAIN] Iter: 69000 Loss: 0.006812990177422762  PSNR: 25.72471046447754


 35%|███▌      | 70001/200000 [3:39:13<5:21:03,  6.75it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\070000.tar
[TRAIN] Iter: 70000 Loss: 0.003460866864770651  PSNR: 28.90595054626465


 36%|███▌      | 71001/200000 [3:41:38<5:11:03,  6.91it/s]

[TRAIN] Iter: 71000 Loss: 0.004611298441886902  PSNR: 27.09161949157715


 36%|███▌      | 72001/200000 [3:44:02<5:08:07,  6.92it/s]

[TRAIN] Iter: 72000 Loss: 0.005645987577736378  PSNR: 27.0596981048584


 37%|███▋      | 73001/200000 [3:46:29<5:40:05,  6.22it/s] 

[TRAIN] Iter: 73000 Loss: 0.004724087193608284  PSNR: 27.32999610900879


 37%|███▋      | 74001/200000 [3:49:05<5:19:56,  6.56it/s]

[TRAIN] Iter: 74000 Loss: 0.004790063481777906  PSNR: 26.808561325073242


 38%|███▊      | 75001/200000 [3:51:36<5:11:56,  6.68it/s]

[TRAIN] Iter: 75000 Loss: 0.004232112318277359  PSNR: 26.97902488708496


 38%|███▊      | 76001/200000 [3:54:08<5:07:01,  6.73it/s]

[TRAIN] Iter: 76000 Loss: 0.003962992690503597  PSNR: 28.582544326782227


 39%|███▊      | 77001/200000 [3:56:40<5:07:45,  6.66it/s]

[TRAIN] Iter: 77000 Loss: 0.00432491023093462  PSNR: 27.444311141967773


 39%|███▉      | 78001/200000 [3:59:05<4:54:29,  6.90it/s]

[TRAIN] Iter: 78000 Loss: 0.004715759307146072  PSNR: 26.948135375976562


 40%|███▉      | 79001/200000 [4:01:47<5:13:03,  6.44it/s] 

[TRAIN] Iter: 79000 Loss: 0.006067204289138317  PSNR: 25.645503997802734


 40%|████      | 80001/200000 [4:04:33<5:30:06,  6.06it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\080000.tar
[TRAIN] Iter: 80000 Loss: 0.005264965817332268  PSNR: 26.582971572875977


 41%|████      | 81001/200000 [4:07:24<5:36:57,  5.89it/s]

[TRAIN] Iter: 81000 Loss: 0.0037454338744282722  PSNR: 27.578296661376953


 41%|████      | 82001/200000 [4:10:11<5:10:24,  6.34it/s]

[TRAIN] Iter: 82000 Loss: 0.003791513154283166  PSNR: 28.663330078125


 42%|████▏     | 83001/200000 [4:12:58<5:12:28,  6.24it/s]

[TRAIN] Iter: 83000 Loss: 0.005584012717008591  PSNR: 27.08004379272461


 42%|████▏     | 84000/200000 [4:15:48<6:53:36,  4.67it/s]

[TRAIN] Iter: 84000 Loss: 0.0037321432027965784  PSNR: 28.517854690551758


 43%|████▎     | 85001/200000 [4:18:34<4:52:35,  6.55it/s]

[TRAIN] Iter: 85000 Loss: 0.004698870703577995  PSNR: 27.492204666137695


 43%|████▎     | 86001/200000 [4:21:06<4:49:46,  6.56it/s]

[TRAIN] Iter: 86000 Loss: 0.004388880450278521  PSNR: 27.60680389404297


 44%|████▎     | 87001/200000 [4:23:39<4:42:46,  6.66it/s]

[TRAIN] Iter: 87000 Loss: 0.005548551212996244  PSNR: 26.585208892822266


 44%|████▍     | 88001/200000 [4:26:12<4:45:28,  6.54it/s]

[TRAIN] Iter: 88000 Loss: 0.003921864554286003  PSNR: 28.011751174926758


 45%|████▍     | 89001/200000 [4:28:42<4:25:28,  6.97it/s]

[TRAIN] Iter: 89000 Loss: 0.005814752541482449  PSNR: 26.63813018798828


 45%|████▌     | 90001/200000 [4:31:07<4:32:36,  6.72it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\090000.tar
[TRAIN] Iter: 90000 Loss: 0.004889360163360834  PSNR: 27.557025909423828


 46%|████▌     | 91001/200000 [4:33:33<4:22:51,  6.91it/s]

[TRAIN] Iter: 91000 Loss: 0.004020638298243284  PSNR: 27.598478317260742


 46%|████▌     | 92001/200000 [4:35:58<4:20:16,  6.92it/s]

[TRAIN] Iter: 92000 Loss: 0.004967436194419861  PSNR: 27.380260467529297


 47%|████▋     | 93001/200000 [4:38:24<4:17:10,  6.93it/s]

[TRAIN] Iter: 93000 Loss: 0.004180189222097397  PSNR: 27.860538482666016


 47%|████▋     | 94001/200000 [4:40:49<4:17:28,  6.86it/s]

[TRAIN] Iter: 94000 Loss: 0.005019684787839651  PSNR: 27.49823760986328


 48%|████▊     | 95001/200000 [4:43:15<4:13:43,  6.90it/s]

[TRAIN] Iter: 95000 Loss: 0.004387437365949154  PSNR: 27.43877601623535


 48%|████▊     | 96001/200000 [4:45:40<4:11:26,  6.89it/s]

[TRAIN] Iter: 96000 Loss: 0.005791021510958672  PSNR: 26.420804977416992


 49%|████▊     | 97001/200000 [4:48:06<4:08:51,  6.90it/s]

[TRAIN] Iter: 97000 Loss: 0.004057193174958229  PSNR: 28.04981231689453


 49%|████▉     | 98001/200000 [4:50:31<4:05:37,  6.92it/s]

[TRAIN] Iter: 98000 Loss: 0.004707618150860071  PSNR: 27.164663314819336


 50%|████▉     | 99001/200000 [4:52:56<4:03:35,  6.91it/s]

[TRAIN] Iter: 99000 Loss: 0.003807329572737217  PSNR: 28.421049118041992


 50%|████▉     | 99999/200000 [4:55:21<4:03:28,  6.85it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\100000.tar




0 0.0023279190063476562




torch.Size([800, 800, 3]) torch.Size([800, 800])
1 31.067312002182007




2 30.954099416732788




3 30.975966453552246




4 30.894644021987915




5 30.945509672164917




6 30.93149447441101




7 30.971210956573486




8 30.95546865463257




9 30.92339301109314




10 30.96631622314453




11 30.946157217025757




12 30.96176528930664




13 30.925667762756348




14 30.97114133834839




15 30.963400840759277




16 30.94664764404297




17 30.981507301330566




18 30.95962953567505




19 30.933119297027588




20 30.979862213134766




21 30.977850437164307




22 30.919346570968628




23 30.919600009918213




24 30.962496042251587




25 30.95562505722046




26 30.95115065574646




27 30.9310142993927




28 30.938868522644043




29 30.935879945755005




30 30.945950746536255




31 30.944103717803955




32 30.985987663269043




33 30.92332673072815




34 30.966848373413086




35 30.939285278320312




36 30.93383765220642




37 30.935628175735474




38 30.952829122543335




39 30.919214725494385


100%|██████████| 40/40 [20:38<00:00, 30.95s/it]


Done, saving (40, 800, 800, 3) (40, 800, 800)
test poses shape torch.Size([25, 4, 4])




0 0.0019288063049316406




torch.Size([800, 800, 3]) torch.Size([800, 800])
Frame 0 - PSNR: 27.21, SSIM: 0.7698, LPIPS: 0.2687
1 31.06692385673523




Frame 1 - PSNR: 26.40, SSIM: 0.7641, LPIPS: 0.2633
2 31.0062096118927




Frame 2 - PSNR: 26.77, SSIM: 0.7733, LPIPS: 0.2574
3 31.042542934417725




Frame 3 - PSNR: 26.70, SSIM: 0.7718, LPIPS: 0.2665
4 31.04079794883728




Frame 4 - PSNR: 27.47, SSIM: 0.8033, LPIPS: 0.2466
5 31.030628204345703




Frame 5 - PSNR: 27.36, SSIM: 0.8283, LPIPS: 0.2379
6 31.058279275894165




Frame 6 - PSNR: 26.47, SSIM: 0.8326, LPIPS: 0.2416
7 31.022740602493286




Frame 7 - PSNR: 26.46, SSIM: 0.8373, LPIPS: 0.2346
8 30.99923038482666




Frame 8 - PSNR: 26.95, SSIM: 0.8450, LPIPS: 0.2282
9 31.029686212539673




Frame 9 - PSNR: 27.53, SSIM: 0.8654, LPIPS: 0.2085
10 31.039998054504395




Frame 10 - PSNR: 27.87, SSIM: 0.8827, LPIPS: 0.1947
11 31.045485734939575




Frame 11 - PSNR: 26.44, SSIM: 0.8670, LPIPS: 0.2008
12 31.0511155128479




Frame 12 - PSNR: 25.32, SSIM: 0.8609, LPIPS: 0.2133
13 31.038606643676758




Frame 13 - PSNR: 25.44, SSIM: 0.8675, LPIPS: 0.1998
14 31.00075387954712




Frame 14 - PSNR: 25.78, SSIM: 0.8855, LPIPS: 0.1791
15 31.04326367378235




Frame 15 - PSNR: 27.16, SSIM: 0.8977, LPIPS: 0.1652
16 31.049798488616943




Frame 16 - PSNR: 27.70, SSIM: 0.8829, LPIPS: 0.1744
17 31.04252862930298




Frame 17 - PSNR: 26.72, SSIM: 0.8606, LPIPS: 0.2008
18 31.05394434928894




Frame 18 - PSNR: 26.13, SSIM: 0.8428, LPIPS: 0.2270
19 31.044567108154297




Frame 19 - PSNR: 26.46, SSIM: 0.8315, LPIPS: 0.2468
20 31.065054655075073




Frame 20 - PSNR: 26.80, SSIM: 0.8315, LPIPS: 0.2497
21 31.06117057800293




Frame 21 - PSNR: 27.31, SSIM: 0.8247, LPIPS: 0.2619
22 31.060892820358276




Frame 22 - PSNR: 28.12, SSIM: 0.8283, LPIPS: 0.2703
23 31.04687261581421




Frame 23 - PSNR: 28.23, SSIM: 0.8107, LPIPS: 0.2666
24 31.04159426689148


100%|██████████| 25/25 [12:56<00:00, 31.04s/it]
 50%|█████     | 100000/200000 [5:28:57<16795:26:46, 604.64s/it]

Frame 24 - PSNR: 27.85, SSIM: 0.7872, LPIPS: 0.2712

=== METRICS SUMMARY ===
Average PSNR: 26.9064 ± 0.7650
Average SSIM: 0.8341 ± 0.0383
Average LPIPS: 0.2310 ± 0.0323
Metrics logged to: ./logs\ship_blender200k_fullres_higher_samples\metrics\metrics_100000.txt and ./logs\ship_blender200k_fullres_higher_samples\metrics\metrics_100000.json
Training history updated: ./logs\ship_blender200k_fullres_higher_samples\metrics\training_metrics.json
Saved test set with metrics
[TRAIN] Iter: 100000 Loss: 0.006200630217790604  PSNR: 25.94236183166504


 51%|█████     | 101001/200000 [5:31:22<3:54:37,  7.03it/s]     

[TRAIN] Iter: 101000 Loss: 0.004293335601687431  PSNR: 28.03034019470215


 51%|█████     | 102001/200000 [5:33:47<3:56:06,  6.92it/s]

[TRAIN] Iter: 102000 Loss: 0.004541709553450346  PSNR: 27.728885650634766


 52%|█████▏    | 103001/200000 [5:36:12<3:51:54,  6.97it/s]

[TRAIN] Iter: 103000 Loss: 0.005602912046015263  PSNR: 27.224464416503906


 52%|█████▏    | 104001/200000 [5:38:37<3:54:54,  6.81it/s]

[TRAIN] Iter: 104000 Loss: 0.005136994179338217  PSNR: 27.343841552734375


 53%|█████▎    | 105001/200000 [5:41:02<3:49:00,  6.91it/s]

[TRAIN] Iter: 105000 Loss: 0.005217303521931171  PSNR: 27.183639526367188


 53%|█████▎    | 106001/200000 [5:43:28<3:45:35,  6.94it/s]

[TRAIN] Iter: 106000 Loss: 0.005177712999284267  PSNR: 26.86688232421875


 54%|█████▎    | 107001/200000 [5:45:53<3:45:11,  6.88it/s]

[TRAIN] Iter: 107000 Loss: 0.005254107061773539  PSNR: 27.506122589111328


 54%|█████▍    | 108001/200000 [5:48:18<3:41:26,  6.92it/s]

[TRAIN] Iter: 108000 Loss: 0.005350267514586449  PSNR: 26.895803451538086


 55%|█████▍    | 109001/200000 [5:50:44<3:37:46,  6.96it/s]

[TRAIN] Iter: 109000 Loss: 0.0042017241939902306  PSNR: 27.819011688232422


 55%|█████▌    | 110001/200000 [5:53:09<3:43:27,  6.71it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\110000.tar
[TRAIN] Iter: 110000 Loss: 0.004724619444459677  PSNR: 27.157394409179688


 56%|█████▌    | 111001/200000 [5:55:34<3:32:14,  6.99it/s]

[TRAIN] Iter: 111000 Loss: 0.004086778499186039  PSNR: 27.698928833007812


 56%|█████▌    | 112001/200000 [5:58:00<3:30:46,  6.96it/s]

[TRAIN] Iter: 112000 Loss: 0.004201927687972784  PSNR: 27.80699348449707


 57%|█████▋    | 113001/200000 [6:00:25<3:28:11,  6.96it/s]

[TRAIN] Iter: 113000 Loss: 0.005704249255359173  PSNR: 26.449321746826172


 57%|█████▋    | 114001/200000 [6:02:50<3:26:10,  6.95it/s]

[TRAIN] Iter: 114000 Loss: 0.003693635342642665  PSNR: 28.82670021057129


 58%|█████▊    | 115001/200000 [6:05:15<3:24:34,  6.92it/s]

[TRAIN] Iter: 115000 Loss: 0.003365281270816922  PSNR: 28.542797088623047


 58%|█████▊    | 116001/200000 [6:07:40<3:22:57,  6.90it/s]

[TRAIN] Iter: 116000 Loss: 0.003482447238638997  PSNR: 29.210186004638672


 59%|█████▊    | 117001/200000 [6:10:05<3:19:29,  6.93it/s]

[TRAIN] Iter: 117000 Loss: 0.0029746186919510365  PSNR: 28.838104248046875


 59%|█████▉    | 118001/200000 [6:12:30<3:16:09,  6.97it/s]

[TRAIN] Iter: 118000 Loss: 0.004618261009454727  PSNR: 27.615617752075195


 60%|█████▉    | 119001/200000 [6:14:55<3:14:09,  6.95it/s]

[TRAIN] Iter: 119000 Loss: 0.004182604141533375  PSNR: 27.446109771728516


 60%|██████    | 120001/200000 [6:17:20<3:14:30,  6.85it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\120000.tar
[TRAIN] Iter: 120000 Loss: 0.0033258101902902126  PSNR: 28.137081146240234


 61%|██████    | 121001/200000 [6:19:45<3:09:02,  6.96it/s]

[TRAIN] Iter: 121000 Loss: 0.005257769487798214  PSNR: 26.686174392700195


 61%|██████    | 122001/200000 [6:22:10<3:06:35,  6.97it/s]

[TRAIN] Iter: 122000 Loss: 0.004072974435985088  PSNR: 28.23678207397461


 62%|██████▏   | 123001/200000 [6:24:35<3:02:55,  7.02it/s]

[TRAIN] Iter: 123000 Loss: 0.005924940574914217  PSNR: 25.961015701293945


 62%|██████▏   | 124001/200000 [6:27:00<3:02:38,  6.94it/s]

[TRAIN] Iter: 124000 Loss: 0.006723129190504551  PSNR: 26.560293197631836


 63%|██████▎   | 125001/200000 [6:29:25<2:59:31,  6.96it/s]

[TRAIN] Iter: 125000 Loss: 0.005099298898130655  PSNR: 26.432092666625977


 63%|██████▎   | 126001/200000 [6:31:50<2:57:35,  6.94it/s]

[TRAIN] Iter: 126000 Loss: 0.003798385849222541  PSNR: 28.47818374633789


 64%|██████▎   | 127001/200000 [6:34:15<2:56:48,  6.88it/s]

[TRAIN] Iter: 127000 Loss: 0.003936982713639736  PSNR: 27.715106964111328


 64%|██████▍   | 128001/200000 [6:36:40<2:52:33,  6.95it/s]

[TRAIN] Iter: 128000 Loss: 0.0053092110902071  PSNR: 25.600801467895508


 65%|██████▍   | 129001/200000 [6:39:05<2:49:19,  6.99it/s]

[TRAIN] Iter: 129000 Loss: 0.0043903132900595665  PSNR: 27.51958465576172


 65%|██████▌   | 130001/200000 [6:41:30<2:51:38,  6.80it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\130000.tar
[TRAIN] Iter: 130000 Loss: 0.0034678212832659483  PSNR: 28.968069076538086


 66%|██████▌   | 131001/200000 [6:43:55<2:46:01,  6.93it/s]

[TRAIN] Iter: 131000 Loss: 0.004328851122409105  PSNR: 27.67096519470215


 66%|██████▌   | 132001/200000 [6:46:20<2:42:11,  6.99it/s]

[TRAIN] Iter: 132000 Loss: 0.003102600108832121  PSNR: 29.176959991455078


 67%|██████▋   | 133001/200000 [6:48:45<2:41:25,  6.92it/s]

[TRAIN] Iter: 133000 Loss: 0.0050847032107412815  PSNR: 27.284000396728516


 67%|██████▋   | 134001/200000 [6:51:10<2:40:13,  6.87it/s]

[TRAIN] Iter: 134000 Loss: 0.004374898970127106  PSNR: 27.908533096313477


 68%|██████▊   | 135001/200000 [6:53:35<2:35:36,  6.96it/s]

[TRAIN] Iter: 135000 Loss: 0.005114286672323942  PSNR: 27.246885299682617


 68%|██████▊   | 136001/200000 [6:56:00<2:34:30,  6.90it/s]

[TRAIN] Iter: 136000 Loss: 0.003604557830840349  PSNR: 28.379344940185547


 69%|██████▊   | 137001/200000 [6:58:24<2:32:41,  6.88it/s]

[TRAIN] Iter: 137000 Loss: 0.007415123283863068  PSNR: 25.73487663269043


 69%|██████▉   | 138001/200000 [7:00:49<2:29:17,  6.92it/s]

[TRAIN] Iter: 138000 Loss: 0.004053005017340183  PSNR: 28.71807289123535


 70%|██████▉   | 139001/200000 [7:03:14<2:27:26,  6.89it/s]

[TRAIN] Iter: 139000 Loss: 0.005127715412527323  PSNR: 26.724227905273438


 70%|███████   | 140001/200000 [7:05:40<2:26:22,  6.83it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\140000.tar
[TRAIN] Iter: 140000 Loss: 0.00598546490073204  PSNR: 26.82205581665039


 71%|███████   | 141001/200000 [7:08:05<2:22:39,  6.89it/s]

[TRAIN] Iter: 141000 Loss: 0.003824398387223482  PSNR: 28.304418563842773


 71%|███████   | 142001/200000 [7:10:30<2:20:03,  6.90it/s]

[TRAIN] Iter: 142000 Loss: 0.004684159532189369  PSNR: 27.607040405273438


 72%|███████▏  | 143001/200000 [7:12:55<2:16:52,  6.94it/s]

[TRAIN] Iter: 143000 Loss: 0.003374851308763027  PSNR: 28.495933532714844


 72%|███████▏  | 144001/200000 [7:15:20<2:14:59,  6.91it/s]

[TRAIN] Iter: 144000 Loss: 0.005234038457274437  PSNR: 26.15675926208496


 73%|███████▎  | 145001/200000 [7:17:45<2:12:19,  6.93it/s]

[TRAIN] Iter: 145000 Loss: 0.004058373160660267  PSNR: 27.85353660583496


 73%|███████▎  | 146001/200000 [7:20:10<2:09:43,  6.94it/s]

[TRAIN] Iter: 146000 Loss: 0.003911430481821299  PSNR: 27.908241271972656


 74%|███████▎  | 147001/200000 [7:22:35<2:08:24,  6.88it/s]

[TRAIN] Iter: 147000 Loss: 0.0035303495824337006  PSNR: 30.09024429321289


 74%|███████▍  | 148001/200000 [7:25:00<2:05:50,  6.89it/s]

[TRAIN] Iter: 148000 Loss: 0.004676308482885361  PSNR: 26.844482421875


 75%|███████▍  | 149001/200000 [7:27:25<2:03:56,  6.86it/s]

[TRAIN] Iter: 149000 Loss: 0.005103479139506817  PSNR: 26.629369735717773


 75%|███████▍  | 149999/200000 [7:29:50<2:00:29,  6.92it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\150000.tar




0 0.0016825199127197266




torch.Size([800, 800, 3]) torch.Size([800, 800])
1 31.02525806427002




2 30.896987676620483




3 30.927396535873413




4 30.943825483322144




5 30.925148248672485




6 30.967805862426758




7 30.907572269439697




8 30.946932315826416




9 30.925058364868164




10 30.911763429641724




11 30.93995213508606




12 30.930694341659546




13 30.941697120666504




14 30.938517093658447




15 30.879674911499023




16 30.904196977615356




17 30.943002700805664




18 30.93690037727356




19 30.92060613632202




20 30.94822907447815




21 30.93238115310669




22 30.971762895584106




23 30.926698923110962




24 30.920185327529907




25 30.86725354194641




26 30.932913303375244




27 30.884018659591675




28 30.910454273223877




29 30.930466413497925




30 30.905494213104248




31 30.902165174484253




32 30.91347622871399




33 30.91509985923767




34 30.919166803359985




35 30.957637071609497




36 30.919461011886597




37 30.909079790115356




38 30.929067850112915




39 30.926214694976807


100%|██████████| 40/40 [20:37<00:00, 30.93s/it]


Done, saving (40, 800, 800, 3) (40, 800, 800)
test poses shape torch.Size([25, 4, 4])




0 0.0020101070404052734




torch.Size([800, 800, 3]) torch.Size([800, 800])
Frame 0 - PSNR: 27.72, SSIM: 0.7808, LPIPS: 0.2513
1 31.058589458465576




Frame 1 - PSNR: 26.87, SSIM: 0.7761, LPIPS: 0.2446
2 31.008776426315308




Frame 2 - PSNR: 27.28, SSIM: 0.7870, LPIPS: 0.2398
3 31.033056497573853




Frame 3 - PSNR: 27.06, SSIM: 0.7796, LPIPS: 0.2504
4 30.99478316307068




Frame 4 - PSNR: 27.89, SSIM: 0.8101, LPIPS: 0.2326
5 31.027581691741943




Frame 5 - PSNR: 27.93, SSIM: 0.8360, LPIPS: 0.2241
6 31.058845043182373




Frame 6 - PSNR: 26.99, SSIM: 0.8404, LPIPS: 0.2294
7 30.989991903305054




Frame 7 - PSNR: 27.14, SSIM: 0.8451, LPIPS: 0.2237
8 30.987743139266968




Frame 8 - PSNR: 27.56, SSIM: 0.8516, LPIPS: 0.2198
9 31.048692226409912




Frame 9 - PSNR: 28.22, SSIM: 0.8712, LPIPS: 0.1984
10 31.00091791152954




Frame 10 - PSNR: 28.58, SSIM: 0.8882, LPIPS: 0.1866
11 31.00019359588623




Frame 11 - PSNR: 26.85, SSIM: 0.8722, LPIPS: 0.1908
12 31.00786566734314




Frame 12 - PSNR: 25.79, SSIM: 0.8685, LPIPS: 0.1987
13 30.960247039794922




Frame 13 - PSNR: 25.92, SSIM: 0.8748, LPIPS: 0.1871
14 30.990488052368164




Frame 14 - PSNR: 26.16, SSIM: 0.8933, LPIPS: 0.1670
15 31.000401973724365




Frame 15 - PSNR: 27.78, SSIM: 0.9059, LPIPS: 0.1501
16 31.02807092666626




Frame 16 - PSNR: 28.24, SSIM: 0.8893, LPIPS: 0.1634
17 31.059220552444458




Frame 17 - PSNR: 27.33, SSIM: 0.8675, LPIPS: 0.1886
18 31.00451683998108




Frame 18 - PSNR: 26.65, SSIM: 0.8498, LPIPS: 0.2158
19 31.019503355026245




Frame 19 - PSNR: 27.03, SSIM: 0.8388, LPIPS: 0.2321
20 31.002334356307983




Frame 20 - PSNR: 27.44, SSIM: 0.8394, LPIPS: 0.2384
21 31.024771690368652




Frame 21 - PSNR: 27.85, SSIM: 0.8318, LPIPS: 0.2494
22 31.027872562408447




Frame 22 - PSNR: 28.69, SSIM: 0.8350, LPIPS: 0.2563
23 31.016927480697632




Frame 23 - PSNR: 28.82, SSIM: 0.8182, LPIPS: 0.2510
24 31.03699827194214


100%|██████████| 25/25 [12:55<00:00, 31.02s/it]
 75%|███████▌  | 150000/200000 [8:03:23<8391:34:32, 604.19s/it]

Frame 24 - PSNR: 28.34, SSIM: 0.7956, LPIPS: 0.2522

=== METRICS SUMMARY ===
Average PSNR: 27.4455 ± 0.8058
Average SSIM: 0.8418 ± 0.0371
Average LPIPS: 0.2177 ± 0.0307
Metrics logged to: ./logs\ship_blender200k_fullres_higher_samples\metrics\metrics_150000.txt and ./logs\ship_blender200k_fullres_higher_samples\metrics\metrics_150000.json
Training history updated: ./logs\ship_blender200k_fullres_higher_samples\metrics\training_metrics.json
Saved test set with metrics
[TRAIN] Iter: 150000 Loss: 0.00398494815453887  PSNR: 27.997238159179688


 76%|███████▌  | 151001/200000 [8:05:48<1:56:48,  6.99it/s]    

[TRAIN] Iter: 151000 Loss: 0.003865275764837861  PSNR: 28.621103286743164


 76%|███████▌  | 152001/200000 [8:08:13<1:55:26,  6.93it/s]

[TRAIN] Iter: 152000 Loss: 0.00553170358762145  PSNR: 26.91531753540039


 77%|███████▋  | 153001/200000 [8:10:38<1:53:22,  6.91it/s]

[TRAIN] Iter: 153000 Loss: 0.0040213847532868385  PSNR: 27.622180938720703


 77%|███████▋  | 154001/200000 [8:13:03<1:52:08,  6.84it/s]

[TRAIN] Iter: 154000 Loss: 0.007591934408992529  PSNR: 25.109071731567383


 78%|███████▊  | 155001/200000 [8:15:29<1:48:22,  6.92it/s]

[TRAIN] Iter: 155000 Loss: 0.004640092607587576  PSNR: 26.759418487548828


 78%|███████▊  | 156001/200000 [8:17:53<1:45:04,  6.98it/s]

[TRAIN] Iter: 156000 Loss: 0.005853986367583275  PSNR: 26.86829376220703


 79%|███████▊  | 157001/200000 [8:20:18<1:43:54,  6.90it/s]

[TRAIN] Iter: 157000 Loss: 0.0036208939272910357  PSNR: 28.123624801635742


 79%|███████▉  | 158001/200000 [8:22:43<1:40:43,  6.95it/s]

[TRAIN] Iter: 158000 Loss: 0.0036833910271525383  PSNR: 28.043729782104492


 80%|███████▉  | 159001/200000 [8:25:08<1:38:20,  6.95it/s]

[TRAIN] Iter: 159000 Loss: 0.002436238806694746  PSNR: 30.589113235473633


 80%|████████  | 160001/200000 [8:27:34<1:39:55,  6.67it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\160000.tar
[TRAIN] Iter: 160000 Loss: 0.003257605480030179  PSNR: 28.250001907348633


 81%|████████  | 161001/200000 [8:30:01<1:38:06,  6.63it/s]

[TRAIN] Iter: 161000 Loss: 0.003677332540974021  PSNR: 28.748178482055664


 81%|████████  | 162001/200000 [8:32:38<1:36:55,  6.53it/s]

[TRAIN] Iter: 162000 Loss: 0.00447121262550354  PSNR: 27.009580612182617


 82%|████████▏ | 163001/200000 [8:35:13<1:34:40,  6.51it/s]

[TRAIN] Iter: 163000 Loss: 0.002999486867338419  PSNR: 30.199893951416016


 82%|████████▏ | 164001/200000 [8:37:48<1:33:24,  6.42it/s]

[TRAIN] Iter: 164000 Loss: 0.006042309571057558  PSNR: 26.753955841064453


 83%|████████▎ | 165001/200000 [8:40:23<1:28:22,  6.60it/s]

[TRAIN] Iter: 165000 Loss: 0.004785830620676279  PSNR: 27.620893478393555


 83%|████████▎ | 166001/200000 [8:42:40<1:16:30,  7.41it/s]

[TRAIN] Iter: 166000 Loss: 0.002842977875843644  PSNR: 29.517826080322266


 84%|████████▎ | 167001/200000 [8:44:56<1:14:10,  7.41it/s]

[TRAIN] Iter: 167000 Loss: 0.0034402606543153524  PSNR: 28.39792823791504


 84%|████████▍ | 168001/200000 [8:47:12<1:12:21,  7.37it/s]

[TRAIN] Iter: 168000 Loss: 0.0031167897395789623  PSNR: 29.389673233032227


 85%|████████▍ | 169001/200000 [8:49:28<1:09:41,  7.41it/s]

[TRAIN] Iter: 169000 Loss: 0.004837749991565943  PSNR: 27.95529556274414


 85%|████████▌ | 170001/200000 [8:51:45<1:09:29,  7.20it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\170000.tar
[TRAIN] Iter: 170000 Loss: 0.00473704794421792  PSNR: 26.953365325927734


 86%|████████▌ | 171001/200000 [8:54:01<1:05:09,  7.42it/s]

[TRAIN] Iter: 171000 Loss: 0.003300874261185527  PSNR: 29.717453002929688


 86%|████████▌ | 172001/200000 [8:56:17<1:02:51,  7.42it/s]

[TRAIN] Iter: 172000 Loss: 0.00472907442599535  PSNR: 29.180816650390625


 87%|████████▋ | 173001/200000 [8:58:33<1:01:08,  7.36it/s]

[TRAIN] Iter: 173000 Loss: 0.004728070460259914  PSNR: 27.274417877197266


 87%|████████▋ | 174001/200000 [9:00:50<59:21,  7.30it/s]  

[TRAIN] Iter: 174000 Loss: 0.0035652623046189547  PSNR: 28.630290985107422


 88%|████████▊ | 175001/200000 [9:03:06<56:56,  7.32it/s]

[TRAIN] Iter: 175000 Loss: 0.004304095171391964  PSNR: 28.187604904174805


 88%|████████▊ | 176001/200000 [9:05:22<53:49,  7.43it/s]

[TRAIN] Iter: 176000 Loss: 0.005350832361727953  PSNR: 26.895156860351562


 89%|████████▊ | 177001/200000 [9:07:38<52:24,  7.31it/s]

[TRAIN] Iter: 177000 Loss: 0.00293621513992548  PSNR: 29.601482391357422


 89%|████████▉ | 178001/200000 [9:09:54<49:37,  7.39it/s]

[TRAIN] Iter: 178000 Loss: 0.005099562928080559  PSNR: 27.016246795654297


 90%|████████▉ | 179001/200000 [9:12:11<47:03,  7.44it/s]

[TRAIN] Iter: 179000 Loss: 0.005641625262796879  PSNR: 27.078336715698242


 90%|█████████ | 180001/200000 [9:14:27<46:06,  7.23it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\180000.tar
[TRAIN] Iter: 180000 Loss: 0.0036047559697180986  PSNR: 28.67751121520996


 91%|█████████ | 181001/200000 [9:16:43<43:01,  7.36it/s]

[TRAIN] Iter: 181000 Loss: 0.0036166564095765352  PSNR: 28.236404418945312


 91%|█████████ | 182001/200000 [9:19:00<40:44,  7.36it/s]

[TRAIN] Iter: 182000 Loss: 0.005110572092235088  PSNR: 27.15743064880371


 92%|█████████▏| 183001/200000 [9:21:16<38:34,  7.34it/s]

[TRAIN] Iter: 183000 Loss: 0.003563643665984273  PSNR: 28.190204620361328


 92%|█████████▏| 184001/200000 [9:23:32<36:08,  7.38it/s]

[TRAIN] Iter: 184000 Loss: 0.0045555355027318  PSNR: 28.257755279541016


 93%|█████████▎| 185001/200000 [9:25:49<33:55,  7.37it/s]

[TRAIN] Iter: 185000 Loss: 0.004909185692667961  PSNR: 27.040857315063477


 93%|█████████▎| 186001/200000 [9:28:05<31:35,  7.38it/s]

[TRAIN] Iter: 186000 Loss: 0.0033997236751019955  PSNR: 28.71727752685547


 94%|█████████▎| 187001/200000 [9:30:21<29:19,  7.39it/s]

[TRAIN] Iter: 187000 Loss: 0.004810669459402561  PSNR: 27.781835556030273


 94%|█████████▍| 188001/200000 [9:32:38<26:56,  7.42it/s]

[TRAIN] Iter: 188000 Loss: 0.005122028291225433  PSNR: 27.942386627197266


 95%|█████████▍| 189001/200000 [9:34:54<24:55,  7.35it/s]

[TRAIN] Iter: 189000 Loss: 0.0038134329952299595  PSNR: 28.857641220092773


 95%|█████████▌| 190001/200000 [9:37:10<22:54,  7.27it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\190000.tar
[TRAIN] Iter: 190000 Loss: 0.0039069755002856255  PSNR: 27.968833923339844


 96%|█████████▌| 191001/200000 [9:39:27<20:19,  7.38it/s]

[TRAIN] Iter: 191000 Loss: 0.0030521138105541468  PSNR: 29.804264068603516


 96%|█████████▌| 192001/200000 [9:41:43<17:59,  7.41it/s]

[TRAIN] Iter: 192000 Loss: 0.003878558985888958  PSNR: 29.060653686523438


 97%|█████████▋| 193001/200000 [9:43:59<15:53,  7.34it/s]

[TRAIN] Iter: 193000 Loss: 0.004135563038289547  PSNR: 28.201889038085938


 97%|█████████▋| 194001/200000 [9:46:16<13:30,  7.40it/s]

[TRAIN] Iter: 194000 Loss: 0.002895652549341321  PSNR: 29.027353286743164


 98%|█████████▊| 195001/200000 [9:48:32<11:17,  7.38it/s]

[TRAIN] Iter: 195000 Loss: 0.003973056562244892  PSNR: 27.85625457763672


 98%|█████████▊| 196001/200000 [9:50:49<09:03,  7.36it/s]

[TRAIN] Iter: 196000 Loss: 0.003619161434471607  PSNR: 28.851978302001953


 99%|█████████▊| 197001/200000 [9:53:05<06:52,  7.27it/s]

[TRAIN] Iter: 197000 Loss: 0.004476785659790039  PSNR: 27.302406311035156


 99%|█████████▉| 198001/200000 [9:55:21<04:32,  7.35it/s]

[TRAIN] Iter: 198000 Loss: 0.0028619635850191116  PSNR: 29.365449905395508


100%|█████████▉| 199001/200000 [9:57:38<02:15,  7.35it/s]

[TRAIN] Iter: 199000 Loss: 0.00480707036331296  PSNR: 27.200193405151367


100%|█████████▉| 199999/200000 [9:59:54<00:00,  7.33it/s]

Saved checkpoints at ./logs\ship_blender200k_fullres_higher_samples\checkpoints\200000.tar




0 0.0017039775848388672




torch.Size([800, 800, 3]) torch.Size([800, 800])
1 29.001757621765137




2 28.97565245628357




3 29.01278829574585




4 29.001259088516235




5 29.00534224510193




6 32.92943215370178




7 29.738768339157104




8 29.389034032821655




9 29.392902135849




10 29.39199924468994




11 29.392184019088745




12 29.413095712661743




13 29.40765905380249




14 29.405179262161255




15 29.409266233444214




16 29.402910709381104




17 29.399266004562378




18 29.03937602043152




19 29.0261709690094




20 29.040441751480103




21 29.03043794631958




22 29.0256404876709




23 29.017939567565918




24 29.045901775360107




25 29.02931833267212




26 29.042506217956543




27 29.034425735473633




28 29.040652990341187




29 29.03409719467163




30 29.04183578491211




31 29.03053879737854




32 30.384079933166504




33 35.53028655052185




34 29.37862539291382




35 29.389444589614868




36 29.383774757385254




37 29.37813115119934




38 29.377781867980957




39 29.37355089187622


100%|██████████| 40/40 [19:39<00:00, 29.49s/it]


Done, saving (40, 800, 800, 3) (40, 800, 800)
test poses shape torch.Size([25, 4, 4])




0 0.001956462860107422




torch.Size([800, 800, 3]) torch.Size([800, 800])
Frame 0 - PSNR: 28.02, SSIM: 0.7883, LPIPS: 0.2417
1 29.540377616882324




Frame 1 - PSNR: 27.17, SSIM: 0.7843, LPIPS: 0.2324
2 29.50209069252014




Frame 2 - PSNR: 27.62, SSIM: 0.7957, LPIPS: 0.2296
3 29.597777128219604




Frame 3 - PSNR: 27.25, SSIM: 0.7855, LPIPS: 0.2415
4 29.121036529541016




Frame 4 - PSNR: 28.12, SSIM: 0.8145, LPIPS: 0.2249
5 29.15498638153076




Frame 5 - PSNR: 28.27, SSIM: 0.8416, LPIPS: 0.2140
6 29.15734052658081




Frame 6 - PSNR: 27.31, SSIM: 0.8461, LPIPS: 0.2199
7 29.12328815460205




Frame 7 - PSNR: 27.46, SSIM: 0.8513, LPIPS: 0.2130
8 29.129096508026123




Frame 8 - PSNR: 27.97, SSIM: 0.8574, LPIPS: 0.2088
9 29.120907068252563




Frame 9 - PSNR: 28.58, SSIM: 0.8746, LPIPS: 0.1945
10 29.111916542053223




Frame 10 - PSNR: 28.86, SSIM: 0.8913, LPIPS: 0.1823
11 29.130505323410034




Frame 11 - PSNR: 26.99, SSIM: 0.8749, LPIPS: 0.1859
12 29.13886857032776




Frame 12 - PSNR: 26.05, SSIM: 0.8718, LPIPS: 0.1883
13 29.137919664382935




Frame 13 - PSNR: 26.06, SSIM: 0.8789, LPIPS: 0.1771
14 29.136210441589355




Frame 14 - PSNR: 26.54, SSIM: 0.8974, LPIPS: 0.1571
15 29.151644468307495




Frame 15 - PSNR: 28.12, SSIM: 0.9103, LPIPS: 0.1445
16 29.13982105255127




Frame 16 - PSNR: 28.62, SSIM: 0.8947, LPIPS: 0.1558
17 29.10737442970276




Frame 17 - PSNR: 27.59, SSIM: 0.8731, LPIPS: 0.1789
18 29.126954555511475




Frame 18 - PSNR: 26.81, SSIM: 0.8535, LPIPS: 0.2065
19 29.138595819473267




Frame 19 - PSNR: 27.36, SSIM: 0.8447, LPIPS: 0.2211
20 29.147693634033203




Frame 20 - PSNR: 27.83, SSIM: 0.8455, LPIPS: 0.2293
21 29.157140016555786




Frame 21 - PSNR: 28.16, SSIM: 0.8367, LPIPS: 0.2391
22 29.147220134735107




Frame 22 - PSNR: 28.99, SSIM: 0.8386, LPIPS: 0.2456
23 29.126991987228394




Frame 23 - PSNR: 29.10, SSIM: 0.8227, LPIPS: 0.2430
24 29.151573419570923


100%|██████████| 25/25 [12:09<00:00, 29.19s/it]
100%|██████████| 200000/200000 [10:31:44<00:00,  5.28it/s] 

Frame 24 - PSNR: 28.50, SSIM: 0.8001, LPIPS: 0.2466

=== METRICS SUMMARY ===
Average PSNR: 27.7337 ± 0.8293
Average SSIM: 0.8469 ± 0.0361
Average LPIPS: 0.2089 ± 0.0299
Metrics logged to: ./logs\ship_blender200k_fullres_higher_samples\metrics\metrics_200000.txt and ./logs\ship_blender200k_fullres_higher_samples\metrics\metrics_200000.json
Training history updated: ./logs\ship_blender200k_fullres_higher_samples\metrics\training_metrics.json
Saved test set with metrics
[TRAIN] Iter: 200000 Loss: 0.003799606580287218  PSNR: 29.015316009521484



