Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for nd interpolation #2

Merged
merged 1 commit into from Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
58 changes: 45 additions & 13 deletions src/neural_hash_encoding/hash_array.py
@@ -1,6 +1,8 @@
from typing import Tuple, Any
from typing import Tuple, Any, Iterable
from functools import reduce
import operator
import numpy as np
from dataclasses import dataclass
from dataclasses import dataclass, field
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

Expand All @@ -9,26 +11,56 @@
https://github.com/NVlabs/tiny-cuda-nn/blob/d0639158dd64b5d146d659eb881702304abaaa24/include/tiny-cuda-nn/encodings/grid.h
"""

Shape = Iterable[int]
Dtype = Any # this could be a real type?
Array = Any

PRIMES = (73856093, 19349663)
PRIMES = (1, 73856093, 19349663, 83492791)


@register_pytree_node_class
@dataclass
class HashArray2D:
data: Array
shape: Tuple[int, int]

class HashArray:
"""
This is a sparse array backed by simple hash table. It minimally implements an array
interface as to be used for (nd) linear interpolation.
There is no collision resolution or even bounds checking.

Attributes:
data: The hash table represented as a 2D array.
First dim is indexed with the hash index and second dim is the feature
shape: The shape of the array.

NVIDIA Implementation of multi-res hash grid:
https://github.com/NVlabs/tiny-cuda-nn/blob/master/include/tiny-cuda-nn/encodings/grid.h#L66-L80
"""
def spatial_hash(self, y, x):
return (x ^ (y * PRIMES[0])) % self.data.shape[-2]
data: Array
shape: Shape

def __post_init__(self):
assert self.data.ndim == 2, "Hash table data should be 2d"
assert self.data.shape[1] == self.shape[-1]

@property
def ndim(self):
return len(self.shape)

@property
def dtype(self):
return self.data.dtype

def spatial_hash(self, coords):
assert len(coords) <= len(PRIMES), "Add more PRIMES!"
if len(coords) == 1:
i = (coords[0] ^ PRIMES[1])
else:
i = reduce(operator.xor, (c * p for c, p in zip(coords, PRIMES)))
return i % self.data.shape[0]

def __getitem__(self, i):
x, y, d = i[-3:] if len(i) == 3 else (*i[-2:], Ellipsis)
i = self.spatial_hash(y, x)
return self.data[i, d]
*spatial_i, feature_i = i if len(i) == self.ndim else (*i, Ellipsis)
i = self.spatial_hash(spatial_i)
return self.data[i, feature_i]

def __array__(self, dtype=None):
H, W, _ = self.shape
Expand All @@ -37,7 +69,7 @@ def __array__(self, dtype=None):
return arr

def __repr__(self):
return "HashArray2D(" + str(np.asarray(self)) + ")"
return "HashArray(" + str(np.asarray(self)) + ")"

def tree_flatten(self):
return (self.data, self.shape)
Expand Down
113 changes: 98 additions & 15 deletions src/neural_hash_encoding/interpolate.py
@@ -1,8 +1,20 @@
from typing import Any
from typing import Any, Sequence
from dataclasses import dataclass
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

import operator
import itertools
import functools
from jax._src.scipy.ndimage import (
_nonempty_prod,
_nonempty_sum,
_INDEX_FIXERS,
_round_half_away_from_zero,
_nearest_indices_and_weights,
_linear_indices_and_weights,
)

Array = Any


Expand Down Expand Up @@ -37,22 +49,93 @@ def bilinear_interpolate(arr, x, y, clip_to_bounds=False):
return wa*Ia + wb*Ib + wc*Ic + wd*Id


def map_coordinates(input, coordinates, order, mode='constant', cval=0.0):
"""
Adapted from jax.scipy.map_coordinates, but with a few key differences.

1.) interpolations are always broadcasted along the last dimension of the `input`
i.e. a 3 channel rgb image with shape [H, W, 3] will be interpolated with 2d
coordinates and broadcasted across the channel dimension

2.) `input` isn't required to be jax `DeviceArray` -- it can be any type that
supports numpy fancy indexing

Note on interpolation: `map_coordinates` indexes in the order of the axes,
so for an image it indexes the coordinates as [y, x]
"""

coordinates = [jnp.asarray(c) for c in coordinates]
cval = jnp.asarray(cval, input.dtype)

if len(coordinates) != input.ndim-1:
raise ValueError('coordinates must be a sequence of length input.ndim - 1, but '
'{} != {}'.format(len(coordinates), input.ndim - 1))

index_fixer = _INDEX_FIXERS.get(mode)
if index_fixer is None:
raise NotImplementedError(
'map_coordinates does not support mode {}. '
'Currently supported modes are {}.'.format(mode, set(_INDEX_FIXERS)))

if mode == 'constant':
is_valid = lambda index, size: (0 <= index) & (index < size)
else:
is_valid = lambda index, size: True

if order == 0:
interp_fun = _nearest_indices_and_weights
elif order == 1:
interp_fun = _linear_indices_and_weights
else:
raise NotImplementedError(
'map_coordinates currently requires order<=1')

valid_1d_interpolations = []
for coordinate, size in zip(coordinates, input.shape[:-1]):
interp_nodes = interp_fun(coordinate)
valid_interp = []
for index, weight in interp_nodes:
fixed_index = index_fixer(index, size)
valid = is_valid(index, size)
valid_interp.append((fixed_index, valid, weight))
valid_1d_interpolations.append(valid_interp)

outputs = []
for items in itertools.product(*valid_1d_interpolations):
indices, validities, weights = zip(*items)
if all(valid is True for valid in validities):
# fast path
contribution = input[(*indices, Ellipsis)]
else:
all_valid = functools.reduce(operator.and_, validities)
contribution = jnp.where(all_valid[..., None], input[(*indices, Ellipsis)], cval)
outputs.append(_nonempty_prod(weights)[..., None] * contribution)

result = _nonempty_sum(outputs)
if jnp.issubdtype(input.dtype, jnp.integer):
result = _round_half_away_from_zero(result)
return result.astype(input.dtype)


@dataclass
@register_pytree_node_class
class Interpolate2D:
arr: Array

def __call__(self, x, y, normalized=True):
if normalized:
# un-normalize
y = y * (self.arr.shape[0] - 1)
x = x * (self.arr.shape[1] - 1)
return bilinear_interpolate(self.arr, x, y)
class Interpolate:
arr: Array
order: int
mode: str
cval: float = 0.0

def __call__(self, coords, normalized=True):
coords = [jnp.asarray(c) for c in coords]
assert len(coords) == (self.arr.ndim - 1)
if normalized:
# un-normalize
coords = [c * (s-1) for c, s in zip(coords, self.arr.shape)]
return map_coordinates(self.arr, coords, order=self.order, mode=self.mode, cval=self.cval)

def tree_flatten(self):
return (self.arr, None)
def tree_flatten(self):
return (self.arr, None)

@classmethod
def tree_unflatten(cls, aux_data, data):
return cls(data)
@classmethod
def tree_unflatten(cls, aux_data, data):
return cls(data)
52 changes: 23 additions & 29 deletions src/neural_hash_encoding/model.py
Expand Up @@ -5,8 +5,8 @@
import jax.numpy as jnp
from flax import linen as nn

from neural_hash_encoding.hash_array import HashArray2D, _get_level_res_nd
from neural_hash_encoding.interpolate import Interpolate2D
from neural_hash_encoding.hash_array import HashArray, _get_level_res_nd
from neural_hash_encoding.interpolate import Interpolate

# Copied from flax
PRNGKey = Any
Expand All @@ -20,58 +20,52 @@ def init(key, shape, dtype=dtype):
return init


class DenseEncodingLevel2D(nn.Module):
class DenseEncodingLevel(nn.Module):
res: Shape
features: int = 2
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
table_init: Callable[[PRNGKey, Shape, Dtype], Array] = uniform_init(-1e-4, 1e-4)

interp: Interpolate2D = field(init=False)
interp: Interpolate = field(init=False)

def setup(self):
array = self.param('table',
self.table_init,
(*self.res, self.features),
self.param_dtype)
self.interp = Interpolate2D(jnp.asarray(array))
self.interp = Interpolate(jnp.asarray(array), order=1, mode='nearest')

def __call__(self, xy):
assert xy.shape[-1] == 2
xy = jnp.asarray(xy, self.dtype)
x, y = xy[..., 0], xy[..., 1]
return self.interp(x, y, normalized=True)
def __call__(self, coords):
assert len(coords) == (self.interp.arr.ndim - 1)
return self.interp(coords, normalized=True)



class HashEncodingLevel2D(nn.Module):
class HashEncodingLevel(nn.Module):
res: Shape
features: int = 2
table_size: int = 2**14
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
table_init: Callable[[PRNGKey, Shape, Dtype], Array] = uniform_init(-1e-4, 1e-4)

interp: Interpolate2D = field(init=False)
interp: Interpolate = field(init=False)

def setup(self):
table = self.param('table',
self.table_init,
(self.table_size, self.features),
self.param_dtype)
shape = (*self.res, self.features)
array = HashArray2D(jnp.asarray(table), shape)
self.interp = Interpolate2D(array)
array = HashArray(jnp.asarray(table), shape)
self.interp = Interpolate(array, order=1, mode='nearest')

def __call__(self, xy):
assert xy.shape[-1] == 2
xy = jnp.asarray(xy, self.dtype)
x, y = xy[..., 0], xy[..., 1]
return self.interp(x, y, normalized=True)
def __call__(self, coords):
assert len(coords) == (self.interp.arr.ndim - 1)
return self.interp(coords, normalized=True)


class MultiResEncoding2D(nn.Module):

class MultiResEncoding(nn.Module):
levels: int=16
table_size: int = 2**14
features: int = 2
Expand All @@ -89,12 +83,12 @@ def setup(self):
features=self.features, dtype=self.dtype,
param_dtype=self.param_dtype, table_init=self.param_init)
# First level is always dense
L0 = DenseEncodingLevel2D(res_levels[0], **kwargs)
L0 = DenseEncodingLevel(res_levels[0], **kwargs)
# Rest are sparse hash arrays
self.L = tuple([L0, *(HashEncodingLevel2D(l, table_size=self.table_size, **kwargs) for l in res_levels[1:])])
self.L = tuple([L0, *(HashEncodingLevel(l, table_size=self.table_size, **kwargs) for l in res_levels[1:])])

def __call__(self, xy):
features = [l(xy) for l in self.L]
def __call__(self, coords):
features = [l(coords) for l in self.L]
features = jnp.concatenate(features, -1)
return features

Expand All @@ -121,11 +115,11 @@ class ImageModel(nn.Module):
minres: Shape = (16, 16)

def setup(self):
self.embedding = MultiResEncoding2D(self.levels, self.table_size,
self.embedding = MultiResEncoding(self.levels, self.table_size,
self.features, self.minres, self.res)
self.decoder = MLP((64, 64, self.channels))

def __call__(self, xy):
features = self.embedding(xy)
def __call__(self, coords):
features = self.embedding(coords)
color = self.decoder(features)
return color
16 changes: 8 additions & 8 deletions src/train_image.py
Expand Up @@ -32,8 +32,9 @@ def __iter__(self):
x = np.random.randint(0, W-1, self.batch_size)
y = np.random.randint(0, H-1, self.batch_size)
rgb = img[y, x, :]
xy = np.stack([x / (W-1), y / (H-1)], -1)
yield xy, rgb / 255
# Normalize coordinates to [0, 1]
yx = [y / (H-1), x / (W-1)]
yield yx, rgb / 255


class RandomPixelLoader(DataLoader):
Expand Down Expand Up @@ -77,7 +78,7 @@ def l2_loss(params, l2=1e-6):
def create_train_state(rng, learning_rate):
"""Creates initial `TrainState`."""
image_model = ImageModel(res=(H, W), table_size=table_size)
x = jnp.ones((1, 2)) # Dummy data
x = jnp.ones((2, 1)) # Dummy data
params = image_model.init(rng, x)['params']
tx = optax.adamw(learning_rate, b1=.9, b2=.99, eps=1e-10)
return train_state.TrainState.create(
Expand All @@ -86,10 +87,10 @@ def create_train_state(rng, learning_rate):
@jax.jit
def train_step(state, batch):
"""Train for a single step."""
xy, colors_targ = batch
yx, colors_targ = batch

def loss_fn(params, weight_decay=1e-6):
colors_pred = ImageModel((H, W), table_size=table_size).apply({'params': params}, xy)
colors_pred = ImageModel((H, W), table_size=table_size).apply({'params': params}, yx)
mlp_params = params['decoder']
loss = mse_loss(colors_pred, colors_targ) + l2_loss(mlp_params, weight_decay)
return loss, colors_pred
Expand All @@ -105,8 +106,8 @@ def write_region_plot(path, params, img, s=np.s_[10000:12048, 20000:22048, :]):
assert len(s) == 3
crop = img[s]

xy = jnp.roll(jnp.mgrid[s[:2]], 1, 0).reshape(2, -1).T / jnp.array([W-1, H-1])
rgb = ImageModel((H, W), table_size=table_size).apply({'params': params}, xy)
yx = jnp.mgrid[s[:2]].reshape(2, -1) / jnp.array([H-1, W-1]).reshape(2, 1)
rgb = ImageModel((H, W), table_size=table_size).apply({'params': params}, yx)
crop2 = (rgb.reshape(*crop.shape) * 255).round(0).clip(0, 255).astype(np.uint8)

fig, axs = plt.subplots(1, 2, figsize=(16, 12))
Expand All @@ -133,7 +134,6 @@ def write_region_plot(path, params, img, s=np.s_[10000:12048, 20000:22048, :]):
for epoch in range(epochs):
for i, batch in enumerate(loader):
step = epoch * len(ds) + i
batch = (jnp.asarray(batch[0]), jnp.asarray(batch[1]))
state, metrics = train_step(state, batch)
loss, psnr = metrics['loss'], metrics['psnr']
if step > 1 and (np.log10(step) == int(np.log10(step))): # exponential logging
Expand Down