In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"  # specify which GPU(s) to be used

import sys
import attr
# insert at 1, 0 is the script path (or '' in REPL)
sys.path.insert(1, '../src')

from dataset import DatasetConfig, DatasetBuilder
from geometry import Rays
import os
from functools import partial
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

from collections import defaultdict
from typing import Callable, Dict, List, Tuple

from nerf import NerfConfig, Nerf
import nerf_utils
import jax.numpy as jnp

import jax
from jax import jit, random
from trainer import Trainer, TrainerConfig

%matplotlib inline

plt.rcParams["figure.figsize"] = (12, 12)

In [2]:
ds_config =  DatasetConfig(model_dir='../dataset/pinecone/sparse/0/',
                           images_dir='../dataset/pinecone/images/',
                           batch_from_single_image=True,
                           batch_size=8)
trainer_config = TrainerConfig(dataset_config=ds_config)
trainer = Trainer(trainer_config)
train_iter, val_iter = trainer.create_dataset(to_device=False)

In [3]:
rays = next(train_iter)

In [4]:
state = trainer.create_train_state(rays)

In [10]:
coarse_rgb, fine_rgb = state.apply_fn(state.params, trainer._rng, jax.tree_map(lambda x: x[0], rays))

In [11]:
coarse_rgb

DeviceArray([[0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.32859775, 0.34741738, 0.30696827],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.        , 0.        , 0.        ],
             [0.23736162, 0.25658837, 0.17968707],
             [0.35462773, 0.331079  , 0.24922337]], dtype=float32)

In [None]:
@attr.s(frozen=True, auto_attribs=True)
class NerfBuilder:
    config: NerfConfig

    def build(self, rng: jnp.ndarray, rays: Rays):
        model = Nerf(self.config)
        key1, key2, rng = random.split(rng, num=3)
        init_params = model.init(
            key1, rng=key2, rays=rays)
        return model, init_params

rng = random.PRNGKey(0)
rays_local = jax.tree_map(lambda x:x[0], rays)
model, init_params = NerfBuilder(NerfConfig()).build(rng, rays_local)
coarse_rgb, fine_rgb = model.apply(init_params, rng, rays_local)

In [None]:
coarse_rgb

In [12]:
EPS = 1e-10

def volumetric_rendering(points_rgb, points_sigma, points_z, dirs, white_bkgd=True):
    dists = points_z[..., 1:] - points_z[..., :-1]
    dists = jnp.concatenate([dists, 1e10 * jnp.ones_like(dists[..., :1])], axis=-1)
    dists = dists * jnp.linalg.norm(dirs, axis=-1, keepdims=True)

    dists_sigma = points_sigma[..., 0] * dists
    alpha = 1.0 - jnp.exp(-dists_sigma)

    transmit = jnp.exp(-jnp.cumsum(dists_sigma[:, :-1], axis=-1))
    transmit = jnp.concatenate([jnp.ones_like(transmit[:, :1]), transmit], axis=-1)
    weights = alpha * transmit
    
    rgb = jnp.sum(weights[..., jnp.newaxis] * points_rgb, axis=1)
    depth = jnp.sum(weights * points_z, axis=1, keepdims=True)
    
    acc = jnp.sum(weights, axis=-1, keepdims=True)
    disp = acc / (depth + EPS)
    if white_bkgd:
        rgb = rgb + (1. - acc)
    return rgb, disp, acc, weights

points_rgb = random.uniform(rng, (64, 100, 3)) / 1000
points_sigma = random.uniform(rng, (64, 100, 1)) / 1000
rgb, disp, acc, weights = nerf_utils.volumetric_rendering(points_rgb, points_sigma, points_z, rays.directions)

In [13]:
rgb

DeviceArray([[0.00015901, 0.00037874, 0.0004901 ],
             [0.00029306, 0.00049331, 0.00040989],
             [0.00074104, 0.00071467, 0.00047685],
             [0.00046635, 0.00074081, 0.00072935],
             [0.00036115, 0.00029611, 0.00028599],
             [0.00081129, 0.00040781, 0.00028003],
             [0.00068469, 0.00026013, 0.00054004],
             [0.00021147, 0.00046594, 0.00066487],
             [0.00074065, 0.00069683, 0.00037903],
             [0.00063101, 0.00024434, 0.00079861],
             [0.00057646, 0.00052661, 0.0005742 ],
             [0.00066256, 0.00073601, 0.0006188 ],
             [0.00056572, 0.00045589, 0.00035953],
             [0.00053161, 0.00033082, 0.0006    ],
             [0.00023862, 0.00050335, 0.00044   ],
             [0.00019101, 0.00024335, 0.00034161],
             [0.00034782, 0.00065425, 0.00066272],
             [0.00040164, 0.0003118 , 0.00036887],
             [0.00080923, 0.00032702, 0.00032641],
             [0.00055754, 0.000

In [7]:
import functools

from flax import jax_utils

num_rays = rays.origins.shape[0]
num_bins = 100
bins = jnp.linspace(0.001, 1000, num_bins)
bins = jnp.broadcast_to(bins, (num_rays, num_bins))
weights = jnp.ones_like(bins)
# sampler_fn = jit(functools.partial(nerf_utils.sample_along_rays,
#                                    100,
#                                    randomized=True))
sampler_fn = functools.partial(nerf_utils.sample_along_rays,
                               num_samples=100,
                               randomized=True)
points, points_z = sampler_fn(rng, rays, bins, weights)
encode = nerf_utils.positional_encoding(points, 0, 10)

encode.shape

(64, 100, 63)

In [8]:
points_z.shape

(64, 100)

In [6]:
encode = nerf_utils.positional_encoding(rays.directions[:, jnp.newaxis, :], 0, 10)

encode.shape

(64, 1, 63)

In [54]:
rays.directions.shape

(64, 3)