# Tiny NeRF with Flax

<a href="https://colab.research.google.com/github/myagues/flax_nerf/blob/main/tiny_nerf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is a simplied version of the method presented in *NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis*, using Flax and supporting multiple device TPU or GPU training.

Original work:
- [Project Website](https://www.matthewtancik.com/nerf)
- [arXiv Paper](https://arxiv.org/abs/2003.08934)
- [Full Code](https://www.github.com/bmild/nerf)

Components not included in the notebook:
- 5D input including view directions
- Hierarchical Sampling

In [None]:
!pip install -q git+https://www.github.com/google/flax

In [None]:
import functools
import os
import time

import imageio
import jax

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np

from flax import jax_utils, linen as nn, optim
from flax.training import common_utils
from jax import numpy as jnp, lax, random
from jax.config import config
from typing import Any, Callable, Sequence

%config InlineBackend.figure_format = 'svg'
%matplotlib inline

config.enable_omnistaging()

In [None]:
# boilerplate for connecting JAX to TPU
if "google.colab" in str(get_ipython()) and "COLAB_TPU_ADDR" in os.environ:
    import requests

    url = f"http://{os.environ['COLAB_TPU_ADDR'].split(':')[0]}:8475/requestversion/tpu_driver0.1-dev20200416"
    resp = requests.post(url)
    assert resp.status_code == 200
    TPU_DRIVER_MODE = 1

    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = f"grpc://{os.environ['COLAB_TPU_ADDR']}"
    print(f"Registered TPU: {config.FLAGS.jax_backend_target}")
else:
    print("No TPU detected.")
print(jax.local_devices())

In [None]:
if not os.path.exists("tiny_nerf_data.npz"):
    !curl -O https://people.eecs.berkeley.edu/~bmild/nerf/tiny_nerf_data.npz

data = np.load("tiny_nerf_data.npz")
images = jnp.array(data["images"])
poses = data["poses"]
focal = float(data["focal"])
_, img_h, img_w, _ = images.shape

testimg, testpose = images[101], poses[101]
images = images[:100, ..., :3]
poses = poses[:100]

print(f"Images shape: {images.shape}")
print(f"Poses shape: {poses.shape}")
print(f"Focal value: {focal:.5f}")

plt.imshow(testimg)
plt.axis("off")
plt.show()

## Optimize NeRF

In [None]:
@functools.partial(jax.jit, static_argnums=(0, 1, 2))
def get_rays(img_h, img_w, focal, c2w):
    """Generate ray matrices."""
    i, j = jnp.meshgrid(jnp.arange(img_w), jnp.arange(img_h), indexing="xy")
    dirs = jnp.stack(
        [(i - img_w * 0.5) / focal, -(j - img_h * 0.5) / focal, -jnp.ones_like(i)], -1
    )
    rays_d = jnp.einsum("ijl,kl", dirs, c2w[:3, :3])
    rays_o = jnp.broadcast_to(c2w[:3, -1], rays_d.shape)
    return jnp.stack([rays_o, rays_d])


def render_rays(
    net_fn,
    rays,
    near=2.0,
    far=6.0,
    num_samples=64,
    batch_size=10000,
    rng=None,
    rand=False,
):
    rays_o, rays_d = rays
    # Compute 3D query points
    z_vals = jnp.linspace(near, far, num_samples)
    z_shape = [*rays_o.shape[:-1], num_samples]
    if rand:
        z_vals += random.uniform(rng, z_shape) * (far - near) / num_samples
    pts = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., :, None]

    # Run network
    raw = lax.map(net_fn, jnp.reshape(pts, [-1, batch_size, 3]))
    raw = jnp.reshape(raw, [*pts.shape[:-1], 4])
    
    # Compute opacities and colors
    sigma_a = nn.relu(raw[..., 3])
    rgb = nn.sigmoid(raw[..., :3])

    # Do volume rendering
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = jnp.concatenate(
        [dists, jnp.broadcast_to([1e10], dists[..., :1].shape)], axis=-1
    )

    alpha = 1.0 - jnp.exp(-sigma_a * dists)
    alpha_ = jnp.clip(1.0 - alpha, 1e-10, 1.0)
    trans = jnp.concatenate([jnp.ones_like(alpha_[..., :1]), alpha_[..., :-1]], -1)
    weights = alpha * jnp.cumprod(trans, -1)  # (img_h, img_w, num_samples)

    rgb_map = jnp.einsum("...k,...kl", weights, rgb)
    depth_map = jnp.einsum("...k,...k", weights, z_vals)
    acc_map = jnp.einsum("...k->...", weights)

    return rgb_map, depth_map, acc_map

Out of memory (OOM) errors can appear when using TPUs (TPUv2 in Colab have 8GB of HBM memory, whereas GPUs range from 12GB to 16GB), as they use a padding mechanism (read the [TPU performance guide](https://cloud.google.com/tpu/docs/performance-guide#consequences_of_tiling) for more information). To work around these limitations, you can:
- reduce `net_width` and / or `nwt_depth` (worse results)
- enable `nn.remat` decorator (slower time per step). More about `jax.remat` in [JAX #1749](https://github.com/google/jax/pull/1749).

In [None]:
class NeRF(nn.Module):
    net_depth: int = 8
    net_width: int = 256
    skips: Sequence[int] = (4,)
    periodic_fns: Sequence[Callable] = (jnp.sin, jnp.cos)
    output_channels: int = 4
    use_embedding: bool = True
    l_embed: int = 6
    dtype: Any = jnp.float32

    def embed(self, inputs):
        batch_size, _ = inputs.shape
        inputs_freq = jax.vmap(lambda x: inputs * 2.0 ** x)(jnp.arange(self.l_embed))
        fns = jnp.stack([fn(inputs_freq) for fn in self.periodic_fns])
        fns = fns.swapaxes(0, 2).reshape([batch_size, -1])
        fns = jnp.concatenate([inputs, fns], axis=-1)
        return fns

    @nn.remat
    @nn.compact
    def __call__(self, inputs_pts):
        x = self.embed(inputs_pts) if self.use_embedding else inputs_pts
        for i in range(self.net_depth):
            x = nn.Dense(self.net_width, dtype=self.dtype)(x)
            x = nn.relu(x)
            if i in self.skips:
                x = jnp.concatenate([x, inputs_pts], axis=-1)
        x = nn.Dense(self.output_channels, dtype=self.dtype)(x)
        return x


def initialized(key, input_pts_shape):
    model = NeRF()
    initial_params = model.init(
        {"params": key},
        jnp.ones(input_pts_shape),
    )
    return model, initial_params["params"]

In [None]:
def update(opt, rng):
    """Train step."""
    rng0, rng1 = random.split(rng)
    idx = random.randint(rng0, (1,), minval=0, maxval=images.shape[0])[0]

    def loss_fn(params):
        model_fn = lambda x: model.apply({"params": params}, x)
        rgb, *_ = render_rays(model_fn, train_rays[idx], rng=rng1, rand=True)
        return jnp.mean(jnp.square(rgb - images[idx]))

    grads = jax.grad(loss_fn)(opt.target)
    grads = lax.pmean(grads, axis_name="batch")
    new_opt = opt.apply_gradient(grads)
    return new_opt


@jax.jit
def evaluate(params):
    """Evaluation step w/ PSNR metric."""
    model_fn = lambda x: model.apply({"params": params}, x)
    rgb, *_ = render_rays(model_fn, test_rays)
    loss = jnp.mean(jnp.square(rgb - testimg))
    psnr = -10.0 * jnp.log(loss) / jnp.log(10.0)
    return rgb, psnr


p_update = jax.pmap(functools.partial(update), axis_name="batch")

train_rays = lax.map(lambda pose: get_rays(img_h, img_w, focal, pose), poses)
test_rays = get_rays(img_h, img_w, focal, testpose)

In [None]:
key = random.PRNGKey(0)
key, rng = random.split(key)
model, params = initialized(key, (10000, 3))
optimizer = optim.Adam(learning_rate=5e-4).create(params)
optimizer = jax_utils.replicate(optimizer)

psnrs = []
iternums = []
num_iters = 1000
i_plot = 50

for i in range(num_iters + 1):
    t = time.time()
    rng, rng_step = random.split(rng)
    sharded_rngs = common_utils.shard_prng_key(rng_step)
    optimizer = p_update(optimizer, sharded_rngs)

    if i % i_plot == 0:
        t_end = time.time() - t
        optimizer_ = jax_utils.unreplicate(optimizer)
        rgb, psnr = evaluate(optimizer_.target)
        print(f"Iters: {i:4d}\t{t_end:2.5f} sec/iter\tPSNR: {psnr:.5f}")
        psnrs.append(psnr)
        iternums.append(i)

        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
        ax1.imshow(rgb)
        ax1.axis("off")
        ax2.plot(iternums, psnrs)
        plt.show()

optimizer = jax_utils.unreplicate(optimizer)

## Interactive Visualization

In [None]:
trans_t = lambda t: jnp.array(
    [
        [1, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 0, 1, t],
        [0, 0, 0, 1],
    ]
)

rot_phi = lambda phi: jnp.array(
    [
        [1, 0, 0, 0],
        [0, jnp.cos(phi), -jnp.sin(phi), 0],
        [0, jnp.sin(phi), jnp.cos(phi), 0],
        [0, 0, 0, 1],
    ]
)

rot_theta = lambda th: jnp.array(
    [
        [jnp.cos(th), 0, -jnp.sin(th), 0],
        [0, 1, 0, 0],
        [jnp.sin(th), 0, jnp.cos(th), 0],
        [0, 0, 0, 1],
    ]
)


@jax.jit
def get_rgb(c2w):
    rays = get_rays(img_h, img_w, focal, c2w[:3, :4])
    model_fn = lambda x: model.apply({"params": optimizer.target}, x)
    rgb, depth, acc = render_rays(model_fn, rays)
    img = (255 * jnp.clip(rgb, 0, 1)).astype(jnp.uint8)
    return rgb, depth, acc, img


@jax.jit
def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi / 180.0 * jnp.pi) @ c2w
    c2w = rot_theta(theta / 180.0 * jnp.pi) @ c2w
    c2w = jnp.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
    return c2w


def f(**kwargs) -> None:
    c2w = pose_spherical(**kwargs)
    rgb, _, _, _ = get_rgb(c2w)
    img = jnp.clip(rgb, 0, 1)

    plt.figure(2, figsize=(20, 6))
    plt.imshow(img)
    plt.axis("off")
    plt.show()


sldr = lambda v, mi, ma: widgets.FloatSlider(value=v, min=mi, max=ma, step=0.01)

names = [
    ["theta", [100.0, 0.0, 360]],
    ["phi", [-30.0, -90, 0]],
    ["radius", [4.0, 3.0, 5.0]],
]

interactive_plot = widgets.interactive(f, **{s[0]: sldr(*s[1]) for s in names})
output = interactive_plot.children[-1]
output.layout.height = "475px"
interactive_plot

## Render 360 Video

In [None]:
video_angle = jnp.linspace(0.0, 360.0, 120, endpoint=False)
v_c2w = jax.vmap(lambda th: pose_spherical(th, -30.0, 4.0))(video_angle)
*_, frames = zip(*map(get_rgb, v_c2w))
frames = map(np.array, frames)

file_name = "video.mp4"
imageio.mimwrite(file_name, tuple(frames), fps=30, quality=7)

In [None]:
# %%HTML
# <video width="500" controls autoplay loop>
#   <source src="video.mp4" type="video/mp4">
# </video>

In [None]:
from IPython.display import HTML
from base64 import b64encode

mp4 = open("video.mp4", "rb").read()
data_url = f"data:video/mp4;base64,{b64encode(mp4).decode()}"
HTML(
    """
<video width=500 controls autoplay loop>
      <source src="%s" type="video/mp4">
</video>
"""
    % data_url
)