## Cross checking with original implementation vs. optimized implementation

In [1]:
# Torch
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch import optim 
from torchsummary import summary


# Train + Data 
import sys 
sys.path.append('../Layers')
from Conv1d_NN import *
from Conv2d_NN import *
from Conv1d_NN_spatial import * 
from Conv2d_NN_spatial import * 
from ConvNN_CNN_Branching import *

import time

### 1. Distance Matrix Calculation

In [2]:
def _calculate_distance_matrix(matrix, sqrt=False):
    norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
    # print("norm_squared", norm_squared.shape)
    # print("matrix t", matrix.transpose(2, 1).shape)
    # print("matrix", matrix.shape)
    # print("matrix t * matrix", torch.bmm(matrix.transpose(2, 1), matrix).shape)
    
    dot_product = torch.bmm(matrix.transpose(2, 1), matrix)
    dist_matrix = norm_squared + norm_squared.transpose(2, 1) - 2 * dot_product
    
    dist_matrix = torch.clamp(dist_matrix, min=0)  
    
    if sqrt:
        dist_matrix = torch.sqrt(dist_matrix)
    return dist_matrix

ex = torch.randn(32, 3, 28)
o1 = _calculate_distance_matrix(ex)
print(o1.shape)


torch.Size([32, 28, 28])


In [3]:
# N Samples
def original_calculate_distance_matrix_N(matrix, matrix_sample, sqrt=False):
    """Calculates distance matrix between two input matrices""" 
    norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
    
    print("norm squared", norm_squared.shape)
    print("norm sqaured no permute", torch.sum(matrix ** 2, dim=1, keepdim=True).shape)
    
    norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
    print("norm squared sample", norm_squared_sample.shape)
    print("norm squared sample no permute", torch.sum(matrix_sample ** 2, dim=1, keepdim=True).shape)
    
    dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
    print("dot product", dot_product.shape)
    
    dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
    print("dist matrix", dist_matrix.shape)
    
    dist_matrix = torch.clamp(dist_matrix, min=0.0)  
    
    if sqrt:
        dist_matrix = torch.sqrt(dist_matrix)
    return dist_matrix

ex = torch.randn(32, 3, 28)
ex1 = torch.randn(32, 3, 10)
o1 = original_calculate_distance_matrix_N(ex, ex1)
print(o1.shape)
    

norm squared torch.Size([32, 28, 1])
norm sqaured no permute torch.Size([32, 1, 28])
norm squared sample torch.Size([32, 1, 10])
norm squared sample no permute torch.Size([32, 1, 10])
dot product torch.Size([32, 28, 10])
dist matrix torch.Size([32, 28, 10])
torch.Size([32, 28, 10])


### 2. Similarity Matrix Calculation

In [4]:
# All Samples
def original_calculate_similarity_matrix(matrix): 
    """Calculates similarity matrix of the input matrix"""
    normalized_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    dot_product = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix)
    similarity_matrix = dot_product 
    return similarity_matrix

def optimized_calculate_similarity_matrix(matrix): 
    normalized_matrix = F.normalized(matrix, p=2, dim=1)
    similarity_matrix = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix)
    return similarity_matrix

In [5]:
# N Samples
def original_calculate_similarity_matrix_N(matrix, matrix_sample): 
    """Calculates similarity matrix between two input matrices"""
    norm_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    norm_sample = F.normalize(matrix_sample, p=2, dim=1)
    similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
    return similarity_matrix


### 3. Topk Selection

In [6]:
# All Samples

@staticmethod 
def prime_vmap_2d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D"""
    batched_process = torch.vmap(process_batch, in_dims=(0, 0, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=True, maximum=maximum)
    return prime 

@staticmethod 
def prime_vmap_3d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D"""
    batched_process = torch.vmap(process_batch, in_dims=(0, 0, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=False, maximum=maximum)
    return prime

@staticmethod 
def process_batch(matrix, magnitude_matrix, num_nearest_neighbors, flatten, maximum): 
    """Process the batch of matrices by finding the K nearest neighbors with reshaping."""
    ind = torch.topk(magnitude_matrix, num_nearest_neighbors, largest=maximum).indices 
    neigh = matrix[:, ind]
    if flatten: 
        reshape = torch.flatten(neigh, start_dim=1)
        return reshape
    return neigh

mat = torch.randn(32, 3, 28)
mag = original_calculate_similarity_matrix(mat)
prime = prime_vmap_2d(mat, mag, 5, True)

print('original matrix shape', mat.shape)
print('original similarity matrix shape', mag.shape)
print('prime shape', prime.shape)

# 28 * 5 = 140 



original matrix shape torch.Size([32, 3, 28])
original similarity matrix shape torch.Size([32, 28, 28])
prime shape torch.Size([32, 3, 140])


In [7]:
torch.set_printoptions(precision=4, sci_mode=False)


def prime_brute(matrix, dist_matrix, K): 
    stack_list = [] 
    count = 0 
    for i in range(matrix.shape[0]): 
        # print("i", i)
        concat_list = [] 
        for j in range(matrix.shape[2]): 
            # Get the indices of the nearest neighbors
            indices = torch.topk(dist_matrix[i, j, :], K, largest=True).indices
            # print("indices", indices)
            # Get the nearest neighbors
            nearest_neighbors = matrix[i, :, indices]
            
            # Concatenate the nearest neighbors
            concat_list.append(nearest_neighbors)
            count += 1
        # Concatenate the tensor list to create the convolution matrix 
        # print()
        concat = torch.cat(concat_list, dim=1)
        stack_list.append(concat)
    prime = torch.stack(stack_list, dim= 0)
    # print("count", count)
    return prime
        

mat = torch.randn(32, 3, 28)
sim = original_calculate_similarity_matrix(mat)
prime_b = prime_brute(mat, sim, 5)
print('original matrix shape', mat.shape)
print('original similarity matrix shape', sim.shape)
print('prime shape', prime.shape)
print(prime_b)

original matrix shape torch.Size([32, 3, 28])
original similarity matrix shape torch.Size([32, 28, 28])
prime shape torch.Size([32, 3, 140])
tensor([[[ 2.3307,  1.7655,  0.5343,  ...,  0.7687,  0.7188,  1.3099],
         [ 0.5816,  0.8979,  0.8920,  ..., -2.1041, -0.6582, -0.9405],
         [-0.9124,  0.5200,  0.1506,  ..., -0.5886,  0.5900,  0.9902]],

        [[ 1.7783,  0.3377,  0.8736,  ..., -0.4878, -0.6091, -0.3954],
         [-0.6327, -0.1128, -0.1515,  ..., -1.2519, -1.8199, -0.7666],
         [ 0.2774,  0.1525, -0.0997,  ..., -0.9424, -0.2138, -1.4212]],

        [[-1.8757, -0.3479, -2.0375,  ..., -0.1934, -1.3167,  0.4096],
         [ 1.4788,  0.5998,  0.9316,  ..., -1.5570, -0.1427, -1.2595],
         [-1.7019, -0.7538, -0.4602,  ...,  1.0518,  0.8979,  2.8289]],

        ...,

        [[-1.2432, -0.2307, -0.5662,  ...,  2.1841,  0.7420,  0.6716],
         [ 0.6447,  0.1323, -0.0841,  ...,  0.7177,  1.7078, -0.0492],
         [-1.9008, -0.5443, -0.5578,  ...,  0.3047,  0.202

In [8]:






_, topk_indices = torch.topk(sim, k=5, dim=2, largest=True)  
print("topk indices: ", topk_indices.shape)  


batch_size, channels, tokens = mat.shape
K = topk_indices.shape[-1]  
print("batch_size: ", batch_size)
print("channels: ", channels)
print("tokens: ", tokens)
print("K: ", K)

# Expand topk_indices: add a channel dimension at dim=1
indices_expanded = topk_indices.unsqueeze(1).expand(batch_size, channels, tokens, K)
print(indices_expanded)

print("indices_expanded shape:", indices_expanded.shape)  # torch.Size([32, 3, 28, 5])



print()

# Create index tensors for batch and channel dimensions:
batch_indices = torch.arange(batch_size).view(batch_size, 1, 1, 1).expand(batch_size, channels, tokens, K)

channel_indices = torch.arange(channels).view(1, channels, 1, 1).expand(batch_size, channels, tokens, K)

# Now use advanced indexing:
# For each [b, c, i, j], we select mat[b, c, indices_expanded[b, c, i, j]]
prime = mat[batch_indices, channel_indices, indices_expanded]
# prime will have shape [32, 3, 28, 5]

# Finally, reshape (flatten the token and neighbor dimensions) to get [32, 3, 28*5] = [32, 3, 140]
prime_new = prime.view(batch_size, channels, -1)

print("prime_new shape:", prime_new.shape)


topk indices:  torch.Size([32, 28, 5])
batch_size:  32
channels:  3
tokens:  28
K:  5
tensor([[[[ 0, 19, 13, 27,  9],
          [ 1,  3, 14,  8, 16],
          [ 2, 12, 27, 26,  6],
          ...,
          [25, 22, 13, 17, 24],
          [26,  6, 11, 15, 12],
          [27, 12,  2, 26,  6]],

         [[ 0, 19, 13, 27,  9],
          [ 1,  3, 14,  8, 16],
          [ 2, 12, 27, 26,  6],
          ...,
          [25, 22, 13, 17, 24],
          [26,  6, 11, 15, 12],
          [27, 12,  2, 26,  6]],

         [[ 0, 19, 13, 27,  9],
          [ 1,  3, 14,  8, 16],
          [ 2, 12, 27, 26,  6],
          ...,
          [25, 22, 13, 17, 24],
          [26,  6, 11, 15, 12],
          [27, 12,  2, 26,  6]]],


        [[[ 0,  6, 24, 15,  2],
          [ 1, 10, 21, 12,  3],
          [ 2, 24, 26,  0, 15],
          ...,
          [25, 19, 22,  4, 15],
          [26, 18, 22,  2, 15],
          [27, 14,  3, 20, 10]],

         [[ 0,  6, 24, 15,  2],
          [ 1, 10, 21, 12,  3],
          [ 

In [9]:
print(torch.allclose(prime_b, prime_new))



True


### Optimized check


In [10]:
torch.set_printoptions(precision=4, sci_mode=False)


In [11]:
def calculate_similarity_matrix(matrix): 
    """Calculates similarity matrix of the input matrix"""
    normalized_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    dot_product = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix)
    similarity_matrix = dot_product 
    return similarity_matrix

In [12]:
def prime_brute(matrix, dist_matrix, K): 
    stack_list = [] 
    count = 0 
    for i in range(matrix.shape[0]): 
        concat_list = [] 
        for j in range(matrix.shape[2]): 
            indices = torch.topk(dist_matrix[i, j, :], K, largest=True).indices

            nearest_neighbors = matrix[i, :, indices]
            
            concat_list.append(nearest_neighbors)
            count += 1
        concat = torch.cat(concat_list, dim=1)
        stack_list.append(concat)
    prime = torch.stack(stack_list, dim= 0)
    return prime
        

In [13]:
# All Samples

@staticmethod 
def prime_vmap_2d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D"""
    batched_process = torch.vmap(process_batch, in_dims=(0, 0, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=True, maximum=maximum)
    return prime 

@staticmethod 
def prime_vmap_3d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D"""
    batched_process = torch.vmap(process_batch, in_dims=(0, 0, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=False, maximum=maximum)
    return prime

@staticmethod 
def process_batch(matrix, magnitude_matrix, num_nearest_neighbors, flatten, maximum): 
    """Process the batch of matrices by finding the K nearest neighbors with reshaping."""
    ind = torch.topk(magnitude_matrix, num_nearest_neighbors, largest=maximum).indices 
    neigh = matrix[:, ind]
    if flatten: 
        reshape = torch.flatten(neigh, start_dim=1)
        return reshape
    return neigh


In [14]:
def prime_optimized(matrix, dist_matrix, K=5, maximum=False):

    _, topk_indices = torch.topk(dist_matrix, k=K, dim=2, largest=True)  


    batch_size, channels, tokens = matrix.shape
    K = topk_indices.shape[-1]  

    indices_expanded = topk_indices.unsqueeze(1).expand(batch_size, channels, tokens, K)

    batch_indices = torch.arange(batch_size).view(batch_size, 1, 1, 1).expand(batch_size, channels, tokens, K)

    channel_indices = torch.arange(channels).view(1, channels, 1, 1).expand(batch_size, channels, tokens, K)

    prime = matrix[batch_indices, channel_indices, indices_expanded]

    prime_new = prime.view(batch_size, channels, -1)

    return prime_new

In [15]:
def _prime(matrix, magnitude_matrix, K, maximum):
    b, c, t = matrix.shape 

    _, topk_indices = torch.topk(magnitude_matrix, k = K, dim=2, largest=maximum)
    
    tk = topk_indices.shape[-1]
    
    assert K == tk, "Error: K must be same as tk. K == tk."
    
    indices_expanded = topk_indices.unsqueeze(1).expand(b, c, t, tk)
    batch_indices = torch.arange(b).view(b, 1, 1, 1).expand(b, c, t, tk)
    channel_indices = torch.arange(c).view(1, c, 1, 1).expand(b, c, t, tk)
    
    prime = matrix[batch_indices, channel_indices, indices_expanded]
    prime = prime.view(b, c, -1)
    return prime

In [16]:
mat = torch.randn(32, 3, 28)
sim = calculate_similarity_matrix(mat)
prime_b = prime_brute(mat, sim, 5)
prime_o = prime_optimized(mat, sim, 5)
prime_f = prime_vmap_2d(mat, sim, 5, True)
prime_ = _prime(mat, sim, 5, True)
print('original matrix shape', mat.shape)
print('original similarity matrix shape', sim.shape)
print()
print('prime brute shape', prime_b.shape)
print('prime optimized shape', prime_o.shape)
print('prime vmap shape', prime_f.shape)

print(torch.allclose(prime_b, prime_o))
print(torch.allclose(prime_b, prime_f))
print(torch.allclose(prime_o, prime_f))
print(torch.allclose(prime_b, prime_))


original matrix shape torch.Size([32, 3, 28])
original similarity matrix shape torch.Size([32, 28, 28])

prime brute shape torch.Size([32, 3, 140])
prime optimized shape torch.Size([32, 3, 140])
prime vmap shape torch.Size([32, 3, 140])
True
True
True
True


#### N Samples

In [17]:
torch.set_printoptions(precision=4, sci_mode=False)
import numpy as np


In [18]:
# N Samples
def calculate_similarity_matrix_N(matrix, matrix_sample): 
    """Calculates similarity matrix between two input matrices"""
    norm_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    norm_sample = F.normalize(matrix_sample, p=2, dim=1)
    similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
    return similarity_matrix

In [19]:

def prime_vmap_2d_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D"""
    batched_process = torch.vmap(process_batch_N, in_dims=(0, 0, None, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten=True, maximum=maximum)
    return prime 

def prime_vmap_3d_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, maximum): 
    """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D"""
    batched_process = torch.vmap(process_batch_N, in_dims=(0, 0, None, None), out_dims=0)
    prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten=False, maximum=maximum)
    return prime

def process_batch_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten, maximum): 
    """Process the batch of matrices by finding the K nearest neighbors with reshaping."""
    topk_ind = torch.topk(magnitude_matrix, num_nearest_neighbors - 1, largest=maximum).indices
    device = topk_ind.device
    rand_idx = rand_idx.to(device) # same device as topk_ind
    mapped_tensor = rand_idx[topk_ind] 
    index_tensor = torch.arange(0, matrix.shape[1], device=device).unsqueeze(1) # shape [40, 1]
    final_tensor = torch.cat([index_tensor, mapped_tensor], dim=1)
    neigh = matrix[:, final_tensor] 
    if flatten: 
        reshape = torch.flatten(neigh, start_dim=1)
        return reshape
    return neigh

In [20]:
def _prime_N(matrix, magnitude_matrix, K, rand_idx, maximum):
    """
    Create prime tensor for the N-samples case.
    
    Args:
        matrix (torch.Tensor): Input tensor of shape [b, c, t].
        magnitude_matrix (torch.Tensor): Magnitude matrix of shape [b, t, s] 
                                           (where s is the number of sampled tokens).
        K (int): Total number of neighbors to consider (including the self index).
        rand_idx (torch.Tensor): 1D tensor of length s containing the sampled indices.
        maximum (bool): If True, select the largest values (e.g. for similarity); if False, select the smallest (e.g. for distance).
        
    Returns:
        torch.Tensor: A prime tensor of shape [b, c, t*K].
    """
    b, c, t = matrix.shape
    # print("batch ", b)
    # print("channel ", c)
    # print("token ", t)
        

    _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
    tk = topk_indices.shape[-1]
    assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."
    # print("topk_indices shape: ", topk_indices.shape)  # torch.Size([32, 28, 4])
    
    mapped_tensor = rand_idx[topk_indices]
    # print("mapped_tensor shape: ", mapped_tensor.shape)  # torch.Size([32, 28, 4])
    # print(mapped_tensor)

    token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
    # print("token_indices shape: ", token_indices.shape)  # torch.Size([32, 28, 1])
    # 
    
    final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
    # print("final_indices shape: ", final_indices.shape)  # torch.Size([32, 28, 5])

    indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)
    # print("indices_expanded shape: ", indices_expanded.shape)  # torch.Size([32, 3, 28, 5])
    
    batch_indices = torch.arange(b, device=matrix.device).view(b, 1, 1, 1).expand(b, c, t, K)
    channel_indices = torch.arange(c, device=matrix.device).view(1, c, 1, 1).expand(b, c, t, K)
    

    prime = matrix[batch_indices, channel_indices, indices_expanded]  
    
    prime = prime.view(b, c, -1)
    return prime



In [21]:


samples = 10
x1 = torch.randn(32, 3, 28)


rand_idx = torch.tensor(np.random.choice(x1.shape[2], samples, replace=False)) # list 
x1_sample = x1[:, :, rand_idx]


sim = calculate_similarity_matrix_N(x1, x1_sample)
p1 = prime_vmap_2d_N(x1, sim, 5, rand_idx, True)
o1 = _prime_N(x1, sim, 5, rand_idx, True)
print('p1 shape', p1.shape)
print('o1 shape', o1.shape)
print(torch.allclose(p1, o1))

p1 shape torch.Size([32, 3, 140])
o1 shape torch.Size([32, 3, 140])
True


In [22]:
rand_idx = torch.tensor(np.random.choice(x1.shape[2], samples, replace=False)) # list 
print('rand_idx', rand_idx)

rand_idx_1 = torch.randperm(x1.shape[2], device=x1.device)[:samples]

print('rand_idx_1', rand_idx_1)

rand_idx tensor([20,  4,  0, 26,  6, 15,  8, 27,  9, 10])
rand_idx_1 tensor([23, 11,  1, 18, 15, 17,  7, 10,  0,  9])


In [23]:
print(p1)

tensor([[[    -1.0198,     -0.7261,     -0.3265,  ...,      0.7224,
               0.2761,      0.5273],
         [     0.3269,      0.7684,      0.5119,  ...,      0.1771,
               0.9181,      1.8624],
         [     0.9327,      0.5142,      1.9450,  ...,     -1.0296,
               0.2873,      0.7453]],

        [[     1.4551,      0.9084,      0.4332,  ...,     -0.3708,
               0.2935,     -1.3222],
         [    -0.4963,     -0.3979,     -0.0089,  ...,      0.4365,
               0.1287,      0.3454],
         [    -1.4958,     -0.5424,     -0.7425,  ...,     -1.2643,
              -0.3374,     -2.2412]],

        [[    -0.2315,      1.1534,      1.4498,  ...,      1.1534,
               1.4498,     -1.3786],
         [     0.4918,     -0.4089,     -0.1386,  ...,     -0.4089,
              -0.1386,      0.9560],
         [     1.9510,      1.9407,      1.8860,  ...,      1.9407,
               1.8860,      0.1638]],

        ...,

        [[     0.1174,      0.0060,

In [24]:
print(o1)

tensor([[[    -1.0198,     -0.7261,     -0.3265,  ...,      0.7224,
               0.2761,      0.5273],
         [     0.3269,      0.7684,      0.5119,  ...,      0.1771,
               0.9181,      1.8624],
         [     0.9327,      0.5142,      1.9450,  ...,     -1.0296,
               0.2873,      0.7453]],

        [[     1.4551,      0.9084,      0.4332,  ...,     -0.3708,
               0.2935,     -1.3222],
         [    -0.4963,     -0.3979,     -0.0089,  ...,      0.4365,
               0.1287,      0.3454],
         [    -1.4958,     -0.5424,     -0.7425,  ...,     -1.2643,
              -0.3374,     -2.2412]],

        [[    -0.2315,      1.1534,      1.4498,  ...,      1.1534,
               1.4498,     -1.3786],
         [     0.4918,     -0.4089,     -0.1386,  ...,     -0.4089,
              -0.1386,      0.9560],
         [     1.9510,      1.9407,      1.8860,  ...,      1.9407,
               1.8860,      0.1638]],

        ...,

        [[     0.1174,      0.0060,

# Optimization for function: prime
- This function gathers the K nearest neighbors from each element and creates a new matrix

### i. All Samples

In [96]:
def _prime_v3(matrix, magnitude_matrix, K, maximum):
    b, c, t = matrix.shape
    # Direct indexing approach that avoids intermediate tensors
    _, indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
    
    # Pre-allocate output
    result = torch.empty(b, c, t*K, dtype=matrix.dtype, device=matrix.device)
    
    # More efficient batched indexing
    for i in range(b):
        for j in range(t):
            idx = indices[i, j]  # [K]
            result[i, :, j*K:(j+1)*K] = matrix[i, :, idx]
    
    return result

In [97]:
@torch.jit.script
def _prime_v2(matrix: torch.Tensor, magnitude_matrix: torch.Tensor, 
             K: int, maximum: bool) -> torch.Tensor:
    b, c, t = matrix.shape
    # Get top-K indices: shape [b, t, K]
    _, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
    
    # Expand indices to add channel dimension: [b, 1, t, K] then expand to [b, c, t, K]
    topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
    
    # Unsqueeze matrix and expand so that the gathered dimension has size K.
    # matrix.unsqueeze(-1) yields shape [b, c, t, 1]
    # Then expand to [b, c, t, K] and force contiguous memory.
    matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
    
    # Gather along the token dimension (dim=2) using the expanded indices.
    prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)
    
    # Flatten the token and neighbor dimensions: [b, c, t*K]
    prime = prime.view(b, c, -1)
    return prime

In [98]:
@torch.jit.script
def _prime(matrix: torch.Tensor, magnitude_matrix: torch.Tensor, K: int, maximum: bool) -> torch.Tensor:
    b, c, t = matrix.shape 

    _, topk_indices = torch.topk(magnitude_matrix, k = K, dim=2, largest=maximum)
    
    tk = topk_indices.shape[-1]
    
    assert K == tk, "Error: K must be same as tk. K == tk."
    
    indices_expanded = topk_indices.unsqueeze(1).expand(b, c, t, tk)
    batch_indices = torch.arange(b).view(b, 1, 1, 1).expand(b, c, t, tk)
    channel_indices = torch.arange(c).view(1, c, 1, 1).expand(b, c, t, tk)
    
    prime = matrix[batch_indices, channel_indices, indices_expanded]
    prime = prime.view(b, c, -1)
    return prime

In [99]:
mat = torch.randn(32, 3, 28)
sim = calculate_similarity_matrix(mat)
prime_f = prime_vmap_2d(mat, sim, 5, True)
prime_ = _prime(mat, sim, 5, True)
prime_v2_ = _prime_v2(mat, sim, 5, True)
prime_v3_ = _prime_v3(mat, sim, 5, True)

print('original matrix shape', mat.shape)
print('original similarity matrix shape', sim.shape)
print('prime vmap shape', prime_f.shape)
print('prime shape', prime_.shape)
print('prime v2 shape', prime_v2_.shape)
print(torch.allclose(prime_f, prime_))
print(torch.allclose(prime_f, prime_v2_))
print(torch.allclose(prime_, prime_v2_))
print(torch.allclose(prime_f, prime_v3_))



original matrix shape torch.Size([32, 3, 28])
original similarity matrix shape torch.Size([32, 28, 28])
prime vmap shape torch.Size([32, 3, 140])
prime shape torch.Size([32, 3, 140])
prime v2 shape torch.Size([32, 3, 140])
True
True
True
True


### ii. N Samples

In [100]:
@torch.jit.script
def _prime_N_v3(matrix: torch.Tensor, magnitude_matrix: torch.Tensor, 
                               K: int, rand_idx: torch.Tensor, maximum: bool) -> torch.Tensor:
    b, c, t = matrix.shape

    # Get top-(K-1) indices
    _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
    tk = topk_indices.shape[-1]
    assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

    # Map indices
    mapped_tensor = rand_idx[topk_indices]

    # Create self indices for each token
    token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)

    # Concatenate self index with neighbor indices
    final_indices = torch.cat([token_indices, mapped_tensor], dim=2)

    # Pre-allocate output tensor
    result = torch.empty(b, c, t*K, device=matrix.device, dtype=matrix.dtype)
    
    # Process each batch and token position
    for i in range(b):
        for j in range(t):
            # Get the K indices for this position
            indices = final_indices[i, j]  # Shape: [K]
            # Fill the result tensor directly
            result[i, :, j*K:(j+1)*K] = matrix[i, :, indices]
    
    return result

In [101]:
@torch.jit.script
def _prime_N_v2(matrix: torch.Tensor, magnitude_matrix: torch.Tensor, K: int, rand_idx: torch.Tensor, maximum: bool) -> torch.Tensor:
    b, c, t = matrix.shape

    # Get top-(K-1) indices from the magnitude matrix; shape: [b, t, K-1]
    _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
    tk = topk_indices.shape[-1]
    assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

    # Map indices from the sampled space to the full token indices using rand_idx.
    # mapped_tensor will have shape: [b, t, K-1]
    mapped_tensor = rand_idx[topk_indices]

    # Create self indices for each token; shape: [1, t, 1] then expand to [b, t, 1]
    token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)

    # Concatenate self index with neighbor indices to form final indices; shape: [b, t, K]
    final_indices = torch.cat([token_indices, mapped_tensor], dim=2)

    # Expand final_indices to include the channel dimension; result shape: [b, c, t, K]
    indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)

    # Expand matrix to shape [b, c, t, 1] and then to [b, c, t, K] (ensuring contiguous memory)
    matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()

    # Gather neighbor features along the token dimension (dim=2)
    prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  # shape: [b, c, t, K]

    # Flatten the token and neighbor dimensions into one: [b, c, t*K]
    prime = prime.view(b, c, -1)
    return prime

In [102]:
@torch.jit.script
def _prime_N(matrix: torch.Tensor, magnitude_matrix: torch.Tensor, K: int, rand_idx: torch.Tensor, maximum: bool) -> torch.Tensor:
    """
    Create prime tensor for the N-samples case.
    
    Args:
        matrix (torch.Tensor): Input tensor of shape [b, c, t].
        magnitude_matrix (torch.Tensor): Magnitude matrix of shape [b, t, s] 
                                           (where s is the number of sampled tokens).
        K (int): Total number of neighbors to consider (including the self index).
        rand_idx (torch.Tensor): 1D tensor of length s containing the sampled indices.
        maximum (bool): If True, select the largest values (e.g. for similarity); if False, select the smallest (e.g. for distance).
        
    Returns:
        torch.Tensor: A prime tensor of shape [b, c, t*K].
    """
    b, c, t = matrix.shape
    # print("batch ", b)
    # print("channel ", c)
    # print("token ", t)
        

    _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
    tk = topk_indices.shape[-1]
    assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."
    # print("topk_indices shape: ", topk_indices.shape)  # torch.Size([32, 28, 4])
    
    mapped_tensor = rand_idx[topk_indices]
    # print("mapped_tensor shape: ", mapped_tensor.shape)  # torch.Size([32, 28, 4])
    # print(mapped_tensor)

    token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
    # print("token_indices shape: ", token_indices.shape)  # torch.Size([32, 28, 1])
    # 
    
    final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
    # print("final_indices shape: ", final_indices.shape)  # torch.Size([32, 28, 5])

    indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)
    # print("indices_expanded shape: ", indices_expanded.shape)  # torch.Size([32, 3, 28, 5])
    
    batch_indices = torch.arange(b, device=matrix.device).view(b, 1, 1, 1).expand(b, c, t, K)
    channel_indices = torch.arange(c, device=matrix.device).view(1, c, 1, 1).expand(b, c, t, K)
    

    prime = matrix[batch_indices, channel_indices, indices_expanded]  
    
    prime = prime.view(b, c, -1)
    return prime



In [103]:
samples = 10
x1 = torch.randn(32, 3, 28)


rand_idx = torch.tensor(np.random.choice(x1.shape[2], samples, replace=False)) # list 
x1_sample = x1[:, :, rand_idx]


sim = calculate_similarity_matrix_N(x1, x1_sample)
p1 = prime_vmap_2d_N(x1, sim, 5, rand_idx, True)
o1 = _prime_N(x1, sim, 5, rand_idx, True)
o1_v2 = _prime_N_v2(x1, sim, 5,rand_idx, True)
o1_v3 = _prime_N_v3(x1, sim, 5, rand_idx, True)
print('p1 shape', p1.shape)
print('o1 shape', o1.shape)
print(torch.allclose(p1, o1))
print(torch.allclose(p1, o1_v2))
print(torch.allclose(o1, o1_v2))
print(torch.allclose(o1, o1_v3))


p1 shape torch.Size([32, 3, 140])
o1 shape torch.Size([32, 3, 140])
True
True
True
True


## Time complexity for prime og, prime v1, prime v2

 All samples

In [77]:
iteration = 30000

In [78]:
ex = torch.randn(32, 3, 28)
o1 = _calculate_distance_matrix(ex)

start = time.time()
for i in range(iteration):
    prime_vmap_2d(ex, o1, 5, True)
end = time.time()
print("Time taken for vmap: ", end - start)
    

Time taken for vmap:  10.530007123947144


In [79]:
ex = torch.randn(32, 3, 28)
o1 = _calculate_distance_matrix(ex)

start = time.time()
for i in range(iteration):
    _prime(ex, o1, 5, True)
end = time.time()
print("Time taken for prime v1: ", end - start)
    

Time taken for prime v1:  9.789667844772339


In [105]:
ex = torch.randn(32, 3, 28)
o1 = _calculate_distance_matrix(ex)

start = time.time()
for i in range(iteration):
    _prime_v2(ex, o1, 5, True)
end = time.time()
print("Time taken for prime v2: ", end - start)
    

Time taken for prime v2:  7.953621864318848


In [106]:
ex = torch.randn(32, 3, 28)
o1 = _calculate_distance_matrix(ex)

start = time.time()
for i in range(iteration):
    _prime_v3(ex, o1, 5, True)
end = time.time()
print("Time taken for prime v3: ", end - start)
    

KeyboardInterrupt: 

N Samples

In [81]:
samples = 10
ex = torch.randn(32, 3, 28)
rand_idx = torch.tensor(np.random.choice(ex.shape[2], samples, replace=False)) # list 
ex_sample = ex[:, :, rand_idx]
sim = calculate_similarity_matrix_N(ex, ex_sample)


start = time.time()
for i in range(iteration):
    prime_vmap_2d_N(x1, sim, 5, rand_idx, True)
end = time.time()
print("Time taken for vmap N: ", end - start)


Time taken for vmap N:  7.737882852554321


In [82]:
samples = 10
ex = torch.randn(32, 3, 28)
rand_idx = torch.tensor(np.random.choice(ex.shape[2], samples, replace=False)) # list 
ex_sample = ex[:, :, rand_idx]
sim = calculate_similarity_matrix_N(ex, ex_sample)


start = time.time()
for i in range(iteration):
    _prime_N(x1, sim, 5, rand_idx, True)
end = time.time()
print("Time taken for prime v1 N: ", end - start)


Time taken for prime v1 N:  6.5011138916015625


In [None]:
samples = 10
ex = torch.randn(32, 3, 28)
rand_idx = torch.tensor(np.random.choice(ex.shape[2], samples, replace=False)) # list 
ex_sample = ex[:, :, rand_idx]
sim = calculate_similarity_matrix_N(ex, ex_sample)


start = time.time()
for i in range(iteration):
    _prime_N_v2(x1, sim, 5, rand_idx, True)
end = time.time()
print("Time taken for prime v2 N: ", end - start)


Time taken for prime v2 N:  5.222126007080078


In [107]:
samples = 10
ex = torch.randn(32, 3, 28)
rand_idx = torch.tensor(np.random.choice(ex.shape[2], samples, replace=False)) # list 
ex_sample = ex[:, :, rand_idx]
sim = calculate_similarity_matrix_N(ex, ex_sample)


start = time.time()
for i in range(iteration):
    _prime_N_v3(x1, sim, 5, rand_idx, True)
end = time.time()
print("Time taken for prime v3 N: ", end - start)


KeyboardInterrupt: 