In [1]:
import torch

import numpy as np
from torch import nn

def get_emb(sin_inp):
    """
    Gets a base embedding for one dimension with sin and cos intertwined
    """
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)

class PositionalEncoding2D(nn.Module):
    def __init__(self, channels, dtype_override=None):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        :param dtype_override: If set, overrides the dtype of the output embedding.
        """
        super(PositionalEncoding2D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 4) * 2)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.register_buffer("cached_penc", None, persistent=False)
        self.dtype_override = dtype_override
        self.channels = channels

        print('Channels:', channels)

    def forward(self, tensor):
        """
        :param tensor: A 4d tensor of size (batch_size, x, y, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, y, ch)
        """
        if len(tensor.shape) != 4:
            raise RuntimeError("The input tensor has to be 4d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, y, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype)
        pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype)
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
        emb_x = get_emb(sin_inp_x).unsqueeze(1)
        emb_y = get_emb(sin_inp_y)
        emb = torch.zeros(
            (x, y, self.channels * 2),
            device=tensor.device,
            dtype=(
                self.dtype_override if self.dtype_override is not None else tensor.dtype
            ),
        )
        emb[:, :, : self.channels] = emb_x
        emb[:, :, self.channels : 2 * self.channels] = emb_y

        print(emb.shape)

        self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)
        return self.cached_penc

In [2]:
class PositionalEncodingND(nn.Module):
    def __init__(self, n_dim, channels, dtype_override=None):
        """
        Positional encoding for N-dimensional tensors.


        """
        super(PositionalEncodingND, self).__init__()
        self.n_dim = n_dim
        self.org_channels = channels
        channels = int(np.ceil(channels / (2*n_dim)) * 2)
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.dtype_override = dtype_override
        self.channels = channels

        print("Channels: ", channels)

    def forward(self, tensor):
        """
        :param tensor: A 2+nd tensor of size (batch_size, x1, x2, ..., xn, ch)
        :return: Positional Encoding Matrix of size (batch_size, x1, x2, ..., xn, ch)
        """
        if len(tensor.shape) != self.n_dim + 2:
            raise RuntimeError("The input tensor has to be {}d!".format(self.n_dim + 2))


        shape = tensor.shape

        orig_ch = shape[-1]
        emb_shape = list(shape)[1:]
        emb_shape[-1] = self.channels * self.n_dim

        emb = torch.zeros(
            emb_shape,
            device=tensor.device,
            dtype=(
                self.dtype_override if self.dtype_override is not None else tensor.dtype
            ),
        )

        for i in range(self.n_dim):
            pos = torch.arange(shape[i+1], device=tensor.device, dtype=self.inv_freq.dtype)
            sin_inp = torch.einsum("i,j->ij", pos, self.inv_freq)
            emb_i = get_emb(sin_inp)
            for _ in range(self.n_dim-i-1):
                emb_i = emb_i.unsqueeze(1)
            
            emb[..., i*self.channels : (i+1)*self.channels] = emb_i

        return emb[None, ..., :orig_ch].repeat(shape[0], *(1 for _ in range(self.n_dim)), 1)

In [None]:

pos_enc_nd = PositionalEncodingND(4, 10)

x = torch.randn(1, 5, 5, 5, 5, 10)


# Apply the positional encoding
pos_enc_x_nd = pos_enc_nd(x)

print(pos_enc_x_nd.shape)

In [4]:
import torch

def get_spatial_representation(
        X : torch.Tensor,
        coords : torch.Tensor,
    ) -> torch.Tensor:
    """
    Computes the spatial representation of a bag given the sequential representation and the coordinates.

    Given the input tensor `X` of shape `(batch_size, bag_size, dim)` and the coordinates `coords` of shape `(batch_size, bag_size, n)`, 
    this function returns the spatial representation `X_enc` of shape `(batch_size, coord1, coord2, ..., coordn, dim)`.

    This representation is characterized by the fact that the coordinates are used to index the elements of spatial representation:
    `X_enc[batch, i1, i2, ..., in, :] = X[batch, idx, :]` where `(i1, i2, ..., in) = coords[batch, idx]`.

    Arguments:
        X (Tensor): Sequential representation of shape `(batch_size, bag_size, dim)`.
        coords (Tensor): Coordinates of shape `(batch_size, bag_size, n)`.
    
    Returns:
        X_esp: Spatial representation of shape `(batch_size, coord1, coord2, ..., coordn, dim)`.
    """

    # Get the shape of the spatial representation
    batch_size = X.shape[0]
    bag_size = X.shape[1]
    n = coords.shape[-1]
    shape = torch.Size([X.shape[0]] + [int(coords[:, :, i].max().item()) + 1 for i in range(n)] + [X.shape[-1]])

    # Initialize the spatial representation
    X_enc = torch.zeros(shape, device=X.device, dtype=X.dtype)

    # Create batch indices of shape (batch_size, bag_size)
    batch_indices = torch.arange(batch_size, device=X.device).unsqueeze(1).expand(-1, bag_size)

    # Create a list of spatial indices (one per coordinate dimension), each of shape (batch_size, bag_size)
    spatial_indices = [coords[:, :, i] for i in range(n)]

    # Build the index tuple without using the unpack operator in the subscript.
    index_tuple = (batch_indices,) + tuple(spatial_indices)

    # Use advanced indexing to assign values from X into X_enc.
    X_enc[index_tuple] = X


    return X_enc

def get_seq_representation(
        X_esp : torch.Tensor,
        coords : torch.Tensor,
    ) -> torch.Tensor:
    """
    Computes the sequential representation of a bag given the spatial representation and the coordinates.

    Given the spatial tensor `X_esp` of shape `(batch_size, coord1, coord2, ..., coordn, dim)` and the coordinates `coords` of shape `(batch_size, bag_size, n)`, 
    this function returns the sequential representation `X` of shape `(batch_size, bag_size, dim)`.

    This representation is characterized by the fact that the coordinates are used to index the elements of spatial representation:
    `X_seq[batch, idx, :] = X_esp[batch, i1, i2, ..., in, :]` where `(i1, i2, ..., in) = coords[batch, idx]`.

    Arguments:
        X_esp (Tensor): Spatial representation of shape `(batch_size, coord1, coord2, ..., coordn, dim)`.
        coords (Tensor): Coordinates of shape `(batch_size, bag_size, n)`.
    
    Returns:
        X_seq: Sequential representation of shape `(batch_size, bag_size, dim)`.
    """

    batch_size = X_esp.shape[0]
    bag_size = coords.shape[1]
    n = coords.shape[-1]

    # Create batch indices with shape (batch_size, bag_size)
    batch_indices = torch.arange(batch_size, device=X_esp.device).unsqueeze(1).expand(-1, bag_size)

    # Build the index tuple without using the unpack operator in the subscript.
    # Each element in the tuple has shape (batch_size, bag_size)
    index_tuple = (batch_indices,) + tuple(coords[:, :, i] for i in range(n))

    # Use advanced indexing to extract the sequential representation from X_esp.
    # The result will have shape (batch_size, bag_size, dim)
    X_seq = X_esp[index_tuple]

    return X_seq


In [None]:
batch_size = 4
bag_size = 5
dim = 1
n = 2

X_seq = torch.randn(batch_size, bag_size, dim)

# create random coordinates. ensure there are no duplicates
repeat = True
while repeat:
    coords = torch.randint(0, 5, (batch_size, bag_size, n))
    repeat = False
    for b in range(batch_size):
        coords_list = []
        for i in range(bag_size):
            coords_list.append(tuple(coords[b, i, :].tolist()))
        coords_set = set(coords_list)
        if len(coords_set) != len(coords_list):
            repeat = True
            break


X_esp = get_spatial_representation(X_seq, coords)

X_seq_reconstructed = get_seq_representation(X_esp, coords)
X_esp_reconstructed = get_spatial_representation(X_seq_reconstructed, coords)


print('X_seq:', X_seq.shape)
print('X_esp:', X_esp.shape)
print('X_seq_reconstructed:', X_seq_reconstructed.shape)
print('X_esp_reconstructed:', X_esp_reconstructed.shape)

In [None]:
print('X_seq:', X_seq)
print('X_seq_reconstructed:', X_seq_reconstructed)
print('X_esp:', X_esp)
print('X_esp_reconstructed:', X_esp_reconstructed)

In [7]:
def normalize(x):
    return (x - min(x)) / (max(x) - min(x))

In [8]:
def quantize_unique(x):
    """
    Quantize a list of real numbers into unique integers based on their rank.
    
    Parameters:
        x (list or np.ndarray): List of n real numbers.
    
    Returns:
        np.ndarray: Array of unique integers (0 to n-1) corresponding to the rank order of x.
    """
    x = np.array(x)
    order = np.argsort(x)
    ranks = np.empty_like(order)
    ranks[order] = np.arange(len(x))
    return ranks

In [None]:
import numpy as np

n = 5
x = np.random.rand(n)

print('x:', x)
print('quantize_unique(x):', quantize_unique(x)) 

In [None]:
import numpy as np

def normalize(x):
    return (x - min(x)) / (max(x) - min(x))

def quantize_coords_1d(x):
    """
    Quantize a list of real numbers into unique integers.
    
    Parameters:
        x (list or np.ndarray): List of n real numbers.
    
    Returns:
        np.ndarray: Array of unique integers (0 to n-1) corresponding to the rank order of x.
    """
    x = np.array(x)
    x = normalize(x)

    # find the minimum distance between two points
    min_dist = np.min(np.diff(np.sort(x)))

    m = np.ceil(1 / min_dist)
    x_quant = np.round(x * m).astype(int)
    return x_quant

n = 5
x = np.random.rand(n)

print('x:', x)
print('quantize_coords_1d(x):', quantize_coords_1d(x))

In [1]:
import numpy as np

def normalize(x):
    return (x - x.min()) / (x.max() - x.min())

def quantize_coords_nd(X):
    """
    Quantize N-dimensional coordinates into unique discrete integers while preserving relative distances.
    
    Parameters:
        X (np.ndarray): Array of shape (n, d) representing n points in d-dimensional space.
    
    Returns:
        np.ndarray: Quantized coordinates of shape (n, d).
    """
    X = np.array(X, dtype=float)
    X = normalize(X)

    # Compute all pairwise distances in the normalized space
    from scipy.spatial import distance_matrix
    dist_matrix = distance_matrix(X, X)
    np.fill_diagonal(dist_matrix, np.inf)  # Ignore self-distances

    # Find the smallest nonzero pairwise distance
    min_dist = np.min(dist_matrix)

    if min_dist == 0:
        min_dist = np.min(dist_matrix[dist_matrix > 0])  # Smallest nonzero distance

    # Compute scaling factor and quantize
    m = np.ceil(1 / min_dist)
    X_quant = np.round(X * m).astype(int)

    return X_quant



In [None]:

n = 10000
x = np.random.rand(n, 20)

# print('x:', x)
print('quantize_coords(x):', quantize_coords_nd(x))

In [None]:
class A():
    def __init__(self):
        print('A init')
        self.d = 3

    def forward(self, x):
        print('A forward')
        return x

class B():
    def __init__(self, mode='r'):
        print('B init, mode:', mode)
        self.r = 5

class B_p(A, B):
    def __init__(self):
        super().__init__()
        B.__init__(self, 'w')

b = B_p()
b.forward(1)
b.d, b.r

In [None]:
import torch

soft_split1 = torch.nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))

bs = 1
seq_len = 10
h = 5
w = 5

x = torch.randn(bs, seq_len, h, w) # (bs, seq_len, h, w)
soft_split1(x) 

In [None]:
x = torch.randn(3, 10, 5, 8)
x.transpose(1, -1).shape