In [None]:
import os
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from tqdm import trange

from model import *
from rays_util import *
from render import *
from load_llff import *

In [None]:
# For forward Passing

def chunkify(
  inputs: torch.Tensor,
  chunksize: int = 2**15
):
  """
  Divide input into chunks.
  Return: List(torch.Tensor)
  """
  return [inputs[i:i+chunksize] for i in range(0, inputs.shape[0], chunksize)]

def batchify_points(
  points: torch.Tensor,
  encoding_function: torch.Tensor,
  chunksize: int = 2**15
):
  """
  Encode and chunkify "points" to prepare for NeRF model.
  Return: List[torch.Tensor]
  """
  points = points.reshape((-1, 3))
  points = encoding_function(points)
  points = chunkify(points, chunksize=chunksize)
  return points

def batchify_viewdirs(
  points: torch.Tensor,
  rays_d: torch.Tensor,
  encoding_function: torch.Tensor,
  chunksize: int = 2**15
):
  """
  Encode and chunkify "viewdirs" to prepare for NeRF model.
  Return: List[torch.Tensor]
  """
  viewdirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
  viewdirs_flat = viewdirs[:, None, ...].expand(points.shape).reshape((-1, 3))
  embedded_viewdirs = encoding_function(viewdirs_flat)
  viewdirs = chunkify(embedded_viewdirs, chunksize=chunksize)
  return viewdirs

In [None]:
def nerf_forward(
  rays_o: torch.Tensor,
  rays_d: torch.Tensor,
  near: float,
  far: float,
  encoding_fn: torch.Tensor,
  coarse_model: nn.Module,
  fine_model = None,
  stratified_sampling_kwargs: dict = None,
  N_importance: int = 0,
  hierarchical_sampling_kwargs: dict = None,
  viewdirs_encoding_fn: torch.Tensor = None,
  chunksize: int = 2**15
):
  """
  True full forward pass through all function and models
  Return: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict]
  """

  # Set no kwargs if none are given.
  if stratified_sampling_kwargs is None:
    stratified_sampling_kwargs = {}
  if hierarchical_sampling_kwargs is None:
    hierarchical_sampling_kwargs = {}

  # Strtified sampling(stage 1) for coarse query points.
  sample_pts, z_vals = sample_stratified(rays_o, rays_d, near, far, **stratified_sampling_kwargs)

  # Prepare batches.
  batches = batchify_points(sample_pts, encoding_fn, chunksize=chunksize)
  if viewdirs_encoding_fn is not None:
    batches_viewdirs = batchify_viewdirs(sample_pts, rays_d,
                                               viewdirs_encoding_fn,
                                               chunksize=chunksize)
  else:
    batches_viewdirs = [None] * len(batches)

  # Coarse model pass.
  # Split the encoded points into "chunks", run the model on all chunks
  # and concatenate the results (avoid OOM).
  predictions = []
  for batch, batch_viewdirs in zip(batches, batches_viewdirs):
    predictions.append(coarse_model(batch, viewdirs=batch_viewdirs))
  raw = torch.cat(predictions, dim=0)
  raw = raw.reshape(list(sample_pts.shape[:2]) + [raw.shape[-1]])

  # Volume rendering
  rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d)
  outputs = {
      'z_vals_stratified': z_vals
  }

  # Fine model pass.
  if N_importance > 0:
    # Save previous outputs to return.
    rgb_map_0, depth_map_0, acc_map_0 = rgb_map, depth_map, acc_map

    # Apply hierarchical sampling(stage 2) for fine query points.
    sample_pts, z_vals_combined, z_hierarch = sample_hierarchical(
      rays_o, rays_d, z_vals, weights, N_importance,
      **hierarchical_sampling_kwargs)

    # Prepare batches.
    batches = batchify_points(sample_pts, encoding_fn, chunksize=chunksize)
    if viewdirs_encoding_fn is not None:
      batches_viewdirs = batchify_viewdirs(sample_pts, rays_d,
                                                 viewdirs_encoding_fn,
                                                 chunksize=chunksize)
    else:
      batches_viewdirs = [None] * len(batches)

    # Forward pass new samples through fine model.
    fine_model = fine_model if fine_model is not None else coarse_model
    predictions = []
    for batch, batch_viewdirs in zip(batches, batches_viewdirs):
      predictions.append(fine_model(batch, viewdirs=batch_viewdirs))
    raw = torch.cat(predictions, dim=0)
    raw = raw.reshape(list(sample_pts.shape[:2]) + [raw.shape[-1]])

    # Volume rendering
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals_combined, rays_d)

    # Store outputs.
    outputs['z_vals_hierarchical'] = z_hierarch
    outputs['rgb_map_0'] = rgb_map_0
    outputs['depth_map_0'] = depth_map_0
    outputs['acc_map_0'] = acc_map_0

  # Store outputs.
  outputs['rgb_map'] = rgb_map
  outputs['depth_map'] = depth_map
  outputs['acc_map'] = acc_map
  outputs['weights'] = weights
  #input()
  return outputs

In [None]:
"""
Hyperparameters
"""

# Model
d_filter = 256          # Dimensions of linear layer filters
n_layers = 8            # Number of layers in network bottleneck
skip = [4]              # Layers at which to apply input residual
use_fine_model = True   # If set, creates a fine model
d_filter_fine = 256     # Dimensions of linear layer filters of fine network
n_layers_fine = 8       # Number of layers in fine network bottleneck

# Stratified sampling
n_samples = 64          # Number of spatial samples per ray
perturb = True          # If set, applies noise to sample positions
inverse_depth = False   # If set, samples points linearly in inverse depth

# Hierarchical sampling
N_importance = 64            # Number of samples per ray
perturb_hierarchical = True  # If set, applies noise to sample positions

# Encoder
d_input = 3           # Number of input dimensions
n_freqs = 10          # Number of encoding functions for samples
log_sampling = True   # If set, frequencies scale in log space
use_viewdirs = True   # If set, use view direction as input
n_freqs_views = 4     # Number of encoding functions for views

# Optimizer
lr = 5e-4             # Learning rate
lr_decay = 250        # Learning rate decay
decay_rate = 0.1      # decay rate of the learning rate decay

# Training
n_iters = 10000+1           # Training iterations
batch_size = 1024           # Number of rays per gradient step (power of 2)
one_image_per_step = False  # One image per gradient step (disables batching)
chunksize = 1024            # Modify as needed to fit in GPU memory
center_crop = True          # Crop the center of image (one_image_per_)
center_crop_iters = 50      # Stop cropping center after this many epochs
display_rate = 100          # Display test output every X epochs

# LLFF & Dataloading
start = 0                   # Starting iteration, 0 by default
basedir = ".\\logs"         # Base directory for logs and ckpts
expname = "fern_check"       # Custom experiment name
datadir = ".\\fern"         # Input data directory
llff_hold = 8               # if set, take image 1/N as test set
use_ndc = True              # use ndc for forward facing scenes
render_test = False         # render test set instead of custom poses
save_rate = 25           # frequency of saving the model by iteration

# Early Stopping
warmup_iters = 100          # Number of iterations during warmup phase
warmup_min_fitness = 10.0   # Min val PSNR to continue training at warmup_iters
n_restarts = 10             # Number of times to restart if training stalls

# Function kwargs
stratified_sampling_kwargs = {
    'n_samples': n_samples,
    'perturb': perturb,
    'inverse_depth': inverse_depth
}
hierarchical_sampling_kwargs = {
    'perturb': perturb_hierarchical
}

In [None]:
'''
Classes and Functions for training
'''
def crop_center(
  img: torch.Tensor,
  frac: float = 0.5
):
  """
  Crop center square from image for better result.
  Return: torch.Tensor
  """
  h_offset = round(img.shape[0] * (frac / 2))
  w_offset = round(img.shape[1] * (frac / 2))
  return img[h_offset:-h_offset, w_offset:-w_offset]

class EarlyStopping:
  """
  Early stopping helper.
  """
  def __init__(
    self,
    patience: int = 50,
    min_improve: float = 1e-4
  ):
    self.best_fitness = 0.0  # PSNR
    self.best_iter = 0
    self.min_improve = min_improve
    self.patience = patience or float('inf')  # number of epochs to wait if fitness stop improving

  def __call__(
    self,
    iter: int,
    fitness: float
  ):
    """
    Check if criterion for stopping is met.
    Return: Bool
    """
    if (fitness - self.best_fitness) > self.min_improve:
      self.best_iter = iter
      self.best_fitness = fitness
    delta = iter - self.best_iter
    stop = delta >= self.patience  # stop training
    return stop

In [None]:
data_dir = "H:\\NeRF-NTUE-project\\fern"
images, poses, bds, render_poses, i_test = load_llff_data(data_dir)
focal = poses[0, 2, 4]
height, width = poses[0, :2, -1]
height, width = int(height), int(width)
poses = poses[:, :3, :4]

print(f'Images shape: {images.shape}')
print(f'Poses shape: {poses.shape}')
print(f'Focal: {focal}')

if not isinstance(i_test, list):
    i_test = [i_test]

if llff_hold > 0:
    print(f'Set hold out frequency: {llff_hold}')
    i_test = np.arange(images.shape[0])[::llff_hold]
i_val = i_test
n_training = np.array([i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val)])

if use_ndc:
    near, far = 0., 1.
else:
    near = np.ndarray.min(bds) *.9
    far = np.ndarray.max(bds) * 1.

print(f'Define near far: {near}, {far}')

print(f'Train views are: {n_training}')
print(f'Val views are: {i_val}')
print(f'Test views are: {i_test}')

plt.imshow(images[i_val[0]])
print('Pose')
print(poses[i_val[0]])

if render_test:
    render_poses = np.array(poses[i_test])

os.makedirs(os.path.join(basedir, expname), exist_ok=True)

In [None]:
def init_models():
  """
  Initialize models, encoders, and optimizer for training.
  """

  # Encoders
  encoder = PositionalEncoder(d_input, n_freqs, log_sampling=log_sampling)
  encode = lambda x: encoder(x)

  # View direction encoders
  if use_viewdirs:
    encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views,
                                        log_sampling=log_sampling)
    encode_viewdirs = lambda x: encoder_viewdirs(x)
    d_viewdirs = encoder_viewdirs.d_output
    test_dir = d_viewdirs
  else:
    encode_viewdirs = None
    d_viewdirs = None

  # Models
  model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,
              d_viewdirs=d_viewdirs)
  model.to(device)
  model_params = list(model.parameters())
  if use_fine_model:
    fine_model = NeRF(encoder.d_output, n_layers=n_layers_fine, d_filter=d_filter, skip=skip,
                      d_viewdirs=d_viewdirs)
    fine_model.to(device)
    model_params = model_params + list(fine_model.parameters())
  else:
    fine_model = None

  # Optimizer
  optimizer = torch.optim.Adam(model_params, lr=lr)

  # Early Stopping
  warmup_stopper = EarlyStopping(patience=50)

  # Checkpoints loading
  os.makedirs(os.path.join(basedir, expname), exist_ok=True)
  ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
  print("Found ckpts", ckpts)

  if len(ckpts) > 0:
    ckpt_path = ckpts[-1]
    print('Reloading from', ckpt_path)
    ckpt = torch.load(ckpt_path)

    global start
    start = ckpt['global_step']
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])

    # Load model
    model.load_state_dict(ckpt['coarse_model_dict'])
    if fine_model is not None:
        fine_model.load_state_dict(ckpt['fine_model_dict'])
  

  return model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# To torch tensors
images = torch.from_numpy(images[n_training]).to(device)
poses = torch.from_numpy(poses).to(device)
#focal = torch.from_numpy(focal).to(device)
#testimg = torch.from_numpy(data['images'][testimg_idx]).to(device)
#testpose = torch.from_numpy(data['poses'][testimg_idx]).to(device)

testimg = images[i_test].to(device)
testpose = poses[i_test].to(device)

In [None]:
#Training Loop

def train():
  """
  Launch training.
  """
  global start
  global_step = start

  # Shuffle rays across all images.

  all_ndc_rays = []

  if not one_image_per_step:
    height, width = images.shape[1:3]

    
    for p in poses[n_training]:
      rays_o, rays_d = get_rays(height, width, focal, p)
      rays_o, rays_d = ndc_rays(height, width, focal, 1., rays_o, rays_d)
      rays_od = torch.stack((rays_o, rays_d), dim=0)
      all_ndc_rays.append(rays_od)
    all_ndc_rays = torch.stack(all_ndc_rays, dim=0)

    #all_rays = torch.stack([torch.stack(get_rays(height, width, focal, p), dim=0)
                        #for p in poses[:n_training]], dim=0)  # [100, 2, 100, 100, 3]
    
    rays_rgb = torch.cat([all_ndc_rays, images[:, None]], 1)  # [100, 3, 100, 100, 3]

    rays_rgb = torch.permute(rays_rgb, [0, 2, 3, 1, 4]) # [100, 100, 100, 3, 3]

    rays_rgb = rays_rgb.reshape([-1, 3, 3]) # [1000000, 3, 3]

    rays_rgb = rays_rgb.type(torch.float32)
    rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
    i_batch = 0

  train_psnrs = []
  start = start + 1
  for i in trange(start, n_iters):
    model.train()

    if one_image_per_step:
      # Randomly pick an image as target.
      target_img_idx = np.random.randint(images.shape[0])
      target_img = images[target_img_idx].to(device)
      if center_crop and i < center_crop_iters:
        target_img = crop_center(target_img)
      height, width = target_img.shape[:2]
      target_pose = poses[target_img_idx].to(device)
      rays_o, rays_d = get_rays(height, width, focal, target_pose)
      ###
      rays_o, rays_d = ndc_rays(height, width, focal, 1., rays_o, rays_d)
      ###
      rays_o = rays_o.reshape([-1, 3])
      rays_d = rays_d.reshape([-1, 3])
    else:
      # Random over all images.
      batch = rays_rgb[i_batch:i_batch + batch_size]
      batch = torch.transpose(batch, 0, 1)  # [3, 2**14, 3]
      rays_o, rays_d, target_img = batch  # [2**14, 3]
      height, width = target_img.shape[:2]
      i_batch += batch_size
      # Shuffle after one epoch
      if i_batch >= rays_rgb.shape[0]:
          print("An epoch ends, shuffle all data!")
          rays_rgb = rays_rgb[torch.randperm(rays_rgb.shape[0])]
          i_batch = 0

    target_img = target_img.reshape([-1, 3])

    # Run one iteration and get the rendered RGB image.
    outputs = nerf_forward(rays_o, rays_d,
                           near, far, encode, model,
                           fine_model=fine_model,
                           stratified_sampling_kwargs=stratified_sampling_kwargs,
                           N_importance=N_importance,
                           hierarchical_sampling_kwargs=hierarchical_sampling_kwargs,
                           viewdirs_encoding_fn=encode_viewdirs,
                           chunksize=chunksize)

    # Check for any numerical issues.
    for k, v in outputs.items():
      if torch.isnan(v).any():
        print(f"! [Numerical Alert] {k} contains NaN.")
      if torch.isinf(v).any():
        print(f"! [Numerical Alert] {k} contains Inf.")

    # Backprop
    rgb_predicted = outputs['rgb_map']
    optimizer.zero_grad()
    loss = torch.nn.functional.mse_loss(rgb_predicted, target_img)
    psnr = -10. * torch.log10(loss)

    loss.backward()
    optimizer.step()

    train_psnrs.append(psnr.item())

    # Update learning rate
    decay_steps = lr_decay * 1000
    new_lrate = lr * (decay_rate ** (global_step / decay_steps))
    for param_group in optimizer.param_groups:
        param_group['lr'] = new_lrate

    # Logging.
    if i % display_rate == 0:
      print(f"Step: {global_step}, Loss: {loss.item()}, psnr: {psnr.item()}")
      
    if i % save_rate == 0:
      path = os.path.join(basedir, expname, '{0:06d}.tar'.format(i))
      torch.save({
          'global_step': global_step,
          'coarse_model_dict': model.state_dict(),
          'fine_model_dict': fine_model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
      }, path)
      print('Saved checkpoints at', path)


    # Check PSNR for issues and stop if any are found.
    '''
    if i == warmup_iters - 1:
      if val_psnr < warmup_min_fitness:
        print(f'Val PSNR {val_psnr} below warmup_min_fitness {warmup_min_fitness}. Stopping...')
        return False, train_psnrs, val_psnrs
    '''
    if i < warmup_iters:
      if warmup_stopper is not None and warmup_stopper(i, psnr):
        print(f'Train PSNR flatlined at {psnr} for {warmup_stopper.patience} iters. Stopping...')
        return False, train_psnrs 
      
    
    global_step += 1

  return True, train_psnrs

In [None]:
# Training session
for _ in range(n_restarts):
  model, fine_model, encode, encode_viewdirs, optimizer, warmup_stopper = init_models()
  success, train_psnrs = train()
  #if success and val_psnrs[-1] >= warmup_min_fitness:
  if success:
    print('Training successful!')
    break

print('')
print(f'Done!')

In [None]:
! nvidia-smi