In [1]:
import torch
from torch import nn, Tuple
import numpy as np
import matplotlib.pyplot as plt

from hash_encoding_batch import *

In [71]:
levels = 5
n_min = 45
n_features_per_level = 2
n_max = 320
log2_hashmap_size = 13
b = np.exp((np.log(n_max) - np.log(n_min)) / (levels - 1))


size  = 320
x = torch.arange(size)
y = torch.arange(size)
z = torch.arange(4)
coil = torch.arange(3)

points = torch.meshgrid(x, y, z, coil, indexing="ij")
points = torch.stack(points, dim=-1).reshape(-1, len(points)).float()
xy = points[:,:2]

In [72]:
def _get_number_of_embeddings(level_idx: int) -> int:
    max_size = 2 ** log2_hashmap_size
    n_l = int(n_min * (b ** level_idx).item())
    n_l_embeddings = (n_l + 5) ** 2
    return min(max_size, n_l_embeddings)

def bilinear_interp(x: torch.Tensor, box_indices: torch.Tensor, box_embedds: torch.Tensor) -> torch.Tensor:
    device = x.device
    
    if box_indices.shape[1] > 2:
        weights = torch.norm(box_indices - x[:, None, :], dim=2)
        den = weights.sum(dim=1, keepdim=True)
        
        weights /= den # Normalize weights
        weights = 1-weights # NOTE: More weight is given to vertex closer to the point of interest
        
        weights = weights.to(device)
        box_embedds = box_embedds.to(device)

        Npoints = len(den)
        xi_embedding = torch.zeros((Npoints, 2), device = device)
        
        for i in range(4):
            xi_embedding += weights[...,i].unsqueeze(1) * box_embedds[...,i,:]
            
    else:
        xi_embedding = box_embedds
        
    return xi_embedding

def _get_box_idx(points: torch.Tensor, n_l: int) -> tuple:
    
    # Get bounding box indices for a batch of points
    if points.dim() > 1:
        x = points[:,0]
        y = points[:,1]
    else:
        x = points[0]
        y = points[1]

    if n_max == n_l:
        box_idx = points
        hashed_box_idx = _hash(points)
    else:
        # Calculate box size based on the total boxes
        box_width = n_max // n_l  # Width of each box
        box_height = n_max // n_l  # Height of each box

        x_min = torch.maximum(torch.zeros_like(x), (x // box_width) * box_width)
        y_min = torch.maximum(torch.zeros_like(y), (y // box_height) * box_height)
        x_max = torch.minimum(torch.full_like(x, n_max), x_min + box_width)
        y_max = torch.minimum(torch.full_like(y, n_max), y_min + box_height)
        
        # Stack to create four corners per point, maintaining the batch dimension
        box_idx = torch.stack([
            torch.stack([x_min, y_min], dim=1),
            torch.stack([x_max, y_min], dim=1),
            torch.stack([x_min, y_max], dim=1),
            torch.stack([x_max, y_max], dim=1)
        ], dim=1)  # Shape: (batch_size, 4, 2)
        
        # Determine if the coordinates can be directly mapped or need hashing
        max_hashtable_size = 2 ** log2_hashmap_size
        if max_hashtable_size >= (n_l + 5) ** 2:
            hashed_box_idx, _ = _to_1D(box_idx, n_l)
        else:
            hashed_box_idx = _hash(box_idx)
            
    return box_idx, hashed_box_idx

## Hash encoders
def _to_1D(coors, n_l):

    scale_factor = n_max // n_l
    scaled_coords = torch.div(coors, scale_factor, rounding_mode="floor").int()    
    x = scaled_coords[...,0]
    y = scaled_coords[...,1]
    
    return (y * n_l + x), scaled_coords


def _hash(coords: torch.Tensor) -> torch.Tensor:
    """
    coords: this function can process upto 7 dim coordinates
    log2T:  logarithm of T w.r.t 2
    """
    device = coords.device
    primes = torch.tensor([
        1,
        2654435761,
        805459861,
        3674653429,
        2097192037,
        1434869437,
        2165219737,
    ], dtype = torch.int64, device=device
    )

    xor_result = torch.zeros(coords.shape[:-1], dtype=torch.int64, device=device)

    for i in range(coords.shape[-1]): # Loop around all possible dimensions of the vector containing the bounding box positions
        xor_result ^= coords[...,i].to(torch.int64)*primes[i]
        
    hash_mask = (1 << log2_hashmap_size) - 1
    return xor_result & hash_mask


embeddings = nn.ModuleList([
            nn.Embedding(_get_number_of_embeddings(i), n_features_per_level)
            for i in range(levels)])

print(embeddings)


ModuleList(
  (0): Embedding(2500, 2)
  (1): Embedding(6084, 2)
  (2-4): 3 x Embedding(8192, 2)
)


In [76]:
n_min = 16
levels = 5

b = np.exp((np.log(n_max) - np.log(n_min)) / (levels - 1))

for i in range(levels):
    n_l = int(n_min * b ** i)
    print(f'No collision existed : {n_l}, {(n_l+5)**2}')
    print(f'Reality : {n_l}, {_get_number_of_embeddings(i)}')

No collision existed : 16, 441
Reality : 16, 441
No collision existed : 33, 1444
Reality : 33, 1444
No collision existed : 71, 5776
Reality : 71, 5776
No collision existed : 151, 24336
Reality : 151, 8192
No collision existed : 319, 104976
Reality : 319, 8192


In [20]:
xy_embedded_all = []

for i in range(levels):
    n_l = int(n_min * b ** i)
    
    box_idx, hashed_box_idx = _get_box_idx(xy, n_l)
    
    box_embedds = embeddings[i](hashed_box_idx)
    
    xy_embedded = bilinear_interp(xy, box_idx, box_embedds)
    
    xy_embedded_all.append(xy_embedded)
    
    
xy_embedded_all = torch.cat(xy_embedded_all, dim = 1)    


In [7]:
hashed_box_idx

tensor([[   0,    1,   56,   57],
        [   0,    1,   56,   57],
        [   0,    1,   56,   57],
        ...,
        [3591, 3592, 3647, 3648],
        [3591, 3592, 3647, 3648],
        [3591, 3592, 3647, 3648]], dtype=torch.int32)