In [1]:
import torch
from torch import nn, Tuple
import numpy as np
import matplotlib.pyplot as plt

In [6]:
L = 10
n_inputs_dim = 2
log2_hashmap_size = 12 # NOTE : This is the size of the hashtable (T)

n_features_per_level = 2
Nmax = 320
Nmin = 16

b = np.exp((np.log(Nmax)-np.log(Nmin))/(L-1))

def _get_number_of_embeddings(level_idx: int) -> int:
    """
    level_idx: level index
    returns: number of embeddings for given level. Max number is 2**self.log2_hashmap_size
    """

    max_size = 2**log2_hashmap_size

    resolution = int(Nmin * (b**level_idx).item())
    
    n_level_size = (resolution + 4) ** 2  # see explanation below at 'def _to_1D(...)' why we do + 2

    return min(max_size, n_level_size)

embeddings = nn.ModuleList(
    [
        nn.Embedding(
            _get_number_of_embeddings([i]), n_features_per_level
        )
        for i in range(L)
    ]
)

print(embeddings)


#### Functions required for hash_encoding

def bilinear_interp(
    x: torch.Tensor,
    box_indices: torch.Tensor,
    box_embedds: torch.Tensor,
) -> torch.Tensor:
    """
    Function that computes the bilinear interpolation of the embedding for a point in a bounding box of identified corners
    """
    if box_indices.shape[0] > 2:
        w11 = np.linalg.norm(box_indices[0] - x)
        w12 = np.linalg.norm(box_indices[1] - x)
        w21 = np.linalg.norm(box_indices[2] - x)
        w22 = np.linalg.norm(box_indices[3] - x)

        den = w11+w12+w21+w22

        w11 /= den
        w12 /= den
        w21 /= den
        w22 /= den

        xi_embedding = w11*box_embedds[0] + w12*box_embedds[1] + w21*box_embedds[2] + w22*box_embedds[3]  
        
    else:
        xi_embedding = box_embedds
        
    return xi_embedding


def _get_box_idx (point, Nmax, resolution, log2_hashmap_size):
    x, y = point

    if Nmax == resolution:
        box_idx = torch.tensor((point[0], point[1]))

    else:
        # Calculate box size based on the total boxes
        box_width = Nmax // resolution  # Width of each box
        box_height = Nmax // resolution  # Height of each box

        x_min = max(0, (x // box_width) * box_width)
        y_min = max(0, (y // box_height) * box_height)
        x_max = min(Nmax, x_min + box_width)
        y_max = min(Nmax, y_min + box_height)
        
        
        p00 = [x_min, y_min]
        p01 = [x_max, y_min]
        p10 = [x_min, y_max]
        p11 = [x_max, y_max]
        
        box_idx = torch.tensor((p00, p01, p10, p11))
        
    max_hashtable_size = 2**log2_hashmap_size
    if max_hashtable_size > resolution**2:
        hashed_box_idx, scaled_coords = _to_1D(box_idx, Nmax, resolution)
    else:
        hashed_box_idx = _hash(box_idx, log2_hashmap_size)
        
    return box_idx, hashed_box_idx

def _to_1D(coors, Nmax, resolution):
    scale_factor = Nmax // resolution
    scaled_coords = torch.div(coors, scale_factor, rounding_mode="floor").int()    
    x = scaled_coords[:,0]
    y = scaled_coords[:,1]
    
    return (y * resolution + x), scaled_coords

def _hash(coords: torch.Tensor, log2_hashmap_size: int) -> torch.Tensor:
    """
    coords: this function can process upto 7 dim coordinates
    log2T:  logarithm of T w.r.t 2
    """
    primes = torch.tensor([
        1,
        2654435761,
        805459861,
        3674653429,
        2097192037,
        1434869437,
        2165219737,
    ], dtype = torch.int64
    )

    xor_result = torch.zeros_like(coords, dtype=torch.int64)[..., 0]
    for i in range(coords.shape[-1]): # Loop around all possible dimensions of the vector containing the bounding box positions
        xor_result ^= coords[...,i].to(torch.int64)*primes[i]

    return (
        torch.tensor((1 << log2_hashmap_size) - 1, device=xor_result.device)
        & xor_result
        )
    


ModuleList(
  (0): Embedding(400, 2)
  (1): Embedding(676, 2)
  (2): Embedding(1225, 2)
  (3): Embedding(2209, 2)
  (4-9): 6 x Embedding(4096, 2)
)


In [7]:
point = torch.tensor([319,319])
x_embedded_all = []

for i in range(L):
    
    resolution = int(Nmin * b**i)
    print(resolution)
    box_idx, hashed_box_idx = _get_box_idx(point, Nmax, resolution, log2_hashmap_size)
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded = bilinear_interp(point, box_idx, box_embedds)
    x_embedded_all.append(x_embedded)

x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

16
22
31
43
60
84
117
164
229
320
tensor([ 0.5554, -0.7761,  0.4109,  1.1651,  0.8021, -1.1589, -0.9328,  1.0611,
        -0.3175,  0.6355, -0.5556, -0.0240,  0.5324,  0.0339, -0.4623, -0.5715,
         0.2971,  0.4746,  0.5212,  1.7220], grad_fn=<CatBackward0>)


In [11]:
point = torch.tensor([310, 317])
x_embedded_all = []

for i in range(L):
    
    resolution = int(Nmin * b**i)
    print(resolution)
    box_idx, hashed_box_idx = _get_box_idx(point, Nmax, resolution, log2_hashmap_size)
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded = bilinear_interp(point, box_idx, box_embedds)
    x_embedded_all.append(x_embedded)

x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

16
22
31
43
60
84
117
164
229
320
tensor([ 0.4948, -0.6677,  0.4109,  1.1651,  0.6012, -0.9509, -1.4937,  0.5453,
        -0.3050, -0.0969, -0.9447, -0.5184, -0.6131, -0.4333,  1.2609, -0.0287,
        -0.6124, -0.7133, -0.8770,  2.7351], grad_fn=<CatBackward0>)


In [9]:
point = torch.tensor([0, 0])
x_embedded_all = []

for i in range(L):
    
    resolution = int(Nmin * b**i)
    print(resolution)
    box_idx, hashed_box_idx = _get_box_idx(point, Nmax, resolution, log2_hashmap_size)
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded = bilinear_interp(point, box_idx, box_embedds)
    x_embedded_all.append(x_embedded)

x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

16
22
31
43
60
84
117
164
229
320
tensor([ 0.8431,  0.2460,  0.4135,  0.3757, -0.2419, -0.0620, -0.4686, -1.2061,
        -0.0622,  0.1241,  0.3490,  1.3119,  0.0768, -0.1606,  0.4591, -0.5743,
        -0.2125,  0.3595,  0.7171, -2.1822], grad_fn=<CatBackward0>)


In [10]:
point = torch.tensor([120, 120])
x_embedded_all = []

for i in range(L):
    
    resolution = int(Nmin * b**i)
    print(resolution)
    box_idx, hashed_box_idx = _get_box_idx(point, Nmax, resolution, log2_hashmap_size)
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded = bilinear_interp(point, box_idx, box_embedds)
    x_embedded_all.append(x_embedded)

x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

16
22
31
43
60
84
117
164
229
320
tensor([-0.0125, -0.3245, -0.4046,  0.6017,  0.1132, -0.6288,  0.0319,  0.0631,
         0.3316, -0.3720,  0.3281, -0.6362,  0.4788, -0.9974,  1.3093, -0.3233,
         0.9508, -1.0356, -0.5815, -0.3331], grad_fn=<CatBackward0>)


In [2]:
from hash_encoding import *
point = torch.tensor([1,1])

embedder = hash_encoder(levels=10, log2_hashmap_size=12, n_features_per_level=2, n_max = 320, n_min = 16)

embedder(point)

tensor([[ 0,  0],
        [20,  0],
        [ 0, 20],
        [20, 20]])
tensor([[0, 0],
        [1, 0],
        [0, 1],
        [1, 1]], dtype=torch.int32)
tensor([[ 0,  0],
        [14,  0],
        [ 0, 14],
        [14, 14]])
tensor([[0, 0],
        [1, 0],
        [0, 1],
        [1, 1]], dtype=torch.int32)
tensor([[ 0,  0],
        [10,  0],
        [ 0, 10],
        [10, 10]])
tensor([[0, 0],
        [1, 0],
        [0, 1],
        [1, 1]], dtype=torch.int32)
tensor([[0, 0],
        [7, 0],
        [0, 7],
        [7, 7]])
tensor([[0, 0],
        [1, 0],
        [0, 1],
        [1, 1]], dtype=torch.int32)
tensor([[0, 0],
        [5, 0],
        [0, 5],
        [5, 5]])
tensor([[0, 0],
        [1, 0],
        [0, 1],
        [1, 1]], dtype=torch.int32)
tensor([[0, 0],
        [3, 0],
        [0, 3],
        [3, 3]])
tensor([[0, 0],
        [3, 0],
        [0, 3],
        [3, 3]])
tensor([[0, 0],
        [2, 0],
        [0, 2],
        [2, 2]])
tensor([[0, 0],
        [2, 0],
    

tensor([ 0.4450,  0.3627, -0.7832,  0.6388,  0.2325,  0.7201,  0.2879,  0.6227,
         0.4419, -0.2468,  0.4043, -0.5034, -0.1639, -0.4500,  0.0418,  0.8081,
         0.6485,  1.2017,  0.4955,  0.5192], grad_fn=<CatBackward0>)