This Colab is a tool to visualize the 3D NeRF and Semantic Scene Representations produced as a part of NeSF: Neural Semantic Fields


The project website for NeSF can be found here: https://nesf3d.github.io/


Accompanying code can be found on GitHub at: https://github.com/google-research/jax3d/tree/main/jax3d/projects/nesf

# Setup

In [None]:
# @title Set up environment
import sys
!git clone https://github.com/google-research/jax3d.git
%cd /content/jax3d
!python -m pip install --upgrade pip
!pip install .
!pip install --upgrade "jax3d[nesf]"
!pip install --upgrade "jax[cpu]"
!pip install flax==0.5.3

In [None]:
# @title Configure datasets and checkpoints
!wget https://storage.googleapis.com/kubric-public/data/NeSFDatasets/NeSF%20datasets/klevr.tar.gz
!tar -xvf klevr.tar.gz
!rm klevr.tar.gz
!wget https://storage.googleapis.com/kubric-public/data/NeSFDatasets/NeRF%20checkpoints/klevr.tar.gz
!mkdir klevr_checkpoints
!mv klevr.tar.gz klevr_checkpoints
%cd klevr_checkpoints/
!tar -xvf klevr.tar.gz
!wget https://storage.googleapis.com/kubric-public/data/NeSFDatasets/NeSFCheckpoints/klevr.tar.gz
!mkdir klevr_semantic_checkpoints
!mv klevr.tar.gz klevr_semantic_checkpoints
%cd klevr_semantic_checkpoints/
!tar -xvf klevr.tar.gz

In [None]:
# @title Reorganize code
%cd /content
!mv jax3d jax3d_old
!mv /content/jax3d_old/* /content
!rm -R /content/jax3d_old
%cd /content

# Imports

In [None]:
import re

import sys
import chex
import flax
import imageio
import plotly.graph_objects as go
import numpy as np
import scipy
import sklearn
import seaborn as sns
import matplotlib.pyplot as plt
import mediapy
import jax
import pandas as pd
import plotly.express as px
from jax import numpy as jnp
from flax import linen as nn
import tensorflow as tf


In [None]:
import gin
gin.enter_interactive_mode()  # Avoid error when reloading modules

import jax3d.projects.nesf as j3d
from jax3d.projects.nesf import nerfstatic as nf

from jax3d.projects.nesf.nerfstatic.utils import train_utils
from jax3d.projects.nesf.nerfstatic.utils import eval_utils
from jax3d.projects.nesf.nerfstatic import datasets
from jax3d.projects.nesf.nerfstatic.datasets import klevr
from jax3d.projects.nesf.nerfstatic.models import models
from jax3d.projects.nesf.nerfstatic.models import model_utils
from jax3d.projects.nesf.nerfstatic.nerf import utils
from jax3d.projects.nesf.nerfstatic.utils import config as nerf_config
from jax3d.projects.nesf.nerfstatic.utils import types
from jax3d.projects.nesf.nerfstatic.utils import semantic_utils

import importlib
importlib.reload(j3d)


# Load Dataset

In [None]:
GIN_CONFIG = """
ConfigParams.models = "NerfParams"

DatasetParams.batch_size = 16384
DatasetParams.data_dir = '/content/klevr'
TrainParams.nerf_model_ckpt = '/content/klevr_checkpoints'
DatasetParams.eval_scenes = '0:1'
DatasetParams.novel_scenes = '80:81'
DatasetParams.train_scenes = '0:1'

DatasetParams.dataset = 'klevr'
DatasetParams.factor = 0
DatasetParams.num_scenes_per_batch = 1
DatasetParams.max_num_train_images_per_scene = 9
DatasetParams.max_num_test_images_per_scene = 4

ModelParams.num_semantic_classes = 6  # Needed to make semantic predictions.

TrainParams.mode = "SEMANTIC"
TrainParams.print_every = 100
TrainParams.train_dir = "/content/klevr_semantic_checkpoints/klevr/"  # Will be overriden by XManager.
TrainParams.train_steps = 25000
TrainParams.save_every = 500
TrainParams.semantic_smoothness_regularization_num_points_per_device = 8192
TrainParams.semantic_smoothness_regularization_weight = 0.01
TrainParams.semantic_smoothness_regularization_stddev = 0.05
TrainParams.nerf_model_recompute_sigma_grid = True
TrainParams.nerf_model_recompute_sigma_grid_shape = (64, 64, 64)
TrainParams.nerf_model_recompute_sigma_grid_convert_sigma_to_density = True

EvalParams.chunk = 32788
EvalParams.eval_num_log_images = 8
EvalParams.eval_once = False

ModelParams.unet_depth = 3
ModelParams.unet_feature_size = (32, 64, 128, 256)
ModelParams.num_fine_samples = 192
ModelParams.apply_random_scene_rotations = True
"""

In [None]:
from absl import app

# Addresses `UnrecognizedFlagError: Unknown command line flag 'f'`
sys.argv = sys.argv[:1]

# `app.run` calls `sys.exit`
try:
  app.run(lambda argv: None)
except:
  pass

In [None]:
# Load experiment HParams

gin.clear_config()
gin.parse_config(GIN_CONFIG)
params = nerf_config.root_config_from_flags()

params

In [None]:
# Load Dataset

DATA_DIR_ROOT = params.datasets.data_dir
SCENE_ID = params.datasets.train_scenes.split(':')[0]
DATA_DIR = DATA_DIR_ROOT / SCENE_ID

In [None]:
rng = j3d.RandomState(params.train.random_seed)

In [None]:
dataset = datasets.get_dataset(
    split="train",
    args=params.datasets,
    model_args=params.models,
    example_type=datasets.ExampleType.RAY,
    ds_state=None,
    is_novel_scenes=False,
)

In [None]:
_, placeholder_batch = dataset.peek()
placeholder_batch = jax.tree.map(lambda t: t[0, 0, ...], placeholder_batch)
print(placeholder_batch.target_view.rays.scene_id.shape)
print('scene_name:', dataset.all_metadata[0].scene_name)

In [None]:
recompute_sigma_grid_opts = semantic_utils.RecomputeSigmaGridOptions.from_params(params.train)

In [None]:
# Initialize & load per-scene NeRF models.

recovered_nerf_state = semantic_utils.load_all_nerf_variables(
    save_dir=params.train.nerf_model_ckpt,
    train_dataset=dataset,
    novel_dataset=None,
    recompute_sigma_grid_opts=recompute_sigma_grid_opts
)

In [None]:
# Select pretrained NeRF corresponding to scene 0.

# scene_id corresponding to ray=0.
scene_id = placeholder_batch.target_view.rays.scene_id[0, 0]
print('scene_id:', scene_id)

nerf_variables = semantic_utils.select_and_stack([scene_id],
                                                  recovered_nerf_state.train_variables,
                                                  num_devices=1)
nerf_sigma_grid = semantic_utils.select_and_stack([scene_id],
                                                  recovered_nerf_state.train_sigma_grids,
                                                  num_devices=1)

In [None]:
# Extract NeRF state corresponding to device=0, scene=0.

nerf_variables = jax.tree.map(lambda x: x[0, 0], nerf_variables)
nerf_sigma_grid = jax.tree.map(lambda x: x[0, 0], nerf_sigma_grid)

In [None]:
# Drop the first dimension of nerf_sigma_grid if necessary.

if len(nerf_sigma_grid.shape) == 5:
  assert nerf_sigma_grid.shape[0] == 1
  nerf_sigma_grid = nerf_sigma_grid[0]
nerf_sigma_grid.shape

# Visualize Sigma Field

In [None]:
# Plot XYZ values where density > min_density for various values of min_density.

def plot_density_coordinates(kept_points, color=None, ax=None):
  ii, jj, kk = kept_points[:, 0], kept_points[:, 1], kept_points[:, 2]
  if ax is None:
    ax = plt.axes(projection='3d')

  c = kk
  if color is not None:
    c = color
  ax.scatter(ii, jj, kk, c=c, s=1, cmap='viridis', linewidth=1);
  ax.set_xlabel('x')
  ax.set_ylabel('y')
  ax.set_zlabel('z')
  return ax

def plot_density_coordinates_min_density(nerf_sigma_grid, eligible_points, min_density_values):
  fig = plt.figure(figsize=(len(min_density_values) * 4, 4))
  axs = []
  for i, min_density in enumerate(min_density_values):
    kept_points_sigma_grid = binarize_sigma_grid(nerf_sigma_grid, eligible_points, min_density)
    ax = fig.add_subplot(1, len(min_density_values), i+1, projection='3d')
    ax = plot_density_coordinates(kept_points_sigma_grid, ax=ax)
    ax.set_title(f'min_density = {min_density}')
    axs.append(ax)
  return fig, axs


MIN_DENSITY_VALUES = [0, 2, 4, 8, 16, 32, 64]

In [None]:
# Plot XYZ values where density > 0, interactively.

def plot_density_coordinates_interactive(kept_points, color=None):
  assert len(kept_points)
  df = pd.DataFrame.from_records(kept_points, columns=['x', 'y' ,'z'])
  df['color'] = df['z'] * -1
  if color is not None:
    df['color'] = color
  fig = px.scatter_3d(df, x='x', y='y', z='z', color='color', color_continuous_scale=px.colors.sequential.gray)

  # Reduce size of each dot.
  fig.update_traces(marker={'size': 2})

  # Set z-axis to be smaller than x-axis and y-axis.
  fig.update_layout(scene_aspectmode='manual',
                    scene_aspectratio=dict(x=1, y=1, z=1))

  # Update axis direction to match matplotlib.
  fig.update_yaxes(autorange="reversed")

  return fig

MIN_DENSITY = 4

In [None]:
# Generate a 3D lattice of query points.

n = 64
X = np.linspace(-1, 1, num=n)
Y = np.linspace(-1, 1, num=n)
Z = np.linspace(-1, 1, num=n)

In [None]:
nerf_model = recovered_nerf_state.model

In [None]:
xx, yy, zz = np.meshgrid(X, Y, Z, indexing='ij')

In [None]:
xx.shape

In [None]:
positions = jnp.asarray([[[x_new, y_new, z_new]
                          for x_new, y_new, z_new
                          in zip(xx.flatten(), yy.flatten(), zz.flatten())]])

In [None]:
positions.shape

In [None]:
p = types.SamplePoints(
    scene_id=jnp.asarray([[0]]),
    position=positions,
    direction=jax.random.uniform(rng.next(), shape=[1, 3]))

In [None]:
# Query sigma field across 3D lattice of query points.

result = nerf_model.apply(nerf_variables, p)
sigma_values = result.sigma
color_values = jax.nn.sigmoid(result.rgb)

In [None]:
print(sigma_values.shape, sigma_values.dtype)

In [None]:
# Plot sigma field

min_density_values = [0, 2, 4, 8, 16, 32, 64]
fig = plt.figure(figsize=(len(min_density_values) * 4, 4))
axs = []
for i, min_density in enumerate(min_density_values):
  mask = sigma_values > min_density
  kept_points = positions[0][mask[0, :, 0]]
  kept_points_color = color_values[0][mask[0, :, 0]]
  ax = fig.add_subplot(1, len(min_density_values), i+1, projection='3d')
  ax = plot_density_coordinates(kept_points, color=kept_points_color, ax=ax)
  ax.set_title(f'min_density = {min_density}')
  axs.append(ax)

# Visualize Semantic Field

In [None]:
from jax3d.projects.nesf.nerfstatic.models import volumetric_semantic_model
from jax3d.projects.nesf.nerfstatic.utils import types
from jax3d.projects.nesf.utils.typing import PRNGKey, Tree, f32  # pylint: disable=g-multiple-import

In [None]:
rng = jax.random.PRNGKey(params.train.random_seed)

In [None]:
def plot_3D(points, clusters):
  fig = px.scatter_3d(x=points[:, 0],
                      y=points[:, 1],
                      z=points[:, 2],
                      color=clusters,
                      )

  fig.update_traces(marker=dict(size=1),
                    selector=dict(mode='markers'))

  fig.update_yaxes(range=[-1,1])
  fig.update_xaxes(range=[-1,1])
  fig.show()

In [None]:
def predict_fn_3d(
    rng: PRNGKey,
    points: types.SamplePoints,
    nerf_variables: Tree[jnp.ndarray],
    nerf_sigma_grid: f32["1 x y z c"],
    *,
    semantic_variables: Tree[jnp.ndarray],
    semantic_model: volumetric_semantic_model.VolumetricSemanticModel,
) -> f32["D n k"]:
  """Predict semantic logits for a set of 3D points.

  Args:
    rng: jax3d random state.
    points: 3D points to evaluate. Batch size is 'n'.
    nerf_variables: NeRF Model's variables
    nerf_sigma_grid: NeRF sigma grid.
    semantic_variables: Semantic model variables.
    semantic_model: Semantic model for rendering.

  Returns:
    semantic_logits: Array of shape [D, n, k]. Contains logits for
      semantic predictions for each point in 'points' from all devices
      participating in this computation. The return value of this
      function's dimensions correspond to,
        D - number of total devices
        n - number of points per device.
        k - number of semantic classes.
  """
  rng_names = ["params", "sampling", "data_augmentation"]
  rng, *rng_keys = jax.random.split(rng, len(rng_names) + 1)

  # Construct dummy rays to render. The current implementation of
  # VolumetricSemanticModel requires a set of rays to be provided.

  normalize_fn = lambda x: x / jnp.linalg.norm(x, axis=-1, keepdims=True)
  n = jax.local_device_count() or 8
  dummy_rays = types.Rays(scene_id=jnp.zeros((n, 1), dtype=jnp.int32),
                          origin=jnp.zeros((n, 3)),
                          direction=normalize_fn(jnp.ones((n, 3))))

  _, predictions = semantic_model.apply(
      semantic_variables,
      rngs=dict(zip(rng_names, rng_keys)),
      rays=dummy_rays,
      sigma_grid=nerf_sigma_grid,
      randomized_sampling=True,
      is_train=False,
      nerf_model_weights=nerf_variables,
      points=points)

  return predictions

In [None]:
# Create placeholder batch for model initialization.

placeholder_batch = dataset.peek()[1]
placeholder_batch = jax.tree.map(lambda t: t[0, 0, ...], placeholder_batch)

In [None]:
# Load pre-trained NeRF model sigma grids and parameters.

recovered_nerf_state = semantic_utils.load_all_nerf_variables(
    save_dir = params.train.nerf_model_ckpt,
    train_dataset = dataset,
    novel_dataset = dataset,
    recompute_sigma_grid_opts=(
        semantic_utils.RecomputeSigmaGridOptions.from_params(params.train)
    )
)


In [None]:
# Initialize semantic model.

initialized_vol_sem_model = models.construct_volumetric_semantic_model(
    rng=j3d.RandomState(0),
    num_scenes=-1,
    placeholder_batch=placeholder_batch,
    args=params.models,
    nerf_model=recovered_nerf_state.model,
    nerf_sigma_grid=recovered_nerf_state.train_sigma_grids[0],
    nerf_variables=recovered_nerf_state.train_variables[0]
)


In [None]:
vol_sem_model = initialized_vol_sem_model.model
semantic_variables = initialized_vol_sem_model.variables

optimizer = flax.optim.Adam(params.train.lr_init).create(semantic_variables)
state = utils.TrainState(optimizer=optimizer)

In [None]:
# Restore semantic model from checkpoint.

save_dir = train_utils.checkpoint_dir(params)
state = train_utils.restore_opt_checkpoint(save_dir=save_dir, state=state)

In [None]:
save_dir

In [None]:
# Query semantic model across 3D lattice of points.

predictions = predict_fn_3d(rng,
                            p,
                            recovered_nerf_state.train_variables[0],
                            recovered_nerf_state.train_sigma_grids[0],
                          #  semantic_variables=semantic_variables,
                            semantic_variables=state.optimizer.target,
                            semantic_model=vol_sem_model)

In [None]:
semantic_predictions = jnp.argmax(predictions, axis=-1)

In [None]:
# Visualize semantic model predictions across 3D lattice of points.

MIN_DENSITY = 16
mask = sigma_values > MIN_DENSITY
kept_points = positions[0][mask[0, :, 0]]
kept_points_color = semantic_predictions[0][mask[0, :, 0]]
print(kept_points.shape)
plot_3D(kept_points, kept_points_color)

# Visualize Ground Truth

In [None]:
# Load ground truth train data examples.

examples, _ = klevr.make_examples(data_dir=DATA_DIR, split='train', image_idxs=None, enable_sqrt2_buffer=True)

In [None]:
@chex.dataclass
class LabeledPointCloud:
  points: jnp.ndarray
  semantics: jnp.ndarray

  @property
  def num_points(self):
    assert len(self.points.shape) == 2, self.points.shape
    assert len(self.semantics.shape) == 2, self.semantics.shape
    assert self.points.shape[0] == self.semantics.shape[0]
    return self.points.shape[0]

def construct_labeled_point_cloud(batch):
  """Constructs a semantic-labeled point cloud."""
  ray_o = batch.target_view.rays.origin
  ray_d = batch.target_view.rays.direction
  depth = batch.target_view.depth

  semantics = batch.target_view.semantics

  points = ray_o + depth * ray_d

  mask = np.all((points >= -1) & (points <= 1), axis=-1)

  select_points = points[mask]
  select_semantics = semantics[mask]

  return LabeledPointCloud(points=select_points, semantics=select_semantics)

In [None]:
# Construct labeled semantic point cloud from ground truth dataset (i.e. using semantic masks, ray origins & directions from known cameras, and depth)

labeled_point_cloud = construct_labeled_point_cloud(examples)

In [None]:
labeled_point_cloud.num_points

In [None]:
idxs = np.random.randint(labeled_point_cloud.num_points, size=200000) # size=50000)
mini_point_cloud = jax.tree.map(lambda x: x[idxs], labeled_point_cloud)
mini_point_cloud.num_points

In [None]:
# Visualize ground truth labeled semantic point cloud.

fig = px.scatter_3d(x=mini_point_cloud.points[:, 0],
                    y=mini_point_cloud.points[:, 1],
                    z=mini_point_cloud.points[:, 2],
                    color=mini_point_cloud.semantics[:, 0],)


fig.update_traces(marker=dict(size=1),
                  selector=dict(mode='markers'))

fig.show()