In [12]:
import os
import numpy as np
import torch
from torch import nn
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
import time

from model import *
from rays_util import *
from render import *
from load_llff import *

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
def batchify(fn, chunk):
    """
    Constructs a version of 'fn' that applies to smaller batches.
    """
    if chunk is None:
        return fn
    def ret(inputs, viewdirs):
        return torch.cat([fn(inputs[i:i+chunk], viewdirs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret

In [15]:
def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
    """
    Render rays in smaller minibatches to avoid OOM.
    """
    all_ret = {}
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i+chunk], **kwargs)
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])

    all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
    return all_ret

In [16]:
def run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """
    Prepares inputs and applies network 'fn'.
    """

    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
    embedded = embed_fn(inputs_flat)
    
    if viewdirs is not None:
        input_dirs = viewdirs[:,None].expand(inputs.shape)
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        #embedded = torch.cat([embedded, embedded_dirs], -1)

    outputs_flat = batchify(fn, netchunk)(embedded, viewdirs=embedded_dirs)
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    return outputs

In [17]:
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, use_ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """
    f = K[0][0]

    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, f, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, f, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()

    sh = rays_d.shape # [..., 3]
    if use_ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, f, 1., rays_o, rays_d)

    # Create ray batch
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()

    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]

In [18]:
def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):

    H, W, focal = hwf

    if render_factor!=0:
        # Render downsampled for speed
        H = H//render_factor
        W = W//render_factor
        focal = focal/render_factor

    rgbs = []
    disps = []

    t = time.time()
    for i, c2w in enumerate(tqdm(render_poses)):
        print(i, time.time() - t)
        t = time.time()
        rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
        rgbs.append(rgb.cpu().numpy())
        disps.append(disp.cpu().numpy())
        if i==0:
            print(rgb.shape, disp.shape)

        """
        if gt_imgs is not None and render_factor==0:
            p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
            print(p)
        """

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)


    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)

    return rgbs, disps

In [19]:
"""
Hyperparameters
"""

# Model
d_filter = 256          # Dimensions of linear layer filters
n_layers = 8            # Number of layers in network bottleneck
skip = [4]               # Layers at which to apply input residual
use_fine_model = True   # If set, creates a fine model
d_filter_fine = 256     # Dimensions of linear layer filters of fine network
n_layers_fine = 8       # Number of layers in fine network bottleneck

# Stratified sampling
n_samples = 64         # Number of spatial samples per ray
perturb = False         # If set, applies noise to sample positions
inverse_depth = False  # If set, samples points linearly in inverse depth

# Hierarchical sampling
N_importance = 64   # Number of samples per ray
perturb_hierarchical = False  # If set, applies noise to sample positions

# Encoder
d_input = 3           # Number of input dimensions
n_freqs = 10          # Number of encoding functions for samples
log_sampling = True      # If set, frequencies scale in log space
use_viewdirs = True   # If set, use view direction as input
n_freqs_views = 4     # Number of encoding functions for views

# Training
n_iters = 10000+1
batch_size = 1024         # Number of rays per gradient step (power of 2)
one_image_per_step = False   # One image per gradient step (disables batching)
chunksize = 1024           # Modify as needed to fit in GPU memory
center_crop = False          # Crop the center of image (one_image_per_)
center_crop_iters = 50      # Stop cropping center after this many epochs
display_rate = 25          # Display test output every X epochs

start = 0
basedir = ".\\logs"
datadir = ".\\fern"
expname = "fern_test"
save_rate = 25

llff_hold = 8

use_ndc = True
render_test = False

# Function kwargs
stratified_sampling_kwargs = {
    'n_samples': n_samples,
    'perturb': perturb,
    'inverse_depth': inverse_depth
}
hierarchical_sampling_kwargs = {
    'perturb': perturb
}

In [20]:
def init_models():
  """
  Initialize models, encoders, and optimizer for training.
  """

  # Encoders
  encoder = PositionalEncoder(d_input, n_freqs, log_sampling=log_sampling)
  encode = lambda x: encoder(x)

  # View direction encoders
  if use_viewdirs:
    encoder_viewdirs = PositionalEncoder(d_input, n_freqs_views,
                                        log_sampling=log_sampling)
    encode_viewdirs = lambda x: encoder_viewdirs(x)
    d_viewdirs = encoder_viewdirs.d_output
    test_dir = d_viewdirs
  else:
    encode_viewdirs = None
    d_viewdirs = None

  # Models
  model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,
              d_viewdirs=d_viewdirs)
  model.to(device)
  model_params = list(model.parameters())
  if use_fine_model:
    fine_model = NeRF(encoder.d_output, n_layers=n_layers, d_filter=d_filter, skip=skip,
                      d_viewdirs=d_viewdirs)
    fine_model.to(device)
    model_params = model_params + list(fine_model.parameters())
  else:
    fine_model = None
  
  network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=encode,
                                                                embeddirs_fn=encode_viewdirs,
                                                                netchunk=chunksize)

  # Checkpoints loading
  os.makedirs(os.path.join(basedir, expname), exist_ok=True)
  ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
  print("Found ckpts", ckpts)

  if len(ckpts) > 0:
    ckpt_path = ckpts[-1]
    print('testing from', ckpt_path)
    ckpt = torch.load(ckpt_path)

    global start
    start = ckpt['global_step']
    #optimizer.load_state_dict(ckpt['optimizer_state_dict'])

    # Load model
    model.load_state_dict(ckpt['coarse_model_dict'])
    if fine_model is not None:
        fine_model.load_state_dict(ckpt['fine_model_dict'])

  render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : perturb,
        'N_importance' : N_importance,
        'network_fine' : fine_model,
        'N_samples' : n_samples,
        'network_fn' : model,
        'use_viewdirs' : use_viewdirs,
        'white_bkgd' : False,
        'raw_noise_std' : 0.,
    }
  
  # NDC only good for LLFF-style forward facing data
  if not use_ndc:
    print('Not ndc!')
    render_kwargs_train['ndc'] = False
    #render_kwargs_train['lindisp'] = args.lindisp

  render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
  render_kwargs_test['perturb'] = False
  render_kwargs_test['raw_noise_std'] = 0.

  return render_kwargs_train, render_kwargs_test, start, model_params#, optimizer

In [21]:
def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                retraw=False,
                lindisp=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_bkgd=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False):
    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn: function used for passing queries to network_fn.
      N_samples: int. Number of different times to sample along each ray.
      retraw: bool. If True, include model's raw, unprocessed predictions.
      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      white_bkgd: bool. If True, assume a white background.
      raw_noise_std: ...
      verbose: bool. If True, print more debugging info.
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disp_map: [num_rays]. Disparity map. 1 / depth.
      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disp_map. Output for coarse model.
      acc0: See acc_map. Output for coarse model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
    viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1] # [-1,1]

    t_vals = torch.linspace(0., 1., steps=N_samples).to(device)
    if not lindisp:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

    z_vals = z_vals.expand([N_rays, N_samples])

    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)

        # Pytest, overwrite u with numpy's fixed random numbers
        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand)

        z_vals = lower + (upper - lower) * t_rand

    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]


#     raw = run_network(pts)
    raw = network_query_fn(pts, viewdirs, network_fn)
    #print(raw.shape)
    #rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd)#, pytest=pytest)
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd)#, pytest=pytest)

    if N_importance > 0:

        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map

        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, perturb=False)#, pytest=pytest)
        z_samples = z_samples.detach()

        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]

        run_fn = network_fn if network_fine is None else network_fine
#         raw = run_network(pts, fn=run_fn)
        raw = network_query_fn(pts, viewdirs, run_fn)

        rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd)#, pytest=pytest)

    ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]
    '''
    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")
    '''

    return ret

In [22]:
def test():
    K = None
    data_dir = "H:\\NeRF-NTUE-project\\fern"
    images, poses, bds, render_poses, i_test = load_llff_data(data_dir)
    focal = poses[0, 2, 4]
    height, width = poses[0, :2, -1]
    height, width = int(height), int(width)
    hwf = [height, width, focal]
    poses = poses[:, :3, :4]

    print(f'Images shape: {images.shape}')
    print(f'Poses shape: {poses.shape}')
    print(f'Focal: {focal}')

    if not isinstance(i_test, list):
        i_test = [i_test]

    if llff_hold > 0:
        print(f'Set hold out frequency: {llff_hold}')
        i_test = np.arange(images.shape[0])[::llff_hold]
    i_val = i_test
    n_training = np.array([i for i in np.arange(int(images.shape[0])) if (i not in i_test and i not in i_val)])

    if use_ndc:
        near, far = 0., 1.
    else:
        near = np.ndarray.min(bds) *.9
        far = np.ndarray.max(bds) * 1.

    print(f'Define near far: {near}, {far}')

    print(f'Train views are: {n_training}')
    print(f'Val views are: {i_val}')
    print(f'Test views are: {i_test}')

    '''
    plt.imshow(images[i_val[0]])
    print('Pose')
    print(poses[i_val[0]])
    '''
    
    if K is None:
        K = np.array([
            [focal, 0, 0.5*width],
            [0, focal, 0.5*height],
            [0, 0, 1]
        ])

    if render_test:
        render_poses = np.array(poses[i_test])

    os.makedirs(os.path.join(basedir, expname), exist_ok=True)
    # Create nerf model
    render_kwargs_train, render_kwargs_test, start, grad_vars = init_models()
    global_step = start

    bds_dict = {
        'near' : near,
        'far' : far,
    }
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)

    # Short circuit if only rendering out from trained model
    #if args.render_only:
    if True:
        print('RENDER ONLY')
        with torch.no_grad():
            if render_test:
                # render_test switches to test poses
                images = images[i_test]
            else:
                # Default is smoother render_poses path
                images = None

            testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format('test' if render_test else 'path', start))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', render_poses.shape)

            #rgbs, _ = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
            rgbs, _ = render_path(render_poses, hwf, K, chunksize, render_kwargs_test, gt_imgs=images, savedir=testsavedir, render_factor=0)
            print('Done rendering', testsavedir)
            #imageio.mimwrite(os.path.join(testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)

            return

In [23]:
if __name__=='__main__':
    #torch.set_default_tensor_type('torch.cuda.FloatTensor')

    test()

Loaded image data (20, 378, 504, 3) [378.         504.         407.56579161]
Loaded H:\NeRF-NTUE-project\fern 16.985296178676084 80.00209740336334
recentered (3, 5)
[[ 1.0000000e+00  0.0000000e+00  0.0000000e+00  1.4901161e-09]
 [ 0.0000000e+00  1.0000000e+00 -1.8730975e-09 -9.6857544e-09]
 [-0.0000000e+00  1.8730975e-09  1.0000000e+00  0.0000000e+00]]
Data:
(20, 3, 5) (20, 378, 504, 3) (20, 2)
HOLDOUT view is 12
Images shape: (20, 378, 504, 3)
Poses shape: (20, 3, 4)
Focal: 407.5657958984375
Set hold out frequency: 8
Define near far: 0.0, 1.0
Train views are: [ 1  2  3  4  5  6  7  9 10 11 12 13 14 15 17 18 19]
Val views are: [ 0  8 16]
Test views are: [ 0  8 16]
Found ckpts ['.\\logs\\fern_test\\010000.tar']
testing from .\logs\fern_test\010000.tar
RENDER ONLY
test poses shape torch.Size([120, 3, 5])


  0%|          | 0/120 [00:00<?, ?it/s]

0 0.004998445510864258


  1%|          | 1/120 [00:58<1:56:34, 58.78s/it]

torch.Size([378, 504, 3]) torch.Size([378, 504])
1 58.78027057647705


  2%|▏         | 2/120 [01:57<1:55:24, 58.68s/it]

2 58.61058592796326


  2%|▎         | 3/120 [02:54<1:53:24, 58.16s/it]

3 57.532970666885376


  3%|▎         | 4/120 [03:51<1:51:36, 57.73s/it]

4 57.06911301612854


  4%|▍         | 5/120 [04:48<1:50:04, 57.43s/it]

5 56.89284586906433


  5%|▌         | 6/120 [05:47<1:49:53, 57.84s/it]

6 58.640366554260254


  6%|▌         | 7/120 [06:44<1:48:32, 57.63s/it]

7 57.20531129837036


  7%|▋         | 8/120 [07:40<1:46:36, 57.11s/it]

8 56.002527952194214


  7%|▋         | 8/120 [07:58<1:51:39, 59.82s/it]


KeyboardInterrupt: 

# Training session
for _ in range(1):
  model, fine_model, encode, encode_viewdirs= init_models()
  success, val_psnrs = test()
  #if success and val_psnrs[-1] >= warmup_min_fitness:
  if success:
    print('Testing successful!')
    break

print('')
print(f'Done!')