## Tiny NeRF
This is a simplied version of the method presented in *NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis*

[Project Website](http://www.matthewtancik.com/nerf)

[arXiv Paper](https://arxiv.org/abs/2003.08934)

[Full Code](https://github.com/bmild/nerf)

Components not included in the notebook:
  * 5D input including view directions
  * Hierarchical Sampling

In [None]:
import os, sys, time

import imageio
import ipywidgets as widgets
import matplotlib.pyplot as plt

from jax import grad, jit, lax, random, partial, vmap
from jax import numpy as jnp
from jax.experimental import optimizers, stax
from jax.experimental.stax import Dense, FanInConcat, FanOut, Identity, Relu
from jax.nn import relu, sigmoid
from tqdm.notebook import tqdm

from typing import Any, Optional, List, Tuple

%config InlineBackend.figure_format = 'svg'
%matplotlib inline

In [None]:
if not os.path.exists("/data/tiny_nerf_data.npz"):
    !curl -O https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz

# Load Input Images and Poses

In [None]:
data = jnp.load("/data/tiny_nerf_data.npz")
images = jnp.array(data["images"])
poses = jnp.array(data["poses"])
focal = data["focal"]
_, H, W, _ = images.shape

testimg, testpose = images[101], poses[101]
images = images[:100, ..., :3]
poses = poses[:100]

print(
    f"Images shape: {images.shape}\nPoses shape: {poses.shape}\nFocal value: {focal:.5f}"
)
plt.imshow(testimg)
plt.axis("off")
plt.show()

# Optimize NeRF

In [None]:
def build_model(D: int = 8, W: int = 256) -> Any:
    """Tiny NeRF network."""
    dense_block = lambda block_rep=1, W=W: [Dense(W), Relu] * block_rep
    sub_net = stax.serial(*dense_block(5))
    model = stax.serial(
        FanOut(2),
        stax.parallel(sub_net, Identity),
        FanInConcat(-1),
        *dense_block(3),
        Dense(4),
    )
    return model


def embed_fn(x: jnp.ndarray, L_embed: int) -> jnp.ndarray:
    """Positional encoder embedding."""
    rets = [x]
    for i in range(L_embed):
        for fn in [jnp.sin, jnp.cos]:
            rets.append(fn(2.0 ** i * x))
    return jnp.concatenate(rets, -1)
    # rets = vmap(lambda idx: 2.0 ** idx * x)(jnp.arange(L_embed))
    # res = jnp.concatenate([x[None, ...], jnp.sin(rets), jnp.cos(rets)], 0)
    # return jnp.reshape(jnp.swapaxes(res, 0, 1), [-1, 3 + 3 * 2 * L_embed])


@partial(jit, static_argnums=(0, 1, 2))
def get_rays(H: int, W: int, focal: float, c2w: jnp.ndarray) -> jnp.ndarray:
    """Generate ray matrices."""
    i, j = jnp.meshgrid(jnp.arange(W), jnp.arange(H), indexing="xy")
    dirs = jnp.stack(
        [(i - W * 0.5) / focal, -(j - H * 0.5) / focal, -jnp.ones_like(i)], -1
    )
    rays_d = jnp.sum(dirs[..., None, :] * c2w[:3, :3], axis=-1)
    rays_o = jnp.broadcast_to(c2w[:3, -1], rays_d.shape)
    return jnp.stack([rays_o, rays_d])


def render_rays(
    net_fn: Any,
    rays: jnp.ndarray,
    near: float = 2.0,
    far: float = 6.0,
    N_samples: int = 64,
    L_embed: int = 6,
    batch_size: int = 10000,
    rng: Optional[Any] = None,
    rand: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:

    rays_o, rays_d = rays

    # Compute 3D query points
    z_vals = jnp.linspace(near, far, N_samples)
    if rand:
        z_vals += (
            random.uniform(rng, list(rays_o.shape[:-1]) + [N_samples])
            * (far - near)
            / N_samples
        )
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]

    # Run network
    pts_flat = jnp.reshape(pts, [-1, 3])
    pts_flat = embed_fn(pts_flat, L_embed)

    # raw = net_fn(pts_flat)

    # def batchify(fn, chunk=1024 * 32):
    #     return lambda inputs: jnp.concatenate(
    #         [fn(inputs[i : i + chunk]) for i in range(0, inputs.shape[0], chunk)], 0,
    #     )

    # raw = batchify(net_fn)(pts_flat)

    raw = lax.map(net_fn, jnp.reshape(pts_flat, [-1, batch_size, pts_flat.shape[-1]]))
    raw = jnp.reshape(raw, list(pts.shape[:-1]) + [4])

    # Compute opacities and colors
    sigma_a = relu(raw[..., 3])
    rgb = sigmoid(raw[..., :3])

    # Do volume rendering
    dists = jnp.concatenate(
        [
            z_vals[..., 1:] - z_vals[..., :-1],
            jnp.broadcast_to([1e10], z_vals[..., :1].shape),
        ],
        -1,
    )

    alpha = 1.0 - jnp.exp(-sigma_a * dists)
    alpha_ = jnp.minimum(1.0, 1.0 - alpha + 1e-10)
    trans = jnp.concatenate([jnp.ones_like(alpha_[..., :1]), alpha_[..., :-1]], -1)
    weights = alpha * jnp.cumprod(trans, -1)

    rgb_map = jnp.sum(weights[..., None] * rgb, -2)
    depth_map = jnp.sum(weights * z_vals, -1)
    acc_map = jnp.sum(weights, -1)

    return rgb_map, depth_map, acc_map

In [None]:
L_embed = 6
key = random.PRNGKey(0)

init_fn, model_fn = build_model()
_, model_params = init_fn(key, input_shape=(3 + 3 * 2 * L_embed,))
opt_init, opt_update, get_params = optimizers.adam(5e-4)
opt_state = opt_init(model_params)

train_rays = lax.map(lambda pose: get_rays(H, W, focal, pose), poses)
test_rays = get_rays(H, W, focal, testpose)

In [None]:
def loss_fun(
    params: jnp.ndarray,
    batch: Tuple[jnp.ndarray, jnp.ndarray],
    rng: Optional[Any] = None,
    rand: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Compute loss function for optimizer and return generated image."""
    rays, target = batch
    model_fn_ = partial(model_fn, params)
    rgb, _, _ = render_rays(model_fn_, rays, rng=rng, rand=rand)
    return jnp.mean(jnp.square(rgb - target)), rgb


@jit
def update(i: int, opt_state: Any, rng: Any) -> Any:
    """Train step."""
    img_rng, fn_rng = random.split(random.fold_in(rng, i))
    img_idx = random.randint(img_rng, (1,), minval=0, maxval=images.shape[0])[0]
    batch = (train_rays[img_idx], images[img_idx])
    params = get_params(opt_state)
    grads, _ = grad(loss_fun, has_aux=True)(params, batch, fn_rng, True)
    return opt_update(i, grads, opt_state)


@jit
def evaluate(params: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Evaluation step w/ PSNR metric."""
    loss, rgb = loss_fun(params, (test_rays, testimg))
    psnr = -10.0 * jnp.log(loss) / jnp.log(10.0)
    return rgb, psnr

In [None]:
N_iters = 1000
psnrs: List[float] = []
iternums: List[int] = []
i_plot = 100

for i in range(N_iters + 1):
    t = time.perf_counter()
    opt_state = update(i, opt_state, key)

    if i % i_plot == 0:
        print(f"Iterations: {i:4d}\t{time.perf_counter() - t:2.5f} sec/iter", end="")
        rgb, psnr = evaluate(get_params(opt_state))
        print(f"\tPSNR: {psnr:.5f}")
        psnrs.append(psnr)
        iternums.append(i)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
        ax1.imshow(rgb)
        ax1.axis("off")
        ax2.plot(iternums, psnrs)
        plt.show()

final_model_fn = partial(model_fn, get_params(opt_state))

# Interactive Visualization

In [None]:
trans_t = lambda t: jnp.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, t], [0, 0, 0, 1],])

rot_phi = lambda phi: jnp.array(
    [
        [1, 0, 0, 0],
        [0, jnp.cos(phi), -jnp.sin(phi), 0],
        [0, jnp.sin(phi), jnp.cos(phi), 0],
        [0, 0, 0, 1],
    ]
)

rot_theta = lambda th: jnp.array(
    [
        [jnp.cos(th), 0, -jnp.sin(th), 0],
        [0, 1, 0, 0],
        [jnp.sin(th), 0, jnp.cos(th), 0],
        [0, 0, 0, 1],
    ]
)


@jit
def get_rgb(c2w: jnp.ndarray):
    rays = get_rays(H, W, focal, c2w[:3, :4])
    rgb, depth, acc = render_rays(final_model_fn, rays)
    img = (255 * jnp.clip(rgb, 0, 1)).astype(jnp.uint8)
    return rgb, depth, acc, img


@jit
def pose_spherical(theta: float, phi: float, radius: float) -> jnp.ndarray:
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180.0 * jnp.pi) @ c2w
    c2w = rot_theta(theta / 180.0 * jnp.pi) @ c2w
    c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
    return c2w


def f(**kwargs) -> None:
    c2w = pose_spherical(**kwargs)
    rgb, _, _, _ = get_rgb(c2w)
    img = jnp.clip(rgb, 0, 1)

    plt.figure(2, figsize=(20, 6))
    plt.imshow(img)
    plt.axis("off")
    plt.show()


sldr = lambda v, mi, ma: widgets.FloatSlider(value=v, min=mi, max=ma, step=0.01)

names = [
    ["theta", [100.0, 0.0, 360]],
    ["phi", [-30.0, -90, 0]],
    ["radius", [4.0, 3.0, 5.0]],
]

interactive_plot = widgets.interactive(f, **{s[0]: sldr(*s[1]) for s in names})
output = interactive_plot.children[-1]
output.layout.height = "350px"
interactive_plot

# Render 360 Video

In [None]:
video_angle = jnp.linspace(0.0, 360.0, 120, endpoint=False)
v_c2w = vmap(lambda th: pose_spherical(th, -30.0, 4.0))(video_angle)
*_, frames = zip(*map(get_rgb, v_c2w))

file_name = "video.mp4"
imageio.mimwrite(file_name, tuple(frames), fps=30, quality=7)

In [None]:
%%HTML
<video width="500" height="500" controls>
  <source src="video.mp4" type="video/mp4">
</video>