In [1]:
import sys
import attr
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, '../src')

from dataset import DatasetConfig, DatasetBuilder
from geometry import Rays
import os
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from collections import defaultdict
from typing import Callable, Dict, List, Tuple

from nerf import NerfConfig, Nerf
import nerf_utils
import jax.numpy as jnp

import jax
from jax import jit, random

%matplotlib inline

plt.rcParams["figure.figsize"] = (12, 12)

In [2]:
ds_config =  DatasetConfig(model_dir='../dataset/pinecone/sparse/0/',
                           images_dir='../dataset/pinecone/images/',
                           batch_from_single_image=True,
                           batch_size=64)
dataset_builder = DatasetBuilder(ds_config)
train_ds = dataset_builder.build_train_dataset()

In [10]:
data = list(train_ds.take(1).as_numpy_iterator())[0]

In [11]:
def prepare_inputs(data):
    return Rays(*data)

@attr.s(frozen=True, auto_attribs=True)
class NerfBuilder:
    config: NerfConfig

    def build(self, rng: jnp.ndarray, rays: Rays):
        model = Nerf(self.config)
        key1, key2, rng = random.split(rng, num=3)
        init_params = model.init(
            key1, rng=key2, rays=rays)
        return model, init_params

rng = random.PRNGKey(0)
rays = prepare_inputs(data)
model, init_params = NerfBuilder(NerfConfig()).build(rng, rays)
coarse_rgb, fine_rgb = model.apply(init_params, rng, rays)

In [9]:
coarse_rgb

DeviceArray([[0.40388322, 0.4156903 , 0.33409858],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.07171033, 0.08403395, 0.06920688],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.2193469 , 0.23089232, 0.19843762],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.2261976 , 0.21365152, 0.15540776],
             [0.        , 0.        , 0.        ],
             [0.5547238 , 0.6250224 , 0.52054083],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.        , 0.   

In [27]:
EPS = 1e-10

def volumetric_rendering(points_rgb, points_sigma, points_z, dirs, white_bkgd=True):
    dists = points_z[..., 1:] - points_z[..., :-1]
    dists = jnp.concatenate([dists, 1e10 * jnp.ones_like(dists[..., :1])], axis=-1)
    dists = dists * jnp.linalg.norm(dirs, axis=-1, keepdims=True)

    dists_sigma = points_sigma[..., 0] * dists
    alpha = 1.0 - jnp.exp(-dists_sigma)

    transmit = jnp.exp(-jnp.cumsum(dists_sigma[:, :-1], axis=-1))
    transmit = jnp.concatenate([jnp.ones_like(transmit[:, :1]), transmit], axis=-1)
    weights = alpha * transmit
    
    rgb = jnp.sum(weights[..., jnp.newaxis] * points_rgb, axis=1)
    depth = jnp.sum(weights * points_z, axis=1, keepdims=True)
    
    acc = jnp.sum(weights, axis=-1, keepdims=True)
    disp = acc / (depth + EPS)
    if white_bkgd:
        rgb = rgb + (1. - acc)
    return rgb, disp, acc, weights

points_rgb = -random.uniform(rng, (64, 100, 3)) / 1000
points_sigma = -random.uniform(rng, (64, 100, 1)) / 1000
rgb, disp, acc, weights = nerf_utils.volumetric_rendering(points_rgb, points_sigma, points_z, rays.directions)

In [28]:
rgb

DeviceArray([[inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
             [inf, inf, inf],
          

In [7]:
import functools

from flax import jax_utils

num_rays = rays.origins.shape[0]
num_bins = 100
bins = jnp.linspace(0.001, 1000, num_bins)
bins = jnp.broadcast_to(bins, (num_rays, num_bins))
weights = jnp.ones_like(bins)
# sampler_fn = jit(functools.partial(nerf_utils.sample_along_rays,
#                                    100,
#                                    randomized=True))
sampler_fn = functools.partial(nerf_utils.sample_along_rays,
                               num_samples=100,
                               randomized=True)
points, points_z = sampler_fn(rng, rays, bins, weights)
encode = nerf_utils.positional_encoding(points, 0, 10)

encode.shape

(64, 100, 63)

In [8]:
points_z.shape

(64, 100)

In [6]:
encode = nerf_utils.positional_encoding(rays.directions[:, jnp.newaxis, :], 0, 10)

encode.shape

(64, 1, 63)

In [54]:
rays.directions.shape

(64, 3)