Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V2 #62

Open
wants to merge 9 commits into
base: v2
Choose a base branch
from
Open

V2 #62

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nerfies/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Mapping, Optional, Tuple

import dataclasses
from flax import nn
from flax import linen as nn
import gin
import immutabledict

Expand Down
2 changes: 1 addition & 1 deletion nerfies/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def render_image(
ret_map = jax_utils.unreplicate(model_out[ret_key])
ret_map = jax.tree_map(lambda x: utils.unshard(x, padding), ret_map)
ret_maps.append(ret_map)
ret_map = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *ret_maps)
ret_map = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *ret_maps)
logging.info('Rendering took %.04s', time.time() - start_time)
out = {}
for key, value in ret_map.items():
Expand Down
2 changes: 1 addition & 1 deletion nerfies/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def volumetric_rendering(rgb,
last_sample_z = 1e10 if sample_at_infinity else 1e-19
dists = jnp.concatenate([
z_vals[..., 1:] - z_vals[..., :-1],
jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
jnp.broadcast_to(jnp.array([last_sample_z]), z_vals[..., :1].shape)
], -1)
dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
alpha = 1.0 - jnp.exp(-sigma * dists)
Expand Down
3 changes: 2 additions & 1 deletion nerfies/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Annealing Schedules."""
import abc
import collections
from collections.abc import Mapping
import copy
import math
from typing import Any, Iterable, List, Tuple, Union
Expand All @@ -38,7 +39,7 @@ def from_config(schedule):
return schedule
if isinstance(schedule, Tuple) or isinstance(schedule, List):
return from_tuple(schedule)
if isinstance(schedule, collections.Mapping):
if isinstance(schedule, Mapping):
return from_dict(schedule)

raise ValueError(f'Unknown type {type(schedule)}.')
Expand Down
2 changes: 1 addition & 1 deletion nerfies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def strided_subset(sequence, count):

def tree_collate(list_of_pytrees):
"""Collates a list of pytrees with the same structure."""
return tree_util.tree_multimap(lambda *x: np.stack(x), *list_of_pytrees)
return tree_util.tree_map(lambda *x: np.stack(x), *list_of_pytrees)


@contextlib.contextmanager
Expand Down