Skip to content

Commit

Permalink
Upgrading flax to linen.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 352432643
  • Loading branch information
jonbarron authored and Copybara-Service committed Jan 18, 2021
1 parent 14deb9a commit 690acfd
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 241 deletions.
47 changes: 22 additions & 25 deletions jaxnerf/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@

# Lint as: python3
"""Evaluation script for Nerf."""
import functools
from os import path

from absl import app
from absl import flags
from flax import optim
import flax
from flax.metrics import tensorboard
from flax.training import checkpoints
import jax
from jax import random
import numpy as np

from jaxnerf.nerf import datasets
from jaxnerf.nerf import model_utils
from jaxnerf.nerf import models
from jaxnerf.nerf import utils

Expand All @@ -45,18 +45,19 @@ def main(unused_argv):
raise ValueError("train_dir must be set. None set now.")
if FLAGS.data_dir is None:
raise ValueError("data_dir must be set. None set now.")
# Force rendering to be deterministic even if training was randomized, as this
# eliminates "speckle" artifacts.
FLAGS.__dict__["randomized"] = False

dataset = datasets.get_dataset("test", FLAGS)
rng, key = random.split(rng)
init_model, init_state = models.get_model(key, dataset.peek(), FLAGS)
optimizer_def = optim.Adam(FLAGS.lr_init)
optimizer = optimizer_def.create(init_model)
model, init_variables = models.get_model(key, dataset.peek(), FLAGS)
optimizer = flax.optim.Adam(FLAGS.lr_init).create(init_variables)
state = utils.TrainState(optimizer=optimizer)
del optimizer, init_variables

def render_fn(key_0, key_1, model, rays):
# Note rng_keys are useless in eval mode since there's no randomness.
return jax.lax.all_gather(model(key_0, key_1, *rays), axis_name="batch")
# Rendering is forced to be deterministic even if training was randomized, as
# this eliminates "speckle" artifacts.
def render_fn(variables, key_0, key_1, rays):
return jax.lax.all_gather(
model.apply(variables, key_0, key_1, *rays, False), axis_name="batch")

# pmap over only the data input.
render_pfn = jax.pmap(
Expand All @@ -66,10 +67,6 @@ def render_fn(key_0, key_1, model, rays):
axis_name="batch",
)

state = model_utils.TrainState(
step=0, optimizer=optimizer, model_state=init_state)
del init_model, init_state

last_step = 0
out_dir = path.join(FLAGS.train_dir,
"path_renders" if FLAGS.render_path else "test_preds")
Expand All @@ -78,7 +75,8 @@ def render_fn(key_0, key_1, model, rays):
path.join(FLAGS.train_dir, "eval"))
while True:
state = checkpoints.restore_checkpoint(FLAGS.train_dir, state)
if state.step <= last_step:
step = int(state.optimizer.state.step)
if step <= last_step:
continue
if FLAGS.save_output and (not utils.isdir(out_dir)):
utils.makedirs(out_dir)
Expand All @@ -89,9 +87,8 @@ def render_fn(key_0, key_1, model, rays):
print(f"Evaluating {idx+1}/{dataset.size}")
batch = next(dataset)
pred_color, pred_disp, pred_acc = utils.render_image(
state,
functools.partial(render_pfn, state.optimizer.target),
batch["rays"],
render_pfn,
rng,
FLAGS.dataset == "llff",
chunk=FLAGS.chunk)
Expand All @@ -112,20 +109,20 @@ def render_fn(key_0, key_1, model, rays):
utils.save_img(pred_disp[Ellipsis, 0],
path.join(out_dir, "disp_{:03d}.png".format(idx)))
if (not FLAGS.eval_once) and (jax.host_id() == 0):
summary_writer.image("pred_color", showcase_color, state.step)
summary_writer.image("pred_disp", showcase_disp, state.step)
summary_writer.image("pred_acc", showcase_acc, state.step)
summary_writer.image("pred_color", showcase_color, step)
summary_writer.image("pred_disp", showcase_disp, step)
summary_writer.image("pred_acc", showcase_acc, step)
if not FLAGS.render_path:
summary_writer.scalar("psnr", np.mean(np.array(psnrs)), state.step)
summary_writer.image("target", showcase_gt, state.step)
summary_writer.scalar("psnr", np.mean(np.array(psnrs)), step)
summary_writer.image("target", showcase_gt, step)
if FLAGS.save_output and (not FLAGS.render_path) and (jax.host_id() == 0):
with utils.open_file(path.join(out_dir, "psnr.txt"), "w") as pout:
pout.write("{}".format(np.mean(np.array(psnrs))))
if FLAGS.eval_once:
break
if int(state.step) >= FLAGS.max_steps:
if int(step) >= FLAGS.max_steps:
break
last_step = state.step
last_step = step


if __name__ == "__main__":
Expand Down
76 changes: 29 additions & 47 deletions jaxnerf/nerf/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,29 @@
"""Helper functions/classes for model definition."""

import functools
from typing import Any, Callable

import flax
from flax import nn
from flax import optim
from flax import linen as nn
import jax
from jax import lax
from jax import random
import jax.numpy as jnp


@flax.struct.dataclass
class TrainState:
step: int
optimizer: optim.Optimizer
model_state: nn.Collection


class MLP(nn.Module):
"""A simple MLP."""

def apply(self,
x,
condition=None,
net_depth=8,
net_width=256,
net_depth_condition=1,
net_width_condition=128,
net_activation=nn.relu,
skip_layer=4,
num_rgb_channels=3,
num_sigma_channels=1):
"""Multi-layer perception for nerf.
net_depth: int = 8 # The depth of the first part of MLP.
net_width: int = 256 # The width of the first part of MLP.
net_depth_condition: int = 1 # The depth of the second part of MLP.
net_width_condition: int = 128 # The width of the second part of MLP.
net_activation: Callable[Ellipsis, Any] = nn.relu # The activation function.
skip_layer: int = 4 # The layer to add skip layers to.
num_rgb_channels: int = 3 # The number of RGB channels.
num_sigma_channels: int = 1 # The number of sigma channels.

@nn.compact
def __call__(self, x, condition):
"""Evaluate the MLP.
Args:
x: jnp.ndarray(float32), [batch, num_samples, feature], points.
Expand All @@ -57,15 +48,6 @@ def apply(self,
concatenated with the output vector of the first part of the MLP. If
None, only the first part of the MLP will be used with input x. In the
original paper, this variable is the view direction.
net_depth: int, the depth of the first part of MLP.
net_width: int, the width of the first part of MLP.
net_depth_condition: int, the depth of the second part of MLP.
net_width_condition: int, the width of the second part of MLP.
net_activation: function, the activation function used in the MLP.
skip_layer: int, add a skip connection to the output vector of every
skip_layer layers.
num_rgb_channels: int, the number of RGB channels.
num_sigma_channels: int, the number of density channels.
Returns:
raw_rgb: jnp.ndarray(float32), with a shape of
Expand All @@ -78,32 +60,32 @@ def apply(self,
x = x.reshape([-1, feature_dim])
dense_layer = functools.partial(
nn.Dense, kernel_init=jax.nn.initializers.glorot_uniform())

inputs = x
for i in range(net_depth):
x = dense_layer(x, net_width)
x = net_activation(x)
if i % skip_layer == 0 and i > 0:
for i in range(self.net_depth):
x = dense_layer(self.net_width)(x)
x = self.net_activation(x)
if i % self.skip_layer == 0 and i > 0:
x = jnp.concatenate([x, inputs], axis=-1)
raw_sigma = dense_layer(x, num_sigma_channels).reshape(
[-1, num_samples, num_sigma_channels])
raw_sigma = dense_layer(self.num_sigma_channels)(x).reshape(
[-1, num_samples, self.num_sigma_channels])

if condition is not None:
# Output of the first part of MLP.
bottleneck = dense_layer(x, net_width)
bottleneck = dense_layer(self.net_width)(x)
# Broadcast condition from [batch, feature] to
# [batch, num_samples, feature] since all the samples along the same ray
# has the same viewdir.
# have the same viewdir.
condition = jnp.tile(condition[:, None, :], (1, num_samples, 1))
# Collapse the [batch, num_samples, feature] tensor to
# [batch * num_samples, feature] so that it can be feed into nn.Dense.
# [batch * num_samples, feature] so that it can be fed into nn.Dense.
condition = condition.reshape([-1, condition.shape[-1]])
x = jnp.concatenate([bottleneck, condition], axis=-1)
# Here use 1 extra layer to align with the original nerf model.
for i in range(net_depth_condition):
x = dense_layer(x, net_width_condition)
x = net_activation(x)
raw_rgb = dense_layer(x, num_rgb_channels).reshape(
[-1, num_samples, num_rgb_channels])
for i in range(self.net_depth_condition):
x = dense_layer(self.net_width_condition)(x)
x = self.net_activation(x)
raw_rgb = dense_layer(self.num_rgb_channels)(x).reshape(
[-1, num_samples, self.num_rgb_channels])
return raw_rgb, raw_sigma


Expand Down

0 comments on commit 690acfd

Please sign in to comment.