In [1]:
import os, sys
import numpy as np
import imageio
import json
import random
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm, trange

import matplotlib.pyplot as plt

from run_nerf_helpers import *

from load_llff import load_llff_data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np.random.seed(0)
DEBUG = False            

In [2]:
32768*  64

2097152

In [3]:
# takeaway function and not decompose it
def batchify(fn, chunk):
    """Constructs a version of 'fn' that applies to smaller batches.
    """
    if chunk is None:
        return fn
    def ret(inputs):
        #print(chunk) #65536
        #print(inputs.shape[0]) #4194304
        #print(range(0, inputs.shape[0], chunk)) #range(0, 4194304, 65536)  4194304/65536 = 64
        
        #65536
        #2097152
        #range(0, 2097152, 65536) 2097152/65536 = 32
        
        #fn(inputs[i:i+chunk]).shape = torch.Size([65536, 4])
        #print(fn)
        #NeRF(
        #  (pts_linears): ModuleList(
        #    (0): Linear(in_features=63, out_features=256, bias=True)
        #    (1): Linear(in_features=256, out_features=256, bias=True)
        #    (2): Linear(in_features=256, out_features=256, bias=True)
        #    (3): Linear(in_features=256, out_features=256, bias=True)
        #    (4): Linear(in_features=256, out_features=256, bias=True)
        #    (5): Linear(in_features=319, out_features=256, bias=True)
        #    (6): Linear(in_features=256, out_features=256, bias=True)
        #    (7): Linear(in_features=256, out_features=256, bias=True)
        #  )
        #  (views_linears): ModuleList(
        #    (0): Linear(in_features=283, out_features=128, bias=True)
        #  )
        #  (feature_linear): Linear(in_features=256, out_features=256, bias=True)
        #  (alpha_linear): Linear(in_features=256, out_features=1, bias=True)
        #  (rgb_linear): Linear(in_features=128, out_features=3, bias=True)
        #)
        
        #print([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)])
        
        return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
    return ret

def run_network(inputs, view_direction, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
    """Prepares inputs and applies network 'fn'.
    """
    
    # [DONE] fn = model
    # [NOT] pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
    
    #3
    #[-1, 3]
    # print(inputs.shape[-1]) # 3
    # print([-1, inputs.shape[-1]]) # [-1, 3]
    # print(inputs.shape) #     torch.Size([32768, 64, 3])
    # print(inputs_flat.shape) #torch.Size([2097152, 3])    


    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) # -1 is the last one

    # [NOT]  embad the position
    
    # print(embed_fn) #[torch.sin, torch.cos]
    embedded = embed_fn(inputs_flat)
    # print(embedded.shape) #torch.Size([2097152, 63]) torch.Size([4194304, 63])
    #
    
    
    if view_direction is not None:
        # print(view_direction.shape)         # torch.Size([32768, 3])
        # print(view_direction[:,None].shape) # torch.Size([32768, 1, 3])
        # print(inputs.shape)                 # torch.Size([32768, 64, 3])
        input_dirs = view_direction[:,None].expand(inputs.shape)
        #print(input_dirs.shape)              # torch.Size([32768, 64, 3])
        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
        #print(input_dirs_flat.shape)         # torch.Size([2097152, 3])  32768 * 64 = 2097152
        # print(embeddirs_fn)                 # [torch.sin, torch.cos]
        embedded_dirs = embeddirs_fn(input_dirs_flat)
        # print(embedded.shape)               # torch.Size([2097152, 63])
        # print(embedded_dirs.shape)          # torch.Size([65536, 27])
        embedded = torch.cat([embedded, embedded_dirs], -1)
        # print(embedded.shape)               # torch.Size([65536, 90])
        

        
        
    outputs_flat = batchify(fn, netchunk)(embedded)
    # print(outputs_flat.shape)               # torch.Size([2097152, 4]) 4 is rgb + alpha
    outputs = torch.reshape(outputs_flat, list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])
    # print(list(inputs.shape))                                    # [32768, 64, 3]
    # print(list(inputs.shape[:-1]))                               # [32768, 64]
    # print([outputs_flat.shape[-1]])                              # [4]
    # print(list(inputs.shape[:-1]) + [outputs_flat.shape[-1]])    # [32768, 64, 4]
    # print(outputs.shape)                                         # torch.Size([32768, 64, 4])

    return outputs


def batchify_rays(rays_flat, chunk=1024*32, **kwargs):
    """Render rays in smaller minibatches to avoid OOM.
    """
    
    all_ret = {}
    # range(0, 190512, 32768)
    # print(range(0, rays_flat.shape[0], chunk))
    for i in range(0, rays_flat.shape[0], chunk):
        ret = render_rays(rays_flat[i:i+chunk], **kwargs)
        # print(ret)
        # {
        # 'rgb_map': tensor([[0.1997, 0.2570, 0.1611],
        # [0.0614, 0.0617, 0.0618],
        # [0.2061, 0.2731, 0.1685],
        # ...,
        # [0.7932, 0.7912, 0.7810],
        # [0.6166, 0.6201, 0.5730],
        # [0.3050, 0.3551, 0.2093]], device='cuda:0', grad_fn=<SumBackward1>), 
        # 
        # 'disp_map': tensor([1.7129, 1.7855, 1.6072,  ..., 1.8940, 1.6734, 1.6624], device='cuda:0', grad_fn=<MulBackward0>), 
        # 
        # 'acc_map': tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0', grad_fn=<SumBackward1>), 
        # 
        # 'raw': tensor([[[-1.9319e-01, -4.2380e-02, -8.0331e-01, -1.4389e+01],
        # [ 9.3089e-02,  8.9779e-02, -9.1683e-01, -2.2910e+01],
        # [ 4.2853e-01,  3.8774e-01, -8.5693e-01, -2.6790e+01],
        # ...,
        # [-6.9547e-02, -1.2431e-01, -9.7465e-01,  5.1020e+01],
        # [-4.1269e-01, -3.2929e-01, -1.0338e+00,  9.0478e+01],
        # [-2.8733e-01, -8.4059e-02, -8.7000e-01,  2.4445e+02]],

        # [[ 7.6820e-01,  1.2545e+00,  2.8074e-01, -1.4041e+01],
        # [ 1.0468e+00,  1.6502e+00, -2.7239e-03, -1.8463e+01],
        # [ 2.0594e+00,  2.3907e+00,  3.1019e-01, -2.8290e+01],
        # ...,
        # [-5.4305e-01, -5.7906e-01, -8.0245e-01,  3.2258e+01],
        # [-1.0495e+00, -1.0278e+00, -1.0844e+00,  2.2359e+01],
        # [-1.0510e+00, -9.6312e-01, -1.0124e+00,  4.2875e+01]],

        # [[-6.1452e-01, -3.0069e-01, -1.3510e+00, -1.6736e+01],
        # [-5.0800e-01, -2.1953e-01, -1.4639e+00, -1.6490e+01],
        # [-5.9743e-01, -3.3871e-01, -1.5218e+00, -1.6432e+01],
        # ...,
        # [-1.5179e+00, -1.2576e+00, -1.7958e+00,  1.1900e+02],
        # [-1.4549e+00, -1.3467e+00, -1.7947e+00,  2.3185e+01],
        # [-1.4572e+00, -1.3552e+00, -1.8064e+00,  2.3385e+01]],

        # ...,

        # [[ 6.5511e-01,  4.8005e-01, -6.8248e-01, -2.8155e+01],
        # [ 6.5423e-01,  5.9683e-01, -5.8341e-01, -2.5869e+01],
        # [ 8.5920e-01,  7.2798e-01, -5.1905e-01, -2.8843e+01],
        # ...,
        # [-1.4973e+00, -1.5925e+00, -1.6391e+00,  4.7125e+02],
        # [-2.4364e+00, -2.4179e+00, -2.4181e+00,  1.1233e+02],
        # [-2.2584e+00, -2.2210e+00, -2.2316e+00,  1.7644e+02]],

        # [[ 4.9670e-01,  7.7980e-01, -8.1408e-01, -2.0221e+01],
        # [ 4.6471e-01,  4.5477e-01, -8.6296e-01, -2.0946e+01],
        # [ 5.8376e-01,  7.6404e-01, -6.3700e-01, -2.5659e+01],
        # ...,
        # [ 3.0117e+00,  4.8709e+00,  4.3206e+00, -3.5615e+00],
        # [ 3.3435e+00,  5.6523e+00,  5.2936e+00, -2.9616e+00],
        # [ 2.6037e+00,  5.9075e+00,  6.9644e+00,  3.0141e+00]],

        # [[-4.6111e-01, -5.3944e-01, -1.6045e+00, -1.5485e+01],
        # [-6.8579e-01, -7.7304e-01, -1.8440e+00, -1.4753e+01],
        # [-5.5825e-01, -6.8661e-01, -1.7472e+00, -1.9158e+01],
        # ...,
        # [-1.0059e+00, -1.3362e+00, -1.9315e+00,  4.7464e+00],
        # [-1.0327e+00, -1.2731e+00, -1.8308e+00,  1.7083e+00],
        # [-9.5244e-01, -1.1675e+00, -1.5974e+00,  1.9663e+00]]],
        # device='cuda:0', grad_fn=<ReshapeAliasBackward0>), 
        # 
        # 'rgb0': tensor([[0.1879, 0.2435, 0.1474],
        # [0.0499, 0.0502, 0.0511],
        # [0.2260, 0.2895, 0.1910],
        # ...,
        # [0.7747, 0.7790, 0.7720],
        # [0.6317, 0.6374, 0.5979],
        # [0.3091, 0.3362, 0.1943]], device='cuda:0', grad_fn=<SumBackward1>), 
        # 
        # 'disp0': tensor([1.7112, 1.8140, 1.6950,  ..., 1.8942, 1.6637, 1.6136], device='cuda:0', grad_fn=<MulBackward0>), 
        #
        # 'acc0': tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0', grad_fn=<SumBackward1>), 
        #
        # 'z_std': tensor([0.0225, 0.0156, 0.0120,  ..., 0.0108, 0.0400, 0.1336], device='cuda:0')}
        
        for k in ret:
            if k not in all_ret:
                all_ret[k] = []
            all_ret[k].append(ret[k])
        # print(all_ret)
        # {
        # 'rgb_map': [tensor([[0.3596, 0.3708, 0.2484],
        # [0.4484, 0.4415, 0.3111],
        # [0.3928, 0.3898, 0.2720],
        # ...,
        # [0.5030, 0.4285, 0.3517],
        # [0.4962, 0.4172, 0.3334],
        # [0.4887, 0.4095, 0.3234]], device='cuda:0'), tensor([[0.4374, 0.3868, 0.3077],
        # [0.3152, 0.2962, 0.2380],
        # [0.3303, 0.3047, 0.2461],
        # ...,
        # [0.3095, 0.2413, 0.1847],
        # [0.3425, 0.2649, 0.2020],
        # [0.3390, 0.2649, 0.2017]], device='cuda:0'), tensor([[0.3280, 0.2559, 0.1973],
        # [0.3296, 0.2548, 0.1944],
        # [0.3282, 0.2526, 0.1935],
        # ...,
        # [0.6551, 0.6464, 0.5171],
        # [0.6250, 0.6028, 0.4798],
        # [0.5882, 0.5469, 0.4264]], device='cuda:0')], 'disp_map': [tensor([1.6642, 1.6570, 1.6521,  ..., 1.5733, 1.5742, 1.5784], device='cuda:0'), tensor([1.6753, 1.7885, 1.7782,  ..., 1.2587, 1.2471, 1.2447], device='cuda:0'), tensor([1.2399, 1.2391, 1.2425,  ..., 1.8311, 1.6835, 1.5569], device='cuda:0')], 'acc_map': [tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0'), tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0'), tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0')], 'rgb0': [tensor([[0.4536, 0.4364, 0.3346],
        # [0.5424, 0.5290, 0.4109],
        # [0.4064, 0.4065, 0.3093],
        # ...,
        # [0.5558, 0.4727, 0.3682],
        # [0.5424, 0.4564, 0.3573],
        # [0.5501, 0.4731, 0.3673]], device='cuda:0'), tensor([[0.4619, 0.4234, 0.3358],
        # [0.3041, 0.3021, 0.2474],
        # [0.3874, 0.3397, 0.2682],
        # ...,
        # [0.3025, 0.2407, 0.1764],
        # [0.3031, 0.2447, 0.1801],
        # [0.2946, 0.2330, 0.1723]], device='cuda:0'), tensor([[0.3052, 0.2408, 0.1792],
        # [0.3163, 0.2515, 0.1894],
        # [0.3279, 0.2622, 0.1952],
        # ...,
        # [0.5881, 0.6169, 0.4851],
        # [0.5934, 0.6035, 0.4814],
        # [0.5528, 0.5513, 0.4327]], device='cuda:0')], 
        # 'disp0': [tensor([1.6387, 1.6301, 1.6558,  ..., 1.5708, 1.5701, 1.5913], device='cuda:0'), tensor([1.7420, 1.7823, 1.6791,  ..., 1.2635, 1.2633, 1.2224], device='cuda:0'), tensor([1.2286, 1.2143, 1.2203,  ..., 1.9050, 1.8043, 1.7108], device='cuda:0')], 
        # 'acc0': [tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0'), tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0'), tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0')], 
        # 'z_std': [tensor([0.0900, 0.0902, 0.0901,  ..., 0.0915, 0.0911, 0.0938], device='cuda:0'), tensor([0.0941, 0.0896, 0.0965,  ..., 0.1283, 0.1263, 0.1056], device='cuda:0'), tensor([0.1051, 0.1050, 0.1063,  ..., 0.1518, 0.1549, 0.1401], device='cuda:0')]}
    all_ret = {k : torch.cat(all_ret[k], 0) for k in all_ret}
    # print(all_ret)
    # {
    # 'rgb_map': tensor([[0.3581, 0.3689, 0.2477],
    #    [0.4479, 0.4408, 0.3112],
    #    [0.3914, 0.3884, 0.2717],
    #    ...,
    #    [0.8166, 0.7137, 0.5876],
    #    [0.8147, 0.7113, 0.5886],
    #    [0.8124, 0.7104, 0.5837]], device='cuda:0'), 
    #
    # 'disp_map': tensor([1.6642, 1.6571, 1.6522,  ..., 1.9530, 1.9497, 1.9521], device='cuda:0'), 
    # 'acc_map': tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0'), 
    # 'rgb0': tensor([[0.4541, 0.4368, 0.3353],
    #    [0.5436, 0.5301, 0.4126],
    #    [0.4064, 0.4066, 0.3099],
    #    ...,
    #    [0.8254, 0.7126, 0.6027],
    #    [0.8204, 0.7077, 0.5972],
    #    [0.8221, 0.7101, 0.5964]], device='cuda:0'), 
    # 'disp0': tensor([1.6389, 1.6301, 1.6560,  ..., 1.7748, 1.7937, 1.8374], device='cuda:0'), 
    # 'acc0': tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0'), 
    # 'z_std': tensor([0.0900, 0.0902, 0.0901,  ..., 0.1195, 0.1143, 0.1107], device='cuda:0')}
    
    # torch.Size([378, 504, 3]) torch.Size([378, 504])
    return all_ret




def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.
      W: int. Width of image in pixels.
      focal: float. Focal length of pinhole camera.
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.
      near: float or array of shape [batch_size]. Nearest distance for a ray.
      far: float or array of shape [batch_size]. Farthest distance for a ray.
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for 
       camera while using other c2w argument for viewing directions.
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.
      disp_map: [batch_size]. Disparity map. Inverse of depth.
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
      extras: dict with everything returned by render_rays().
    """
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w)
        #print(H)  # 378
        #print(W)  # 504
        
        #print(K)
        
        #[[407.5657959   0.        252.       ]
        # [  0.        407.5657959 189.       ]
        # [  0.          0.          1.       ]]
        
        # print(c2w)

        # tensor([[ 9.9583e-01, -3.7973e-05,  9.1231e-02,  3.9451e-01],
        #        [ 3.8132e-05,  1.0000e+00, -1.8653e-09, -9.6858e-09],
        #        [-9.1231e-02,  3.4807e-06,  9.9583e-01,  0.0000e+00]], device='cuda:0')

        # print(rays_o.shape) # torch.Size([378, 504, 3])
        # print(rays_d.shape) # torch.Size([378, 504, 3])
    else:
        # use provided ray batch
        rays_o, rays_d = rays

    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()
        # print(viewdirs)
        
        # tensor([[ 0.3432,  0.1160, -0.9321],
        #        [-0.1477,  0.2889, -0.9459],
        #        [ 0.3782, -0.3090, -0.8726],
        #        ...,
        #        [-0.5139, -0.2868, -0.8085],
        #        [ 0.5251,  0.2238, -0.8211],
        #        [-0.2700, -0.2209, -0.9372]], device='cuda:0')
        
        # print(viewdirs.shape)
        # torch.Size([1024, 3])
        

    sh = rays_d.shape # [..., 3] 
    # print(sh) # torch.Size([378, 504, 3])
    if ndc:
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)

        

    # Create ray batch
    # print(rays_o.shape) # torch.Size([378, 504, 3])
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    # print(rays_o.shape) # torch.Size([190512, 3])
    rays_d = torch.reshape(rays_d, [-1,3]).float()

    # print(torch.ones_like(rays_d[...,:1]))
    
    # tensor([[1.],
    #    [1.],
    #    [1.],
    #    ...,
    #    [1.],
    #    [1.],
    #    [1.]], device='cuda:0')
    
    # print(torch.ones_like(rays_d[...,:1]).shape)
    # torch.Size([190512, 1])
    
    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    
    # print(near)
    
    # tensor([[0.],
    #    [0.],
    #    [0.],
    #    ...,
    #    [0.],
    #    [0.],
    #    [0.]], device='cuda:0')
    
    # print(far)
    
    # tensor([[1.],
    #    [1.],
    #    [1.],
    #    ...,
    #    [1.],
    #    [1.],
    #    [1.]], device='cuda:0')
    
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    
    #print(rays)
    
    #tensor([[ 0.3704,  0.0923, -1.0000,  ...,  2.0000,  0.0000,  1.0000],
    #    [-0.9372,  0.0905, -1.0000,  ...,  2.0000,  0.0000,  1.0000],
    #    [ 0.8355, -0.1676, -1.0000,  ...,  2.0000,  0.0000,  1.0000],
    #    ...,
    #    [-0.3754, -0.1484, -1.0000,  ...,  2.0000,  0.0000,  1.0000],
    #    [ 0.0914,  1.0717, -1.0000,  ...,  2.0000,  0.0000,  1.0000],
    #    [ 1.0489,  0.5523, -1.0000,  ...,  2.0000,  0.0000,  1.0000]],
    #   device='cuda:0')
    
    # print(rays.shape) # torch.Size([1024, 8])
    

    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)

    #print(rays)
    
    # tensor([[-0.7948, -0.9399, -1.0000,  ..., -0.2841, -0.3516, -0.8920],
    #    [-0.2115, -0.0843, -1.0000,  ..., -0.2699, -0.0656, -0.9607],
    #    [-0.5973,  0.3672, -1.0000,  ..., -0.1928,  0.1862, -0.9634],
    #    ...,
    #    [-0.7046,  0.4731, -1.0000,  ..., -0.2490,  0.2268, -0.9416],
    #    [ 0.0953,  0.9137, -1.0000,  ...,  0.3694,  0.3839, -0.8462],
    #    [-0.3576,  0.9673, -1.0000,  ..., -0.4698,  0.3516, -0.8097]],
    #   device='cuda:0')
    
    # print(rays.shape) # torch.Size([1024, 11])
    
    # Render and reshape
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        # print(k)
        # print(list(sh[:-1]))
        # print(list(all_ret[k].shape[1:]))
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        # print(k_sh)
        all_ret[k] = torch.reshape(all_ret[k], k_sh)
        
        #rgb_map
        #[378, 504]
        #[3]
        #[378, 504, 3]
        #disp_map
        #[378, 504]
        #[]
        #[378, 504]
        #acc_map
        #[378, 504]
        #[]
        #[378, 504]
        #rgb0
        #[378, 504]
        #[3]
        #[378, 504, 3]
        #disp0
        #[378, 504]
        #[]
        #[378, 504]
        #acc0
        #[378, 504]
        #[]
        #[378, 504]
        #z_std
        #[378, 504]
        #[]
        #[378, 504]

        
        

    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]




def render_rays(ray_batch,
                network_fn,
                network_query_fn,
                N_samples,
                return_raw=False,
                linear_depend_inverse_depth=False,
                perturb=0.,
                N_importance=0,
                network_fine=None,
                white_background=False,
                raw_noise_std=0.,
                verbose=False,
                pytest=False):
    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn: function used for passing queries to network_fn.
      N_samples: int. Number of different times to sample along each ray.
      return_raw: bool. If True, include model's raw, unprocessed predictions.
      linear_depend_inverse: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      white_background: bool. If True, assume a white background.
      raw_noise_std: ...
      verbose: bool. If True, print more debugging info.
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disparity_map: [num_rays]. Disparity map. 1 / depth.
      accumulated_opacity_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disparity_map. Output for coarse model.
      acc0: See accumulated_opacity_map. Output for coarse model.
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """
    N_rays = ray_batch.shape[0]
    # print(ray_batch.shape)      # torch.Size([1024, 11])
    
    rays_origin, rays_direction = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
    # print(rays_origin.shape)    # torch.Size([1024, 3])
    # print(rays_direction.shape) # torch.Size([1024, 3])
    
    view_direction = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
    
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    # print(bounds.shape)         # torch.Size([1024, 1, 2])
    
    near, far = bounds[...,0], bounds[...,1] # [-1,1]

    t_vals = torch.linspace(0., 1., steps=N_samples).to(device)
    
    if not linear_depend_inverse_depth:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))

    # print(z_vals.shape)         # torch.Size([1024, 64])
    z_vals = z_vals.expand([N_rays, N_samples]).to(device)
    # print(z_vals.shape)         # torch.Size([1024, 64])
    
    

    
    if perturb > 0.:
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        
        # print(mids.shape)       # torch.Size([1024, 63])
        # print(upper.shape)      # torch.Size([1024, 64])
        # print(lower.shape)      # torch.Size([1024, 64])

        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape).to(device)
        
        # print(t_rand.shape)     # torch.Size([1024, 64])

        # Pytest, overwrite u with numpy's fixed random numbers
        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand).to(device)

        z_vals = lower + (upper - lower) * t_rand

    # form the origin to the end of the ray, in the ray direction to form pts
    pts = rays_origin[...,None,:] + rays_direction[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
    


#   raw = run_network(pts)

    # print(network_query_fn) # <function train.<locals>.<lambda> at 0x7f2a0243b700>
    # print(network_fn)       # NERF()
    raw = network_query_fn(pts, view_direction, network_fn)
    # print(raw.shape)        # torch.Size([1024, 64, 4])
    
    rgb_map, disparity_map, accumulated_opacity_map, weights, depth_map = raw2outputs_kernel_function(raw, z_vals, rays_direction, raw_noise_std, white_background, pytest=pytest)

    # print(rgb_map.shape)                  # torch.Size([1024, 3])
    # print(disparity_map.shape)            # torch.Size([1024])
    # print(accumulated_opacity_map.shape)  # torch.Size([1024])
    # print(weights.shape)                  # torch.Size([1024, 64])
    # print(depth_map.shape)                # torch.Size([1024])
    
    if N_importance > 0:

        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disparity_map, accumulated_opacity_map

        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
        z_samples = z_samples.detach()

        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_origin[...,None,:] + rays_direction[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]

        run_fn = network_fn if network_fine is None else network_fine
        # raw = run_network(pts, fn=run_fn)
        raw = network_query_fn(pts, view_direction, run_fn)

        rgb_map, disparity_map, accumulated_opacity_map, weights, depth_map = raw2outputs_kernel_function(raw, z_vals, rays_direction, raw_noise_std, white_background, pytest=pytest)

    ret = {'rgb_map' : rgb_map, 'disp_map' : disparity_map, 'acc_map' : accumulated_opacity_map}
    if return_raw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]

    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")

    return ret



def raw2outputs_kernel_function(raw, z_vals, rays_direction, raw_noise_std=0, white_background=False, pytest=False):
    """Transforms model's predictions to semantically meaningful values.
    Args:
        raw: [num_rays, num_samples along ray, 4]. Prediction from model.
        z_vals: [num_rays, num_samples along ray]. Integration time.
        rays_direction: [num_rays, 3]. Direction of each ray.
    Returns:
        rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
        disparity_map: [num_rays]. Disparity map. Inverse of depth map.
        accumulated_opacity_map: [num_rays]. Sum of weights along each ray.
        weights: [num_rays, num_samples]. Weights assigned to each sampled color.
        depth_map: [num_rays]. Estimated distance to object.
    """
    raw2alpha = lambda raw, distances, act_fn=F.relu: 1.-torch.exp(-act_fn(raw)*distances)
    



    
    distances = z_vals[...,1:] - z_vals[...,:-1]
    # print(distances.shape) # torch.Size([1024, 63])
    distances = torch.cat([distances, torch.Tensor([1e10]).to(device).expand(distances[...,:1].shape)], -1)  # [N_rays, N_samples]
    # print(distances.shape) # torch.Size([1024, 64])
    # print(distances)
    # tensor([[1.4942e-02, 1.5233e-02, 1.4264e-02,  ..., 1.0590e-02, 8.3287e-03,
    #  1.0000e+10],
    # [4.7994e-03, 2.6690e-02, 5.5363e-03,  ..., 1.2192e-02, 1.2408e-02,
    #  1.0000e+10],
    # [1.0520e-02, 7.1725e-03, 1.8006e-02,  ..., 1.3228e-02, 7.9815e-03,
    #  1.0000e+10],
    # ...,
    # [1.3458e-02, 1.5161e-02, 1.3487e-02,  ..., 1.1210e-02, 1.1830e-02,
    #  1.0000e+10],
    # [1.3381e-02, 1.7829e-02, 2.5521e-03,  ..., 6.9237e-03, 1.6017e-02,
    #  1.0000e+10],
    # [1.2072e-02, 1.7107e-02, 1.1094e-02,  ..., 2.1299e-02, 6.8605e-03,
    # 1.0000e+10]], device='cuda:0')
    distances = distances * torch.norm(rays_direction[...,None,:], dim=-1)
    
    # print(torch.norm(rays_direction[...,None,:], dim=-1))
    # tensor([[2.0986],
    #     [2.0023],
    #     [2.0339],
    #     ...,
    #     [2.0334],
    #     [2.0881],
    #     [2.0427]], device='cuda:0')
    
    
    # print(distances.shape) # torch.Size([1024, 64])
    # print(distances)
    # tensor([[3.0871e-02, 3.1471e-02, 2.9470e-02,  ..., 2.1879e-02, 1.7207e-02,
    #          2.0660e+10],
    #         [9.6703e-03, 5.3778e-02, 1.1155e-02,  ..., 2.4566e-02, 2.5000e-02,
    #          2.0149e+10],
    #         [2.1828e-02, 1.4883e-02, 3.7363e-02,  ..., 2.7448e-02, 1.6562e-02,
    #          2.0750e+10],
    #         ...,
    #         [2.7521e-02, 3.1004e-02, 2.7580e-02,  ..., 2.2924e-02, 2.4192e-02,
    #          2.0450e+10],
    #         [2.7263e-02, 3.6327e-02, 5.1998e-03,  ..., 1.4107e-02, 3.2634e-02,
    #          2.0375e+10],
    #         [2.4599e-02, 3.4857e-02, 2.2606e-02,  ..., 4.3398e-02, 1.3979e-02,
    #          2.0376e+10]], device='cuda:0')
    rgb = torch.sigmoid(raw[...,:3])  # [N_rays, N_samples, 3]
    # print(raw[...,:3])
    # tensor([[[ 0.0404, -0.1108, -0.0680],
    #      [ 0.0393, -0.1103, -0.0705],
    #      [ 0.0395, -0.1105, -0.0707],
    #      ...,
    #      [ 0.0379, -0.1106, -0.0705],
    #      [ 0.0364, -0.1127, -0.0677],
    #      [ 0.0359, -0.1121, -0.0677]],

    #     [[-0.0153, -0.1384, -0.0835],
    #      [-0.0165, -0.1386, -0.0814],
    #      [-0.0185, -0.1409, -0.0810],
    #      ...,
    #      [-0.0207, -0.1419, -0.0842],
    #      [-0.0196, -0.1402, -0.0828],
    #      [-0.0196, -0.1417, -0.0822]],

    #     [[ 0.0229, -0.1222, -0.0631],
    #      [ 0.0215, -0.1270, -0.0625],
    #      [ 0.0215, -0.1254, -0.0623],
    #      ...,
    #      [ 0.0254, -0.1230, -0.0614],
    #      [ 0.0237, -0.1229, -0.0615],
    #      [ 0.0230, -0.1261, -0.0606]],

    #     ...,

    #     [[-0.0094, -0.1177, -0.0549],
    #      [-0.0085, -0.1184, -0.0538],
    #      [-0.0069, -0.1176, -0.0548],
    #      ...,
    #      [-0.0080, -0.1179, -0.0549],
    #      [-0.0099, -0.1207, -0.0525],
    #      [-0.0092, -0.1203, -0.0524]],

    #     [[ 0.0120, -0.1012, -0.1174],
    #      [ 0.0140, -0.1005, -0.1179],
    #      [ 0.0138, -0.1007, -0.1172],
    #      ...,
    #      [ 0.0111, -0.0992, -0.1177],
    #      [ 0.0105, -0.1011, -0.1163],
    #      [ 0.0118, -0.1011, -0.1173]],

    #     [[-0.0286, -0.1325, -0.0700],
    #      [-0.0293, -0.1317, -0.0702],
    #      [-0.0264, -0.1328, -0.0693],
    #      ...,
    #      [-0.0265, -0.1357, -0.0692],
    #      [-0.0273, -0.1366, -0.0681],
    #      [-0.0285, -0.1353, -0.0693]]], device='cuda:0',
    #    grad_fn=<SliceBackward0>)
    # print(rgb)
    # tensor([[[0.5101, 0.4723, 0.4830],
    #      [0.5098, 0.4724, 0.4824],
    #      [0.5099, 0.4724, 0.4823],
    #      ...,
    #      [0.5095, 0.4724, 0.4824],
    #      [0.5091, 0.4719, 0.4831],
    #      [0.5090, 0.4720, 0.4831]],

    #     [[0.4962, 0.4655, 0.4791],
    #      [0.4959, 0.4654, 0.4797],
    #      [0.4954, 0.4648, 0.4798],
    #      ...,
    #      [0.4948, 0.4646, 0.4790],
    #      [0.4951, 0.4650, 0.4793],
    #      [0.4951, 0.4646, 0.4795]],

    #     [[0.5057, 0.4695, 0.4842],
    #      [0.5054, 0.4683, 0.4844],
    #      [0.5054, 0.4687, 0.4844],
    #      ...,
    #      [0.5063, 0.4693, 0.4846],
    #      [0.5059, 0.4693, 0.4846],
    #      [0.5057, 0.4685, 0.4849]],

    #     ...,

    #     [[0.4976, 0.4706, 0.4863],
    #      [0.4979, 0.4704, 0.4865],
    #      [0.4983, 0.4706, 0.4863],
    #      ...,
    #      [0.4980, 0.4706, 0.4863],
    #      [0.4975, 0.4698, 0.4869],
    #      [0.4977, 0.4700, 0.4869]],

    #     [[0.5030, 0.4747, 0.4707],
    #      [0.5035, 0.4749, 0.4706],
    #      [0.5035, 0.4749, 0.4707],
    #      ...,
    #      [0.5028, 0.4752, 0.4706],
    #      [0.5026, 0.4748, 0.4709],
    #      [0.5029, 0.4747, 0.4707]],

    #     [[0.4929, 0.4669, 0.4825],
    #      [0.4927, 0.4671, 0.4825],
    #      [0.4934, 0.4669, 0.4827],
    #      ...,
    #      [0.4934, 0.4661, 0.4827],
    #      [0.4932, 0.4659, 0.4830],
    #      [0.4929, 0.4662, 0.4827]]], device='cuda:0',
    #    grad_fn=<SigmoidBackward0>)
    #    sigmond( -0.0285 ) = 0.4929
    
    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw[...,3].shape).to(device) * raw_noise_std

        # Overwrite randomly sampled data if pytest
        if pytest:
            np.random.seed(0)
            noise = np.random.rand(*list(raw[...,3].shape)) * raw_noise_std
            noise = torch.Tensor(noise).to(device)

    alpha = raw2alpha(raw[...,3] + noise, distances)  # [N_rays, N_samples]
    
    # print(raw[...,3].shape) # torch.Size([1024, 64])
    # print(distances.shape)  # torch.Size([1024, 64])
    # print(alpha.shape)      # torch.Size([1024, 64])
    
    # print( torch.cat([torch.ones((alpha.shape[0], 1)).to(device), 1.-alpha + 1e-10], -1) )
    # print( torch.cat([torch.ones((alpha.shape[0], 1)).to(device), 1.-alpha + 1e-10], -1).shape )
    
    # tensor([[1.0000e+00, 1.0000e+00, 9.8664e-01,  ..., 9.3879e-01, 9.7269e-01,
    #      1.0000e+00],
    #     [1.0000e+00, 9.8739e-01, 1.0000e+00,  ..., 9.9525e-01, 9.7405e-01,
    #      1.0000e+00],
    #     [1.0000e+00, 1.0000e+00, 1.0000e+00,  ..., 9.8951e-01, 9.9980e-01,
    #      1.0000e+00],
    #     ...,
    #     [1.0000e+00, 9.6991e-01, 9.7920e-01,  ..., 9.5538e-01, 1.0000e+00,
    #      1.0000e-10],
    #     [1.0000e+00, 1.0000e+00, 9.7404e-01,  ..., 9.8908e-01, 1.0000e+00,
    #      1.0000e+00],
    #     [1.0000e+00, 9.7760e-01, 1.0000e+00,  ..., 9.7234e-01, 9.6326e-01,
    #      1.0000e-10]], device='cuda:0', grad_fn=<CatBackward0>)    

    # print( [torch.ones((alpha.shape[0], 1)).to(device), 1.-alpha + 1e-10] )
    
    # print( torch.ones((alpha.shape[0], 1)).shape )

    # [tensor([[1.],
    #     [1.],
    #     [1.],
    #     ...,
    #     [1.],
    #     [1.],
    #     [1.]], device='cuda:0'), tensor([[1.0000e+00, 9.8664e-01, 1.0000e+00,  ..., 9.3879e-01, 9.7269e-01,
    #      1.0000e+00],
    #     [9.8739e-01, 1.0000e+00, 1.0000e+00,  ..., 9.9525e-01, 9.7405e-01,
    #      1.0000e+00],
    #     [1.0000e+00, 1.0000e+00, 9.7597e-01,  ..., 9.8951e-01, 9.9980e-01,
    #      1.0000e+00],
    #     ...,
    #     [9.6991e-01, 9.7920e-01, 9.5232e-01,  ..., 9.5538e-01, 1.0000e+00,
    #      1.0000e-10],
    #     [1.0000e+00, 9.7404e-01, 1.0000e+00,  ..., 9.8908e-01, 1.0000e+00,
    #      1.0000e+00],
    #     [9.7760e-01, 1.0000e+00, 9.8615e-01,  ..., 9.7234e-01, 9.6326e-01,
    #      1.0000e-10]], device='cuda:0', grad_fn=<AddBackward0>)]
    
    # print( torch.ones((alpha.shape[0], 1)).to(device) )
    # print( torch.ones((alpha.shape[0], 1)).shape )
    # tensor([[1.],
    #     [1.],
    #     [1.],
    #     ...,
    #     [1.],
    #     [1.],
    #     [1.]], device='cuda:0')
    
    # print( 1.-alpha + 1e-10 )
    
    # tensor([[1.0000e+00, 9.8664e-01, 1.0000e+00,  ..., 9.3879e-01, 9.7269e-01,
    #      1.0000e+00],
    #     [9.8739e-01, 1.0000e+00, 1.0000e+00,  ..., 9.9525e-01, 9.7405e-01,
    #      1.0000e+00],
    #     [1.0000e+00, 1.0000e+00, 9.7597e-01,  ..., 9.8951e-01, 9.9980e-01,
    #      1.0000e+00],
    #     ...,
    #     [9.6991e-01, 9.7920e-01, 9.5232e-01,  ..., 9.5538e-01, 1.0000e+00,
    #      1.0000e-10],
    #     [1.0000e+00, 9.7404e-01, 1.0000e+00,  ..., 9.8908e-01, 1.0000e+00,
    #      1.0000e+00],
    #     [9.7760e-01, 1.0000e+00, 9.8615e-01,  ..., 9.7234e-01, 9.6326e-01,
    #      1.0000e-10]], device='cuda:0', grad_fn=<AddBackward0>)
    
    
    # print( alpha.shape ) # torch.Size([1024, 64])   
    # print( alpha ) # 1.-torch.exp(-act_fn(raw)*distances)
    




     
    # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
    # tensor = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)
    # cumprod_tensor = torch.cumprod(tensor, dim=0)
    # tensor([1., 2., 6., 24., 120.])
    # torch.cumprod() is Ti of all the point
    
    # alpha = 1.-torch.exp(-act_fn(raw)*distances)
    # Ti = torch.cumprod()
    # weights = Ti * ( 1 - exp( NREF()*delta ) ) 
    
    Ti = torch.cumprod(torch.cat([torch.ones((alpha.shape[0], 1)).to(device), 1.-alpha + 1e-10], -1), -1)[:, :-1]
    weights = Ti * alpha 
    rgb_map = torch.sum(weights[...,None] * rgb, -2)  # [N_rays, 3]

    depth_map = torch.sum(weights * z_vals, -1)
    disparity_map = 1./torch.max(1e-10 * torch.ones_like(depth_map).to(device), depth_map / torch.sum(weights, -1))
    accumulated_opacity_map = torch.sum(weights, -1)

    if white_background:
        rgb_map = rgb_map + (1.-accumulated_opacity_map[...,None])

    return rgb_map, disparity_map, accumulated_opacity_map, weights, depth_map


def render_path(render_poses, hwf, K, chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):

    H, W, focal = hwf

    if render_factor!=0:
        # Render downsampled for speed
        H = H//render_factor
        W = W//render_factor
        focal = focal/render_factor

    rgbs = []
    disps = []

    t = time.time()
    for i, c2w in enumerate(tqdm(render_poses)):
        print(i, time.time() - t)
        t = time.time()
        rgb, disp, acc, _ = render(H, W, K, chunk=chunk, c2w=c2w[:3,:4], **render_kwargs)
        rgbs.append(rgb.cpu().numpy())
        disps.append(disp.cpu().numpy())
        if i==0:
            print(rgb.shape, disp.shape)

        """
        if gt_imgs is not None and render_factor==0:
            p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
            print(p)
        """

        if savedir is not None:
            rgb8 = to8b(rgbs[-1])
            filename = os.path.join(savedir, '{:03d}.png'.format(i))
            imageio.imwrite(filename, rgb8)


    rgbs = np.stack(rgbs, 0)
    disps = np.stack(disps, 0)

    return rgbs, disps





![image.png](attachment:image.png)

In [4]:
def train():

    
    #########################################################################
    # python run_nerf.py --config configs/fern.txt
    N_importance = 64
    N_rand = 1024
    N_samples = 64
    basedir = './logs'
    chunk = 32768
    config = 'configs/fern.txt'
    datadir = './data/nerf_llff_data/fern'
    dataset_type = 'llff'
    expname = 'fern_test'
    factor = 8
    ft_path = None
    half_res = False
    i_embed = 0
    i_img = 500
    i_print = 100
    i_testset = 5000
    i_video = 5000
    i_weights = 10000
    lindisp = False
    llffhold = 8
    lrate = 0.0005
    lrate_decay = 250
    multires = 10
    multires_views = 4
    netchunk = 65536
    netdepth = 8
    netdepth_fine = 8
    netwidth = 256
    netwidth_fine = 256
    no_batching = False
    no_ndc = False
    no_reload = False
    perturb = 1.0
    precrop_frac = 0.5
    precrop_iters = 0
    raw_noise_std = 1.0
    render_factor = 0
    render_only = False
    render_test = False
    shape = 'greek'
    spherify = False
    testskip = 8
    use_viewdirs = True
    white_background = False
    ################################################################################

    images, poses, bds, render_poses, i_test = load_llff_data(datadir, factor, recenter=True, bd_factor=.75, spherify=spherify)
    print("================[0][load_data]=================")
    #print(images.shape)   (20, 378, 504, 3)
    #print(poses.shape)    (20, 3, 5)
    #print(bds.shape)      (20, 2)
    #print(render_poses.shape)  (120, 3, 5)   # Generate poses for spiral path
    #print(i_test) 12           print('HOLDOUT view is', i_test)
    #print(i_test.shape)

    #print(bds) [0.5500126  2.4253333 ]
    # print(bds.shape)  (20, 2)
    # []


    # what is the (20, 3, 5), 3 is for what, 5 is for what
    hwf = poses[0,:3,-1]
    poses = poses[:,:3,:4]
    #print(poses.shape) (20, 3, 4)
    print(poses[0])
    # (20, 378, 504, 3) (120, 3, 5) [378.     504.     407.5658] ./data/nerf_llff_data/fern
    print('Loaded llff', images.shape, render_poses.shape, hwf, datadir)

    if not isinstance(i_test, list):
        i_test = [i_test]

    #print(i_test) [12]


    if llffhold > 0:
        print('Auto LLFF holdout,', llffhold)
        i_test = np.arange(images.shape[0])[::llffhold]
    #print(i_test) [ 0  8 16]

    i_val = i_test
    i_train = np.array([i for i in np.arange(int(images.shape[0])) if
                    (i not in i_test and i not in i_val)])
    #print(i_train) [ 1  2  3  4  5  6  7  9 10 11 12 13 14 15 17 18 19]


    print('DEFINING BOUNDS')
    if no_ndc:
        near = np.ndarray.min(bds) * .9
        far = np.ndarray.max(bds) * 1.

    else:
        near = 0.
        far = 1.
    print('NEAR FAR', near, far) # 0.4737630307674408 2.4794018268585205


    # 1 Cast intrinsics to right types
    H, W, focal = hwf
    #print(hwf)  [378.     504.     407.5658]
    H, W = int(H), int(W)
    hwf = [H, W, focal]
    
    #print(hwf) [378, 504, 407.5658]

    # Load data

    K = np.array([
        [focal, 0, 0.5*W],
        [0, focal, 0.5*H],
        [0, 0, 1]
    ])

    print(K)    
    
    basedir = basedir
    expname = expname
    os.makedirs(os.path.join(basedir, expname), exist_ok=True)    
    
    
        
    
    # create_nerf()
    
    print("================[1][create_nerf()]=================")
    # Positional encoding
    
    embed_fn, input_ch = get_embedder(multires, i_embed)

    input_ch_views = 0
    embeddirs_fn = None
    if use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(multires_views, i_embed)
    output_ch = 5 if N_importance > 0 else 4
    skips = [4]
    
    
    model = NeRF(D=netdepth, W=netwidth,
                 input_ch=input_ch, output_ch=output_ch, skips=skips,
                 input_ch_views=input_ch_views, use_viewdirs=use_viewdirs).to(device)
    grad_vars = list(model.parameters())

    
    model_fine = None
    if N_importance > 0:
        model_fine = NeRF(D=netdepth_fine, W=netwidth_fine,
                          input_ch=input_ch, output_ch=output_ch, skips=skips,
                          input_ch_views=input_ch_views, use_viewdirs=use_viewdirs).to(device)
        grad_vars += list(model_fine.parameters())

    network_query_fn = lambda inputs, viewdirs, network_fn : run_network(inputs, viewdirs, network_fn,
                                                                embed_fn=embed_fn,
                                                                embeddirs_fn=embeddirs_fn,
                                                                netchunk=netchunk)

    # Create optimizer
    optimizer = torch.optim.Adam(params=grad_vars, lr=lrate, betas=(0.9, 0.999))

    start = 0
    basedir = basedir
    expname = expname

    ##########################

    # Load checkpoints
    if ft_path is not None and ft_path!='None':
        ckpts = [ft_path]
    else:
        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]

    print('Found ckpts', ckpts)
    if len(ckpts) > 0 and not no_reload:
        ckpt_path = ckpts[-1]
        print('Reloading from', ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt['global_step']
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])

        # Load model
        model.load_state_dict(ckpt['network_fn_state_dict'])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt['network_fine_state_dict'])

    ##########################

    render_kwargs_train = {
        'network_query_fn' : network_query_fn,
        'perturb' : perturb,
        'N_importance' : N_importance,
        'network_fine' : model_fine,
        'N_samples' : N_samples,
        'network_fn' : model,
        'use_viewdirs' : use_viewdirs,
        'white_background' : white_background,
        'raw_noise_std' : raw_noise_std,
    }

    # NDC only good for LLFF-style forward facing data
    if dataset_type != 'llff' or no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = lindisp

    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.
    
    
    print("================[2][get training data, validation data]=================")

    global_step = start

    bds_dict = {
        'near' : near,
        'far' : far,
    }

    # near, far to dictionary
    render_kwargs_train.update(bds_dict)
    render_kwargs_test.update(bds_dict)

    
    
    # Move testing data to GPU
    render_poses = torch.Tensor(render_poses).to(device)
    #print(render_poses)
    #print(render_poses.shape) torch.Size([3, 3, 4])    
    
    # Prepare raybatch tensor if batching random rays
    N_rand = N_rand # 1024
    use_batching = not no_batching # no_batching = True
    
    
    
    if use_batching:
        # For random ray batching
        print('get rays')
        rays = np.stack([get_rays_np(H, W, K, p) for p in poses[:,:3,:4]], 0) # [N, ro+rd, H, W, 3]
        print('done, concats')
        rays_rgb = np.concatenate([rays, images[:,None]], 1) # [N, ro+rd+rgb, H, W, 3]
        rays_rgb = np.transpose(rays_rgb, [0,2,3,1,4]) # [N, H, W, ro+rd+rgb, 3]
        rays_rgb = np.stack([rays_rgb[i] for i in i_train], 0) # train images only
        rays_rgb = np.reshape(rays_rgb, [-1,3,3]) # [(N-1)*H*W, ro+rd+rgb, 3]
        rays_rgb = rays_rgb.astype(np.float32)
        print('shuffle rays')
        np.random.shuffle(rays_rgb)

        print('done')
        i_batch = 0

    # Move training data to GPU
    if use_batching:
        images = torch.Tensor(images).to(device)
    poses = torch.Tensor(poses).to(device)
    
    #print(poses)
    #print(poses.shape) #torch.Size([20, 3, 4])
    
    if use_batching:
        rays_rgb = torch.Tensor(rays_rgb).to(device)    

    N_iters = 100000 + 1
    print('Begin')
    print('TRAIN views are', i_train)
    print('TEST views are', i_test)
    print('VAL views are', i_val)

    start = start + 1   
    for i in trange(start, N_iters):
        time0 = time.time()
        
        # Sample random ray batch
        if use_batching:
            # Random over all images
            batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            batch_rays, target_s = batch[:2], batch[2]

            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0
                
                
        #####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H, W, K, chunk=chunk, rays=batch_rays, verbose=i < 10, return_raw=True, **render_kwargs_train)   
        
        #print("================[3][loss]=================")
        
        optimizer.zero_grad()
        # get the loss of model prediction and the target image
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)

        #input("Press Enter to continue...")
        
        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)

        loss.backward()
        optimizer.step()

        # NOTE: IMPORTANT!
        ###   update learning rate   ###
        decay_rate = 0.1
        decay_steps = lrate_decay * 1000
        new_lrate = lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################

        dt = time.time()-time0

        # Rest is logging
        if i%i_weights==0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save({
                'global_step': global_step,
                'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
                'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, path)
            print('Saved checkpoints at', path)

        if i%i_video==0 and i > 0:
            # Turn on testing mode
            with torch.no_grad():
                rgbs, disps = render_path(render_poses, hwf, K, chunk, render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
            moviebase = os.path.join(basedir, expname, '{}_spiral_{:06d}_'.format(expname, i))
            imageio.mimwrite(moviebase + 'rgb.mp4', to8b(rgbs), fps=30, quality=8)
            imageio.mimwrite(moviebase + 'disp.mp4', to8b(disps / np.max(disps)), fps=30, quality=8)

        if i%i_testset==0 and i > 0:
            testsavedir = os.path.join(basedir, expname, 'testset_{:06d}'.format(i))
            os.makedirs(testsavedir, exist_ok=True)
            print('test poses shape', poses[i_test].shape)
            with torch.no_grad():
                render_path(torch.Tensor(poses[i_test]).to(device), hwf, K, chunk, render_kwargs_test, gt_imgs=images[i_test], savedir=testsavedir)
            print('Saved test set')


    
        if i%i_print==0:
            tqdm.write(f"[TRAIN] Iter: {i} Loss: {loss.item()}  PSNR: {psnr.item()}")

        global_step += 1

In [None]:
train()

[[ 0.99569476 -0.02079598 -0.09033062 -0.3081002 ]
 [ 0.02503342  0.9986262   0.04603354  0.1346772 ]
 [ 0.0892492  -0.04809664  0.99484736  0.03989876]]
Loaded llff (20, 378, 504, 3) (120, 3, 5) [378.     504.     407.5658] ./data/nerf_llff_data/fern
Auto LLFF holdout, 8
DEFINING BOUNDS
NEAR FAR 0.0 1.0
[[407.5657959   0.        252.       ]
 [  0.        407.5657959 189.       ]
 [  0.          0.          1.       ]]
Found ckpts []
get rays
done, concats
shuffle rays
done
Begin
TRAIN views are [ 1  2  3  4  5  6  7  9 10 11 12 13 14 15 17 18 19]
TEST views are [ 0  8 16]
VAL views are [ 0  8 16]


  0%|                       | 101/100000 [00:16<4:35:39,  6.04it/s]

[TRAIN] Iter: 100 Loss: 0.04974399134516716  PSNR: 16.108917236328125


  0%|                       | 201/100000 [00:33<4:38:52,  5.96it/s]

[TRAIN] Iter: 200 Loss: 0.041705481708049774  PSNR: 16.85163116455078


  0%|                       | 301/100000 [00:49<4:34:34,  6.05it/s]

[TRAIN] Iter: 300 Loss: 0.03454691171646118  PSNR: 17.709415435791016


  0%|                       | 401/100000 [01:06<4:50:03,  5.72it/s]

[TRAIN] Iter: 400 Loss: 0.02851698361337185  PSNR: 18.407367706298828


  1%|                       | 501/100000 [01:24<4:47:15,  5.77it/s]

[TRAIN] Iter: 500 Loss: 0.02818201296031475  PSNR: 18.530838012695312


  1%|▏                      | 601/100000 [01:41<4:33:22,  6.06it/s]

[TRAIN] Iter: 600 Loss: 0.027459295466542244  PSNR: 18.674081802368164


  1%|▏                      | 701/100000 [01:58<4:38:06,  5.95it/s]

[TRAIN] Iter: 700 Loss: 0.025530800223350525  PSNR: 18.982576370239258


  1%|▏                      | 801/100000 [02:15<4:41:51,  5.87it/s]

[TRAIN] Iter: 800 Loss: 0.020086877048015594  PSNR: 19.898605346679688


  1%|▏                      | 901/100000 [02:32<4:44:00,  5.82it/s]

[TRAIN] Iter: 900 Loss: 0.02056136727333069  PSNR: 19.8769474029541


  1%|▏                     | 1001/100000 [02:50<4:50:18,  5.68it/s]

[TRAIN] Iter: 1000 Loss: 0.021127691492438316  PSNR: 19.7226619720459


  1%|▏                     | 1101/100000 [03:07<5:03:26,  5.43it/s]

[TRAIN] Iter: 1100 Loss: 0.0174005888402462  PSNR: 20.5457820892334


  1%|▎                     | 1201/100000 [03:24<4:34:45,  5.99it/s]

[TRAIN] Iter: 1200 Loss: 0.018036074936389923  PSNR: 20.515663146972656


  1%|▎                     | 1301/100000 [03:42<4:51:55,  5.64it/s]

[TRAIN] Iter: 1300 Loss: 0.018692269921302795  PSNR: 20.371692657470703


  1%|▎                     | 1401/100000 [03:59<4:49:01,  5.69it/s]

[TRAIN] Iter: 1400 Loss: 0.018849756568670273  PSNR: 20.329788208007812


  2%|▎                     | 1501/100000 [04:16<4:35:45,  5.95it/s]

[TRAIN] Iter: 1500 Loss: 0.016456592828035355  PSNR: 20.762516021728516


  2%|▎                     | 1601/100000 [04:33<4:31:18,  6.04it/s]

[TRAIN] Iter: 1600 Loss: 0.018721764907240868  PSNR: 20.248821258544922


  2%|▎                     | 1701/100000 [04:50<4:47:34,  5.70it/s]

[TRAIN] Iter: 1700 Loss: 0.016486097127199173  PSNR: 20.96978759765625


  2%|▍                     | 1801/100000 [05:07<4:33:20,  5.99it/s]

[TRAIN] Iter: 1800 Loss: 0.01693684048950672  PSNR: 20.736488342285156


  2%|▍                     | 1901/100000 [05:24<4:48:06,  5.67it/s]

[TRAIN] Iter: 1900 Loss: 0.016412552446126938  PSNR: 20.77263641357422


  2%|▍                     | 2001/100000 [05:41<4:34:02,  5.96it/s]

[TRAIN] Iter: 2000 Loss: 0.015589991584420204  PSNR: 21.29852867126465


  2%|▍                     | 2101/100000 [05:59<4:42:05,  5.78it/s]

[TRAIN] Iter: 2100 Loss: 0.012462159618735313  PSNR: 22.162721633911133


  2%|▍                     | 2201/100000 [06:16<5:08:08,  5.29it/s]

[TRAIN] Iter: 2200 Loss: 0.015192899852991104  PSNR: 21.307836532592773


  2%|▌                     | 2301/100000 [06:34<4:41:40,  5.78it/s]

[TRAIN] Iter: 2300 Loss: 0.014527879655361176  PSNR: 21.338153839111328


  2%|▌                     | 2401/100000 [06:52<4:49:29,  5.62it/s]

[TRAIN] Iter: 2400 Loss: 0.015139332041144371  PSNR: 21.383255004882812


  3%|▌                     | 2501/100000 [07:09<4:55:27,  5.50it/s]

[TRAIN] Iter: 2500 Loss: 0.01348717138171196  PSNR: 22.103736877441406


  3%|▌                     | 2601/100000 [07:26<4:42:34,  5.74it/s]

[TRAIN] Iter: 2600 Loss: 0.013039221987128258  PSNR: 21.856651306152344


  3%|▌                     | 2701/100000 [07:44<4:53:12,  5.53it/s]

[TRAIN] Iter: 2700 Loss: 0.014127520844340324  PSNR: 21.48214340209961


  3%|▌                     | 2801/100000 [08:02<4:48:40,  5.61it/s]

[TRAIN] Iter: 2800 Loss: 0.013877320103347301  PSNR: 21.46422004699707


  3%|▋                     | 2901/100000 [08:20<4:45:02,  5.68it/s]

[TRAIN] Iter: 2900 Loss: 0.012456824071705341  PSNR: 22.07975196838379


  3%|▋                     | 3001/100000 [08:39<4:54:13,  5.49it/s]

[TRAIN] Iter: 3000 Loss: 0.015921786427497864  PSNR: 20.99517059326172


  3%|▋                     | 3101/100000 [08:57<4:52:09,  5.53it/s]

[TRAIN] Iter: 3100 Loss: 0.01246616430580616  PSNR: 22.182151794433594


  3%|▋                     | 3163/100000 [09:08<4:51:02,  5.55it/s]

Shuffle data after an epoch!


  3%|▋                     | 3201/100000 [09:15<4:56:48,  5.44it/s]

[TRAIN] Iter: 3200 Loss: 0.013744794763624668  PSNR: 21.549026489257812


  3%|▋                     | 3301/100000 [09:32<4:36:02,  5.84it/s]

[TRAIN] Iter: 3300 Loss: 0.013368941843509674  PSNR: 21.935606002807617


  3%|▋                     | 3401/100000 [09:50<4:41:24,  5.72it/s]

[TRAIN] Iter: 3400 Loss: 0.012275535613298416  PSNR: 22.199539184570312


  4%|▊                     | 3501/100000 [10:06<4:40:04,  5.74it/s]

[TRAIN] Iter: 3500 Loss: 0.013288315385580063  PSNR: 21.865182876586914


  4%|▊                     | 3601/100000 [10:24<4:44:08,  5.65it/s]

[TRAIN] Iter: 3600 Loss: 0.011918652802705765  PSNR: 22.37919807434082


  4%|▊                     | 3701/100000 [10:41<4:34:01,  5.86it/s]

[TRAIN] Iter: 3700 Loss: 0.012154617346823215  PSNR: 22.00478744506836


  4%|▊                     | 3801/100000 [10:58<4:50:03,  5.53it/s]

[TRAIN] Iter: 3800 Loss: 0.011031098663806915  PSNR: 22.623111724853516


  4%|▊                     | 3901/100000 [11:15<4:34:59,  5.82it/s]

[TRAIN] Iter: 3900 Loss: 0.010649677366018295  PSNR: 22.905269622802734


  4%|▉                     | 4001/100000 [11:33<4:41:58,  5.67it/s]

[TRAIN] Iter: 4000 Loss: 0.011081960052251816  PSNR: 22.704181671142578


  4%|▉                     | 4101/100000 [11:50<4:36:53,  5.77it/s]

[TRAIN] Iter: 4100 Loss: 0.011958520859479904  PSNR: 22.23698616027832


  4%|▉                     | 4201/100000 [12:08<4:49:57,  5.51it/s]

[TRAIN] Iter: 4200 Loss: 0.011082327924668789  PSNR: 22.701507568359375


  4%|▉                     | 4301/100000 [12:26<4:56:44,  5.37it/s]

[TRAIN] Iter: 4300 Loss: 0.009892651811242104  PSNR: 23.105199813842773


  4%|▉                     | 4401/100000 [12:44<4:51:46,  5.46it/s]

[TRAIN] Iter: 4400 Loss: 0.014266838319599628  PSNR: 21.543394088745117


  5%|▉                     | 4501/100000 [13:02<4:46:38,  5.55it/s]

[TRAIN] Iter: 4500 Loss: 0.011355161666870117  PSNR: 22.451444625854492


  5%|█                     | 4601/100000 [13:19<4:25:23,  5.99it/s]

[TRAIN] Iter: 4600 Loss: 0.010375185869634151  PSNR: 22.959850311279297


  5%|█                     | 4626/100000 [13:24<4:22:45,  6.05it/s]