In [1]:
import torch
from torch import nn, Tuple
import numpy as np
import matplotlib.pyplot as plt

In [2]:
L = 10
n_inputs_dim = 2
log2_hashmap_size = 12 # NOTE : This is the size of the hashtable (T)

n_features_per_level = 2
Nmax = 320
Nmin = 16

b = np.exp((np.log(Nmax)-np.log(Nmin))/(L-1))

def _get_number_of_embeddings(level_idx: int) -> int:
    """
    level_idx: level index
    returns: number of embeddings for given level. Max number is 2**self.log2_hashmap_size
    """

    max_size = 2**log2_hashmap_size

    resolution = int(Nmin * (b**level_idx).item())
    
    n_level_size = (resolution + 4) ** 2  # see explanation below at 'def _to_1D(...)' why we do + 2

    return min(max_size, n_level_size)

embeddings = nn.ModuleList(
    [
        nn.Embedding(
            _get_number_of_embeddings([i]), n_features_per_level
        )
        for i in range(L)
    ]
)

print(embeddings)


#### Functions required for hash_encoding

def bilinear_interp(
    x: torch.Tensor,
    box_indices: torch.Tensor,
    box_embedds: torch.Tensor,
) -> torch.Tensor:
    """
    Function that computes the bilinear interpolation of the embedding for a point in a bounding box of identified corners
    """
    if box_indices.shape[0] > 2:
        w11 = np.linalg.norm(box_indices[0] - x)
        w12 = np.linalg.norm(box_indices[1] - x)
        w21 = np.linalg.norm(box_indices[2] - x)
        w22 = np.linalg.norm(box_indices[3] - x)

        den = w11+w12+w21+w22

        w11 /= den
        w12 /= den
        w21 /= den
        w22 /= den

        xi_embedding = w11*box_embedds[0] + w12*box_embedds[1] + w21*box_embedds[2] + w22*box_embedds[3]  
        
    else:
        xi_embedding = box_embedds
        
    return xi_embedding


def _get_box_idx (point, Nmax, resolution, log2_hashmap_size):
    x, y = point

    if Nmax == resolution:
        box_idx = torch.tensor((point[0], point[1]))

    else:
        # Calculate box size based on the total boxes
        box_width = Nmax // resolution  # Width of each box
        box_height = Nmax // resolution  # Height of each box

        x_min = max(0, (x // box_width) * box_width)
        y_min = max(0, (y // box_height) * box_height)
        x_max = min(Nmax, x_min + box_width)
        y_max = min(Nmax, y_min + box_height)
        
        
        p00 = [x_min, y_min]
        p01 = [x_max, y_min]
        p10 = [x_min, y_max]
        p11 = [x_max, y_max]
        
        box_idx = torch.tensor((p00, p01, p10, p11))
        
    max_hashtable_size = 2**log2_hashmap_size
    if max_hashtable_size >= (resolution+4)**2:
        hashed_box_idx, scaled_coords = _to_1D(box_idx, Nmax, resolution)
    else:
        hashed_box_idx = _hash(box_idx, log2_hashmap_size)
        scaled_coords = box_idx
    return box_idx, scaled_coords, hashed_box_idx
    # return box_idx

def _to_1D(coors, Nmax, resolution):
    scale_factor = Nmax // resolution
    scaled_coords = torch.div(coors, scale_factor, rounding_mode="floor").int()    
    x = scaled_coords[:,0]
    y = scaled_coords[:,1]
    
    return (y * resolution + x), scaled_coords

def _hash(coords: torch.Tensor, log2_hashmap_size: int) -> torch.Tensor:
    """
    coords: this function can process upto 7 dim coordinates
    log2T:  logarithm of T w.r.t 2
    """
    primes = torch.tensor([
        1,
        2654435761,
        805459861,
        3674653429,
        2097192037,
        1434869437,
        2165219737,
    ], dtype = torch.int64
    )

    xor_result = torch.zeros_like(coords, dtype=torch.int64)[..., 0]
    for i in range(coords.shape[-1]): # Loop around all possible dimensions of the vector containing the bounding box positions
        xor_result ^= coords[...,i].to(torch.int64)*primes[i]

    return (
        torch.tensor((1 << log2_hashmap_size) - 1, device=xor_result.device)
        & xor_result
        )
    


ModuleList(
  (0): Embedding(400, 2)
  (1): Embedding(676, 2)
  (2): Embedding(1225, 2)
  (3): Embedding(2209, 2)
  (4-9): 6 x Embedding(4096, 2)
)


In [14]:
for i in range(L):
    Nl = int(Nmin*b**i)
    max_size = (Nl+4)**2
    print(Nl, max_size)

16 400
19 529
24 784
30 1156
37 1681
46 2500
57 3721
71 5625
88 8464
109 12769
135 19321
168 29584
208 44944
258 68644
320 104976


In [4]:
Nl = 16
all_hash = []
for i in range(30):
    for j in range(30):
        coors = torch.tensor([i,j])
        box_idx = _get_box_idx(coors, 320, Nl, 12)
        hashed = _to_1D(box_idx, 320, Nl)
        # print(hashed)
        # all_hash.append({'box_idx': box_idx, 'hashed': hashed, 'coors': coors} )
    

TypeError: div() received an invalid combination of arguments - got (tuple, int, rounding_mode=str), but expected one of:
 * (Tensor input, Tensor other, *, Tensor out = None)
 * (Tensor input, Tensor other, *, str rounding_mode, Tensor out = None)
 * (Tensor input, Number other, *, str rounding_mode)


In [3]:
from hash_encoding_batch import *
point = torch.tensor(([25,87], [26, 87]), dtype=float)
x_embedded_all = []
embedder = hash_encoder(levels=15, log2_hashmap_size=12, n_features_per_level=2, n_max=320, n_min=16)
weights_all = []

for i in range(L):
    
    resolution = int(Nmin * b**i)
    box_idx, hashed_box_idx = embedder._get_box_idx(point, n_l=resolution)
    print(box_idx[0])
    print(hashed_box_idx[0])
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded, weights = embedder.bilinear_interp(point, box_idx, box_embedds)
    weights_all.append(weights)
    x_embedded_all.append(x_embedded)

    
x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

tensor([[ 20.,  80.],
        [ 40.,  80.],
        [ 20., 100.],
        [ 40., 100.]], dtype=torch.float64)
tensor([65, 66, 81, 82], dtype=torch.int32)
tensor([[14., 84.],
        [28., 84.],
        [14., 98.],
        [28., 98.]], dtype=torch.float64)
tensor([133, 134, 155, 156], dtype=torch.int32)
tensor([[20., 80.],
        [30., 80.],
        [20., 90.],
        [30., 90.]], dtype=torch.float64)
tensor([250, 251, 281, 282], dtype=torch.int32)
tensor([[21., 84.],
        [28., 84.],
        [21., 91.],
        [28., 91.]], dtype=torch.float64)
tensor([519, 520, 562, 563], dtype=torch.int32)
tensor([[25., 85.],
        [30., 85.],
        [25., 90.],
        [30., 90.]], dtype=torch.float64)
tensor([2012, 2011, 2083, 2084])
tensor([[24., 87.],
        [27., 87.],
        [24., 90.],
        [27., 90.]], dtype=torch.float64)
tensor([2879, 2876, 2082, 2081])
tensor([[24., 86.],
        [26., 86.],
        [24., 88.],
        [26., 88.]], dtype=torch.float64)
tensor([ 366,  364, 1216

In [4]:
weights_all[0]

tensor([[0.8540, 0.7191, 0.7637, 0.6632],
        [0.8418, 0.7315, 0.7544, 0.6723]], dtype=torch.float64)

In [15]:
point = torch.tensor([319,319])
x_embedded_all = []

for i in range(L):
    
    resolution = int(Nmin * b**i)
    # print(resolution)
    box_idx, scaled_coords, hashed_box_idx = _get_box_idx(point, Nmax, resolution, log2_hashmap_size)
    print(scaled_coords)
    print(hashed_box_idx)
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded = bilinear_interp(point, box_idx, box_embedds)
    x_embedded_all.append(x_embedded)

x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

tensor([[15, 15],
        [16, 15],
        [15, 16],
        [16, 16]], dtype=torch.int32)
tensor([255, 256, 271, 272], dtype=torch.int32)
tensor([[19, 19],
        [20, 19],
        [19, 20],
        [20, 20]], dtype=torch.int32)
tensor([380, 381, 399, 400], dtype=torch.int32)
tensor([[24, 24],
        [24, 24],
        [24, 24],
        [24, 24]], dtype=torch.int32)
tensor([600, 600, 600, 600], dtype=torch.int32)
tensor([[31, 31],
        [32, 31],
        [31, 32],
        [32, 32]], dtype=torch.int32)
tensor([961, 962, 991, 992], dtype=torch.int32)
tensor([[39, 39],
        [40, 39],
        [39, 40],
        [40, 40]], dtype=torch.int32)
tensor([1482, 1483, 1519, 1520], dtype=torch.int32)
tensor([[53, 53],
        [53, 53],
        [53, 53],
        [53, 53]], dtype=torch.int32)
tensor([2491, 2491, 2491, 2491], dtype=torch.int32)
tensor([[63, 63],
        [64, 63],
        [63, 64],
        [64, 64]], dtype=torch.int32)
tensor([3654, 3655, 3711, 3712], dtype=torch.int32)
tensor([

In [11]:
point = torch.tensor([310, 317])
x_embedded_all = []

for i in range(L):
    
    resolution = int(Nmin * b**i)
    print(resolution)
    box_idx, hashed_box_idx = _get_box_idx(point, Nmax, resolution, log2_hashmap_size)
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded = bilinear_interp(point, box_idx, box_embedds)
    x_embedded_all.append(x_embedded)

x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

16
22
31
43
60
84
117
164
229
320
tensor([ 0.4948, -0.6677,  0.4109,  1.1651,  0.6012, -0.9509, -1.4937,  0.5453,
        -0.3050, -0.0969, -0.9447, -0.5184, -0.6131, -0.4333,  1.2609, -0.0287,
        -0.6124, -0.7133, -0.8770,  2.7351], grad_fn=<CatBackward0>)


In [9]:
point = torch.tensor([0, 0])
x_embedded_all = []

for i in range(L):
    
    resolution = int(Nmin * b**i)
    print(resolution)
    box_idx, hashed_box_idx = _get_box_idx(point, Nmax, resolution, log2_hashmap_size)
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded = bilinear_interp(point, box_idx, box_embedds)
    x_embedded_all.append(x_embedded)

x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

16
22
31
43
60
84
117
164
229
320
tensor([ 0.8431,  0.2460,  0.4135,  0.3757, -0.2419, -0.0620, -0.4686, -1.2061,
        -0.0622,  0.1241,  0.3490,  1.3119,  0.0768, -0.1606,  0.4591, -0.5743,
        -0.2125,  0.3595,  0.7171, -2.1822], grad_fn=<CatBackward0>)


In [None]:

point = torch.rand(20,2)
x_embedded_all = []
print(point)
for i in range(L):
    
    resolution = int(Nmin * b**i)
    print(resolution)
    box_idx, hashed_box_idx = _get_box_idx(point, Nmax, resolution, log2_hashmap_size)
    box_embedds = embeddings[i](hashed_box_idx)
    x_embedded = bilinear_interp(point, box_idx, box_embedds)
    x_embedded_all.append(x_embedded)

x1 = torch.cat(x_embedded_all, dim=-1)
print(x1)

tensor([[0.7580, 0.0689],
        [0.4019, 0.2114],
        [0.8250, 0.1597],
        [0.3326, 0.7257],
        [0.2233, 0.3285],
        [0.3154, 0.6191],
        [0.0753, 0.6940],
        [0.9564, 0.2477],
        [0.5836, 0.0924],
        [0.4902, 0.1884],
        [0.4021, 0.0060],
        [0.4242, 0.0730],
        [0.3489, 0.2683],
        [0.3255, 0.8294],
        [0.1232, 0.3607],
        [0.8837, 0.5485],
        [0.6858, 0.8674],
        [0.1638, 0.1519],
        [0.6985, 0.1344],
        [0.6996, 0.6958]])


In [1]:
from hash_encoding import *
point = torch.tensor([100,10])

embedder = hash_encoder(levels=10, log2_hashmap_size=12, n_features_per_level=2, n_max = 320, n_min = 16)

embedder(point)

New level: 
 tensor([[100,   0],
        [120,   0],
        [100,  20],
        [120,  20]])
tensor([[5, 0],
        [6, 0],
        [5, 1],
        [6, 1]], dtype=torch.int32)
New level: 
 tensor([[ 98,   0],
        [112,   0],
        [ 98,  14],
        [112,  14]])
tensor([[7, 0],
        [8, 0],
        [7, 1],
        [8, 1]], dtype=torch.int32)
New level: 
 tensor([[100,  10],
        [110,  10],
        [100,  20],
        [110,  20]])
tensor([[10,  1],
        [11,  1],
        [10,  2],
        [11,  2]], dtype=torch.int32)
New level: 
 tensor([[ 98,   7],
        [105,   7],
        [ 98,  14],
        [105,  14]])
tensor([[14,  1],
        [15,  1],
        [14,  2],
        [15,  2]], dtype=torch.int32)
New level: 
 tensor([[100,  10],
        [105,  10],
        [100,  15],
        [105,  15]])
tensor([[20,  2],
        [21,  2],
        [20,  3],
        [21,  3]], dtype=torch.int32)
New level: 
 tensor([[ 99,   9],
        [102,   9],
        [ 99,  12],
        [102,

tensor([-0.4597,  0.1142, -0.4894, -0.3159,  0.9899, -0.7720, -0.0278, -0.5469,
        -0.8895,  0.9419, -0.4393,  0.3824,  0.3304, -0.2077,  0.2283, -0.1087,
        -0.3258, -0.5330, -0.6824, -0.3033], grad_fn=<CatBackward0>)