In [None]:
blabla

In [1]:
import torch
from torch import nn, Tuple
import numpy as np
import matplotlib.pyplot as plt

In [107]:
n_levels = 16
n_inputs_dim = 2
log2_hashmap_size = 17 # NOTE : This is the size of the hashtable (T)
base_resolution = 16
max_resolution = 320
n_features_per_level = 2


b = np.exp((np.log(max_resolution) - np.log(base_resolution))/(base_resolution-1))


def _get_number_of_embeddings(self, level_idx: int) -> int:
    """
    level_idx: level index

    returns: number of embeddings for given level. Max number is 2**self.log2_hashmap_size
    """

    max_size = 2**self.log2_hashmap_size

    resolution = int(self.base_resolution * self.b**level_idx)
    n_level_size = (
        resolution + 2
    ) ** 3  # see explanation below at 'def _to_1D(...)' why we do + 2

    return min(max_size, n_level_size)

embeddings = nn.ModuleList(
    [
        nn.Embedding(
            _get_number_of_embeddings(i), n_features_per_level
        )
        for i in range(n_levels)
    ]
)

box_offsets = torch.tensor([[i, j] for i in [0, 1] for j in [0, 1]])

def _hash(coords: torch.Tensor, log2_hashmap_size: int) -> torch.Tensor:
    """
    coords: this function can process upto 7 dim coordinates
    log2T:  logarithm of T w.r.t 2
    """
    primes = [
        1,
        2654435761,
        805459861,
        3674653429,
        2097192037,
        1434869437,
        2165219737,
    ]

    xor_result = torch.zeros_like(coords)[..., 0]
    for i in range(coords.shape[-1]):
        xor_result ^= coords[..., i] * primes[i]

    return (
        torch.tensor((1 << log2_hashmap_size) - 1, device=xor_result.device)
        & xor_result
        )

def _to_1D(coords: torch.Tensor, resolution: int) -> torch.Tensor:
    """
    coords: 3D indices of grid
    resolution:  resolution of grid
    """

    """
    Given grid resolution, for instance 2, our coordinate values usually span from 0 to 1 (inclusive, on x, y and z dimensions).
    To convert this coordinate (which is between 0 and 1, inclusive) to a grid index,
    we multiply the coordinate with the resolution (which is 2 in this example).
    This means the maximum cell we can get is (2,2,2) when we multiply the coordinate (1,1,1) with resolution 2.
    
    If we want to convert the 3D cell index (2,2,2) into a 1D index (to retrieve the embedding),
    we can use the formula (z * resolution * resolution) + (y * resolution) + x. The resolution here however must be 3,
    since we are now dealing with a 3x3x3 grid. So, the 1D index is (2 * 3 * 3) + (2 * 3) + 2 = 26.
    
    If we use resolution 2, the 1D index would be (2 * 2 * 2) + (2 * 2) + 2 = 14. 
    This is however wrong, as it represents the wrong cell in a 3x3x3 grid.
    
    Now, we do resolution + 2 because we have offsets of + 1, so we can get a cell at location (3,3,3).
    
    """
    resolution = resolution + 2

    x = coords[:, 0]
    y = coords[:, 1]

    return (y * resolution) + x


def get_pixel_vertices(
        xy: torch.Tensor, resolution: float
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        xy = xy * resolution
        box_min_vertex = torch.floor(xy).int()
        box_indices = box_min_vertex.unsqueeze(0) + box_offsets

        max_size = 2**log2_hashmap_size
        n_level_size = (resolution + 2) ** 3
        if max_size > n_level_size: # No collisions, straightforward method
            hashed_voxel_indices = _to_1D(box_indices, resolution)
            
        else: # Collisions are present
            hashed_voxel_indices = _hash(box_indices, log2_hashmap_size)

        return box_indices, hashed_voxel_indices, xy


def bilinear_interp(
    x: torch.Tensor,
    box_indices: torch.Tensor,
    box_embedds: torch.Tensor,
) -> torch.Tensor:
    """
    x: B x 2
    voxel_min_vertex: B x 2
    voxel_max_vertex: B x 2
    voxel_embedds: B x 4 x 2
    """
    # source: https://en.wikipedia.org/wiki/Trilinear_interpolation
    
    w11 = np.linalg.norm(box_indices[0] - x)
    w12 = np.linalg.norm(box_indices[1] - x)
    w21 = np.linalg.norm(box_indices[2] - x)
    w22 = np.linalg.norm(box_indices[3] - x)

    den = w11+w12+w21+w22

    w11 /= den
    w12 /= den
    w21 /= den
    w22 /= den

    xi_embedding = w11*box_embedds[0] + w12*box_embedds[1] + w21*box_embedds[2] + w22*box_embedds[3]  
    return xi_embedding

def _get_bbox (x, dx):
    box_idx = torch.zeros((4,2))

    box_idx[0,0] = (torch.floor(x[0]/dx)*dx)
    box_idx[0,1] = (torch.floor(x[1]/dx)*dx)

    box_idx[1,:] = box_idx[0,:]
    box_idx[1,0] += dx

    box_idx[3,0] = box_idx[0,0] + dx
    box_idx[3,1] = box_idx[0,1] + dx

    box_idx[2,:] = box_idx[3,:]
    box_idx[2,0] -= dx
        
    return box_idx

In [108]:
embeddings = nn.ModuleList(
    [
        nn.Embedding(
            _get_number_of_embeddings([i]), n_features_per_level
        )
        for i in range(5)
    ]
)
print(embeddings)

ModuleList(
  (0): Embedding(256, 2)
  (1): Embedding(1024, 2)
  (2): Embedding(4096, 2)
  (3): Embedding(25600, 2)
  (4): Embedding(102400, 2)
)


In [105]:
## DERIVE THE BOX SURROUNDING THE POINT AT ANY LEVEL
L = 0
Nmax = 320
Nl = [16, 32, 64, 160, 320]
for i in range(5):
    dx = Nmax//Nl[i]

    x = torch.tensor((1,2))

    box_idx = _get_bbox(x, dx)

    max_size = 2**10
    n_level_size = (Nl[L]) ** 2

    # if max_size > n_level_size:
    #     hashed_box_idx = _to_1D(box_idx, Nl[i])
    # else:
    hashed_box_idx = _hash(box_idx, log2_hashmap_size)
    
    voxel_embedds = embeddings[i](hashed_box_idx.int())

RuntimeError: "bitwise_xor_cpu" not implemented for 'Float'

In [40]:
x_embedded_all = []
x = torch.tensor([120,180])

embeddings = nn.ModuleList(
    [
        nn.Embedding(
            _get_number_of_embeddings(i), n_features_per_level
        )
        for i in range(n_levels)
    ]
)

for i in range(n_levels):
    resolution = int(base_resolution * b**i)
    (voxel_min_vertex, hashed_voxel_indices,xi) = get_voxel_vertices(x, resolution)
    # voxel_embedds = embeddings[i](hashed_voxel_indices)
    # x_embedded = trilinear_interp(xi, voxel_min_vertex, voxel_embedds)
    # x_embedded_all.append(x_embedded)
# return torch.cat(x_embedded_all, dim=-1)

RuntimeError: The size of tensor a (2) must match the size of tensor b (4) at non-singleton dimension 0