# ConvNN Prime New

In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [2]:

def _calculate_similarity_matrix(self, matrix):
    # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    norm_matrix = F.normalize(matrix, p=2, dim=1) 
    similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_matrix)
    return similarity_matrix


In [3]:
matrix_x = torch.tensor([[[1, 2, 3, 4], [5, 6, 7, 8]]], dtype=torch.float32)

print("Input Matrix:", matrix_x.shape)
print("Input Matrix:", matrix_x)
print()

a = _calculate_similarity_matrix(None, matrix_x)
print("Similarity Matrix:", a.shape)
print("Similarity Matrix:", a)

# Input Matrix: torch.Size([1, 2, 4])
# Input Matrix: tensor(
#     [
#         [
#             [1., 2., 3., 4.],
#             [5., 6., 7., 8.]
#         ]
#     ]
# )

# Similarity Matrix: torch.Size([1, 4, 4])
# Similarity Matrix: tensor(
#     [
#         [      1.      2.      3.      4.
#          1. [1.0000, 0.9923, 0.9785, 0.9648],
#          2. [0.9923, 1.0000, 0.9965, 0.9899],
#          3. [0.9785, 0.9965, 1.0000, 0.9983],
#          4. [0.9648, 0.9899, 0.9983, 1.0000]
#         ]
#     ]
# )

# Prime Matrix: torch.Size([1, 2, 8])
# Prime Matrix: tensor(
#     [
#         [.     1.      2.      3.      4.
#             [1., 2., 2., 3., 3., 4., 4., 3.],
#             [5., 6., 6., 7., 7., 8., 8., 7.]
#         ]
#     ]
# )

# New Prime Matrix: torch.Size([1, 2, 8])
# New Prime Matrix: tensor(
#     [
#         [.          1.              2.              3.              4.
#             [1.0000, 1.9846, 2.0000, 2.9896, 3.0000, 3.9931, 4.0000, 2.9948],
#             [5.0000, 5.9537, 6.0000, 6.9758, 7.0000, 7.9862, 8.0000, 6.9879]
#         ]
#     ]
# )



Input Matrix: torch.Size([1, 2, 4])
Input Matrix: tensor([[[1., 2., 3., 4.],
         [5., 6., 7., 8.]]])

Similarity Matrix: torch.Size([1, 4, 4])
Similarity Matrix: tensor([[[1.0000, 0.9923, 0.9785, 0.9648],
         [0.9923, 1.0000, 0.9965, 0.9899],
         [0.9785, 0.9965, 1.0000, 0.9983],
         [0.9648, 0.9899, 0.9983, 1.0000]]])


In [4]:

def _prime(self, matrix, magnitude_matrix, K, maximum):
    b, c, t = matrix.shape
    _, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
    topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)    
    
    matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
    prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)
    prime = prime.view(b, c, -1)
    return prime

In [5]:
prime = _prime(None, matrix_x, a, K=2, maximum=True)
print("Prime Matrix:", prime.shape)
print("Prime Matrix:", prime)

Prime Matrix: torch.Size([1, 2, 8])
Prime Matrix: tensor([[[1., 2., 2., 3., 3., 4., 4., 3.],
         [5., 6., 6., 7., 7., 8., 8., 7.]]])


In [6]:

def _prime_new(self, matrix, magnitude_matrix, K, maximum):
    b, c, t = matrix.shape
    topk_values, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
    topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)    
    topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)    

    matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
    prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)
    prime = topk_values_exp * prime
    
    prime = prime.view(b, c, -1)
    return prime

In [7]:
new_prime = _prime_new(None, matrix_x, a, K=2, maximum=True)
print("New Prime Matrix:", new_prime.shape)
print("New Prime Matrix:", new_prime)

New Prime Matrix: torch.Size([1, 2, 8])
New Prime Matrix: tensor([[[1.0000, 1.9846, 2.0000, 2.9896, 3.0000, 3.9931, 4.0000, 2.9948],
         [5.0000, 5.9537, 6.0000, 6.9758, 7.0000, 7.9862, 8.0000, 6.9879]]])


### I. New Prime from Farias

In [8]:
def _prime(self, v, qk, K, maximum):
    b, c, t = v.shape 
    topk_values, topk_indices = torch.topk(qk, k=K, dim=-1, largest = maximum)
    topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
    topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)
    v_expanded = v.unsqueeze(-1).expand(b, c, t, K)
    prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
    prime = topk_values_exp * prime
    prime = prime.reshape(b, c, -1)
    return prime

In [17]:
def _prime_new(self, matrix, magnitude_matrix, K, maximum):
    b, c, t = matrix.shape
    topk_values, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
    topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)    
    topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)    
    print(topk_indices_exp)
    print(topk_values_exp)
    print("topk_indices_exp:", topk_indices_exp.shape)
    print("topk_values_exp:", topk_values_exp.shape)
    
    matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
    prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)
    prime = topk_values_exp * prime
    prime = prime.view(b, c, -1)
    return prime

In [18]:
new_prime = _prime_new(None, matrix_x, a, K=2, maximum=True)
print("New Prime Matrix:", new_prime.shape)
print("New Prime Matrix:", new_prime)

tensor([[[[0, 1],
          [1, 2],
          [2, 3],
          [3, 2]],

         [[0, 1],
          [1, 2],
          [2, 3],
          [3, 2]]]])
tensor([[[[1.0000, 0.9923],
          [1.0000, 0.9965],
          [1.0000, 0.9983],
          [1.0000, 0.9983]],

         [[1.0000, 0.9923],
          [1.0000, 0.9965],
          [1.0000, 0.9983],
          [1.0000, 0.9983]]]])
topk_indices_exp: torch.Size([1, 2, 4, 2])
topk_values_exp: torch.Size([1, 2, 4, 2])
New Prime Matrix: torch.Size([1, 2, 8])
New Prime Matrix: tensor([[[1.0000, 1.9846, 2.0000, 2.9896, 3.0000, 3.9931, 4.0000, 2.9948],
         [5.0000, 5.9537, 6.0000, 6.9758, 7.0000, 7.9862, 8.0000, 6.9879]]])


In [25]:
def _prime_N(matrix, magnitude_matrix, K, rand_idx, maximum):
    b, c, t = matrix.shape
    topk_values, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
    tk = topk_indices.shape[-1]
    assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

    mapped_tensor = rand_idx[topk_indices]
    token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
    final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
    indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)
    topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)
    ones = torch.ones((b, c, t, 1), device=matrix.device)
    topk_values_exp = torch.cat((ones, topk_values_exp), dim=-1)
    

    matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
    prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  
    prime = topk_values_exp * prime
    prime = prime.view(b, c, -1)
    return prime

In [33]:
rand_idx = torch.tensor([1])

In [34]:
new_prime = _prime_N(matrix_x, a, K=2, rand_idx=rand_idx, maximum=True)
print("New Prime Matrix:", new_prime.shape)
print("New Prime Matrix:", new_prime)

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [22]:
ones = torch.ones((1, 2, 4, 1))
print(ones)
zeros = torch.zeros((1, 2, 4, 1))
print(zeros)

combined = torch.cat((ones, zeros), dim=-1)
print(combined)

tensor([[[[1.],
          [1.],
          [1.],
          [1.]],

         [[1.],
          [1.],
          [1.],
          [1.]]]])
tensor([[[[0.],
          [0.],
          [0.],
          [0.]],

         [[0.],
          [0.],
          [0.],
          [0.]]]])
tensor([[[[1., 0.],
          [1., 0.],
          [1., 0.],
          [1., 0.]],

         [[1., 0.],
          [1., 0.],
          [1., 0.],
          [1., 0.]]]])


# New ConvNN

In [None]:

import torch 
import torch.nn as nn 
import torch.nn.functional as F 

class Conv2d_NN(nn.Module): 
    """Convolution 2D Nearest Neighbor Layer"""
    def __init__(self, 
                in_channels, 
                out_channels, 
                K,
                stride, 
                sampling_type, 
                num_samples, 
                sample_padding,
                shuffle_pattern, 
                shuffle_scale, 
                magnitude_type,
                coordinate_encoding=False
                ): 
        """
        Parameters: 
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            K (int): Number of Nearest Neighbors for consideration.
            stride (int): Stride size.
            sampling_type (str): Sampling type: "all", "random", "spatial".
            num_samples (int): Number of samples to consider. -1 for all samples.
            shuffle_pattern (str): Shuffle pattern: "B", "A", "BA".
            shuffle_scale (int): Shuffle scale factor.
            magnitude_type (str): Distance or Similarity.
        """
        super(Conv2d_NN, self).__init__()
        
        # Assertions 
        assert K == stride, "Error: K must be same as stride. K == stride."
        assert shuffle_pattern in ["B", "A", "BA", "NA"], "Error: shuffle_pattern must be one of ['B', 'A', 'BA', 'NA']"
        assert magnitude_type in ["distance", "similarity"], "Error: magnitude_type must be one of ['distance', 'similarity']"
        assert sampling_type in ["all", "random", "spatial"], "Error: sampling_type must be one of ['all', 'random', 'spatial']"
        assert int(num_samples) > 0 or int(num_samples) == -1, "Error: num_samples must be greater than 0 or -1 for all samples"
        assert (sampling_type == "all" and int(num_samples) == -1) or (sampling_type != "all" and isinstance(num_samples, int)), "Error: num_samples must be -1 for 'all' sampling or an integer for 'random' and 'spatial' sampling"
        
        # Initialize parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.stride = stride
        self.sampling_type = sampling_type
        self.num_samples = num_samples if num_samples != -1 else 'all'  # -1 for all samples
        self.sample_padding = sample_padding if sampling_type == "spatial" else 0
        self.shuffle_pattern = shuffle_pattern
        self.shuffle_scale = shuffle_scale
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'similarity' else False

        # Positional Encoding (optional)
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {} 
        self.in_channels = in_channels + 2 if self.coordinate_encoding else in_channels
        self.out_channels = out_channels + 2 if self.coordinate_encoding else out_channels

        # Shuffle2D/Unshuffle2D Layers
        self.shuffle_layer = nn.PixelShuffle(upscale_factor=self.shuffle_scale)
        self.unshuffle_layer = nn.PixelUnshuffle(downscale_factor=self.shuffle_scale)
        
        # Adjust Channels for PixelShuffle
        self.in_channels_1d = self.in_channels * (self.shuffle_scale**2) if self.shuffle_pattern in ["B", "BA"] else self.in_channels
        self.out_channels_1d = self.out_channels * (self.shuffle_scale**2) if self.shuffle_pattern in ["A", "BA"] else self.out_channels

        # Conv1d Layer
        self.conv1d_layer = nn.Conv1d(in_channels=self.in_channels_1d, 
                                      out_channels=self.out_channels_1d, 
                                      kernel_size=self.K, 
                                      stride=self.stride, 
                                      padding=0)

        # Flatten Layer
        self.flatten = nn.Flatten(start_dim=2)

        # Pointwise Convolution Layer
        self.pointwise_conv = nn.Conv2d(in_channels=self.out_channels,
                                         out_channels=self.out_channels - 2,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)
        
        

    def forward(self, x): 
        # Coordinate Channels (optional) + Unshuffle + Flatten 
        x = self._add_coordinate_encoding(x) if self.coordinate_encoding else x
        x_2d = self.unshuffle_layer(x) if self.shuffle_pattern in ["B", "BA"] else x
        x = self.flatten(x_2d)

        if self.sampling_type == "all":    
            # ConvNN Algorithm 
            matrix_magnitude = self._calculate_distance_matrix(x, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix(x)
            prime = self._prime_new(x, matrix_magnitude, self.K, self.maximum)
             
        elif self.sampling_type == "random":
            # Select random samples
            rand_idx = torch.randperm(x.shape[2], device=x.device)[:self.num_samples]
            x_sample = x[:, :, rand_idx]

            # ConvNN Algorithm 
            matrix_magnitude = self._calculate_distance_matrix_N(x, x_sample, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix_N(x, x_sample)
            range_idx = torch.arange(len(rand_idx), device=x.device)
            matrix_magnitude[:, rand_idx, range_idx] = float('inf') if self.magnitude_type == 'distance' else float('-inf')
            prime = self._prime_N_new(x, matrix_magnitude, self.K, rand_idx, self.maximum)
            
        elif self.sampling_type == "spatial":
            # Get spatial sampled indices
            x_ind = torch.linspace(0 + self.sample_padding, x_2d.shape[2] - self.sample_padding - 1, self.num_samples, device=x.device).to(torch.long)
            y_ind = torch.linspace(0 + self.sample_padding, x_2d.shape[3] - self.sample_padding - 1, self.num_samples, device=x.device).to(torch.long)
            x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')
            x_idx_flat, y_idx_flat = x_grid.flatten(), y_grid.flatten()
            width = x_2d.shape[2] 
            flat_indices = y_idx_flat * width + x_idx_flat  
            x_sample = x[:, :, flat_indices]

            # ConvNN Algorithm
            matrix_magnitude = self._calculate_distance_matrix_N(x, x_sample, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix_N(x, x_sample)
            range_idx = torch.arange(len(flat_indices), device=x.device)
            matrix_magnitude[:, flat_indices, range_idx] = float('inf') if self.magnitude_type == 'distance' else float('-inf')
            prime = self._prime_N(x, matrix_magnitude, self.K, flat_indices, self.maximum)
        else: 
            raise ValueError("Invalid sampling_type. Must be one of ['all', 'random', 'spatial'].")

        # Post-Processing 
        x_conv = self.conv1d_layer(prime) 
        
        # Unflatten + Shuffle
        unflatten = nn.Unflatten(dim=2, unflattened_size=x_2d.shape[2:])
        x = unflatten(x_conv)  # [batch_size, out_channels
        x = self.shuffle_layer(x) if self.shuffle_pattern in ["A", "BA"] else x
        x = self.pointwise_conv(x) if self.coordinate_encoding else x
        return x

    def _calculate_distance_matrix(self, matrix, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix)
        
        dist_matrix = norm_squared + norm_squared.transpose(2, 1) - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0) # remove negative values
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix # take square root if needed
        
        return dist_matrix
    
    def _calculate_distance_matrix_N(self, matrix, matrix_sample, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
        
        dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0) # remove negative values
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix

        return dist_matrix
    
    def _calculate_similarity_matrix(self, matrix):
        # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        norm_matrix = F.normalize(matrix, p=2, dim=1) 
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_matrix)
        return similarity_matrix
    
    def _calculate_similarity_matrix_N(self, matrix, matrix_sample):
        # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        norm_matrix = F.normalize(matrix, p=2, dim=1) 
        norm_sample = F.normalize(matrix_sample, p=2, dim=1)
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
        return similarity_matrix

    def _prime_new(self, matrix, magnitude_matrix, K, maximum):
        b, c, t = matrix.shape
        topk_values, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)    
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)    
        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime
        prime = prime.view(b, c, -1)
        return prime
    
    def _prime_N(self, matrix, magnitude_matrix, K, rand_idx, maximum):
        b, c, t = matrix.shape
        _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)

        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  
        prime = prime.view(b, c, -1)
        return prime
    
    def _prime_N_new(self, matrix, magnitude_matrix, K, rand_idx, maximum):
        b, c, t = matrix.shape
        topk_values, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)
        print("Indices Expanded:", indices_expanded.shape)

        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K-1)
        ones = torch.ones((b, c, t, 1), device=matrix.device)
        topk_values_exp = torch.cat((ones, topk_values_exp), dim=-1)
        print("Topk Values Expanded:", topk_values_exp.shape)

        print("indices_expanded:", indices_expanded)
        print("topk_values_exp:", topk_values_exp)

        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  
        prime = topk_values_exp * prime
        prime = prime.view(b, c, -1)
        return prime
    
    def _add_coordinate_encoding(self, x):
        b, _, h, w = x.shape
        cache_key = f"{b}_{h}_{w}_{x.device}"

        if cache_key in self.coordinate_cache:
            expanded_grid = self.coordinate_cache[cache_key]
        else:
            y_coords_vec = torch.linspace(start=-1, end=1, steps=h, device=x.device)
            x_coords_vec = torch.linspace(start=-1, end=1, steps=w, device=x.device)

            y_grid, x_grid = torch.meshgrid(y_coords_vec, x_coords_vec, indexing='ij')
            grid = torch.stack((x_grid, y_grid), dim=0).unsqueeze(0)
            expanded_grid = grid.expand(b, -1, -1, -1)
            self.coordinate_cache[cache_key] = expanded_grid

        x_with_coords = torch.cat((x, expanded_grid), dim=1)
        return x_with_coords

In [45]:
ex = torch.randn(8, 2, 32, 32) 

model = Conv2d_NN(in_channels=2, 
                  out_channels=4, 
                  K=4,
                  stride=4, 
                  sampling_type="random", 
                  num_samples=10, 
                  sample_padding=0,
                  shuffle_pattern="BA", 
                  shuffle_scale=2, 
                  magnitude_type="similarity",
                  coordinate_encoding=False) 

out = model(ex)
print("Output shape:", out.shape)  # Expected shape: (8, 4, 32, 32)

Indices Expanded: torch.Size([8, 8, 256, 4])
Topk Values Expanded: torch.Size([8, 8, 256, 4])
indices_expanded: tensor([[[[  0, 219, 125, 253],
          [  1, 135,   8,  99],
          [  2, 135, 253, 125],
          ...,
          [253,  99, 136, 190],
          [254, 196, 253, 209],
          [255, 219, 125, 209]],

         [[  0, 219, 125, 253],
          [  1, 135,   8,  99],
          [  2, 135, 253, 125],
          ...,
          [253,  99, 136, 190],
          [254, 196, 253, 209],
          [255, 219, 125, 209]],

         [[  0, 219, 125, 253],
          [  1, 135,   8,  99],
          [  2, 135, 253, 125],
          ...,
          [253,  99, 136, 190],
          [254, 196, 253, 209],
          [255, 219, 125, 209]],

         ...,

         [[  0, 219, 125, 253],
          [  1, 135,   8,  99],
          [  2, 135, 253, 125],
          ...,
          [253,  99, 136, 190],
          [254, 196, 253, 209],
          [255, 219, 125, 209]],

         [[  0, 219, 125, 253],
     

# Coordinate Encoding

In [10]:
import torch 
import torch.nn as nn 


In [11]:
coordinate_cache = {}

In [12]:
def _add_coordinate_encoding( x):
    b, c, t = x.shape 
    cache_key = f"{t}_{x.device}"
    if cache_key in coordinate_cache:
        coords_vec = coordinate_cache[cache_key]
    else:
        coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1)
        coordinate_cache[cache_key] = coords_vec

    expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1)
    x_with_coords = torch.cat((x, expanded_coords), dim=1)  
    return x_with_coords
    

In [13]:
x = torch.randn(256, 3, 10)
x_with_coords = _add_coordinate_encoding(x)
print(f"Input shape: {x.shape}")
print(f"With Coordinate Encoding shape: {x_with_coords.shape}")

print(f"Coordinate Cache Size: {len(coordinate_cache)}")
print(f"Coordinate Cache Keys: {list(coordinate_cache.keys())}")

print(f"coordinate cache: {coordinate_cache['10_cpu'][0, ]}")

print(f"x_with_coords: {x_with_coords[0, 3, :]}")  # Print the first channel of the first sample

Input shape: torch.Size([256, 3, 10])
With Coordinate Encoding shape: torch.Size([256, 4, 10])
Coordinate Cache Size: 1
Coordinate Cache Keys: ['10_cpu']
coordinate cache: tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
         0.7778,  1.0000])
x_with_coords: tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
         0.7778,  1.0000])


In [14]:
coordinate_cache = {} 
def _add_coordinate_encoding( x):
    b, c, t = x.shape 
    cache_key = f"{b}_{t}_{x.device}"
    if cache_key in coordinate_cache:
        expanded_coords = coordinate_cache[cache_key]
    else:
        coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1)
        expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1)
        coordinate_cache[cache_key] = expanded_coords


    x_with_coords = torch.cat((x, expanded_coords), dim=1)  
    return x_with_coords
    

In [15]:
x = torch.randn(64, 3, 10)
x_with_coords = _add_coordinate_encoding(x)
print(f"Input shape: {x.shape}")
print(f"With Coordinate Encoding shape: {x_with_coords.shape}")

print(f"Coordinate Cache Size: {len(coordinate_cache)}")
print(f"Coordinate Cache Keys: {list(coordinate_cache.keys())}")

print(f"coordinate cache: {coordinate_cache['64_10_cpu'][0, ]}")

print(f"x_with_coords: {x_with_coords[0, 3, :]}")  # Print the first channel of the first sample

Input shape: torch.Size([64, 3, 10])
With Coordinate Encoding shape: torch.Size([64, 4, 10])
Coordinate Cache Size: 1
Coordinate Cache Keys: ['64_10_cpu']
coordinate cache: tensor([[-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
          0.7778,  1.0000]])
x_with_coords: tensor([-1.0000, -0.7778, -0.5556, -0.3333, -0.1111,  0.1111,  0.3333,  0.5556,
         0.7778,  1.0000])


## ConvNN 

In [16]:
"""
ConvNN
Total parameters: 185,150
Trainable parameters: 185,150

Input shape: torch.Size([64, 197, 64])
After Permute: torch.Size([64, 64, 197])
After Split_head: torch.Size([64, 4, 197, 16])
After Batch_Combine: torch.Size([256, 16, 197]) ### [256, 17, 197] added coordinate encoding
After Conv1d: torch.Size([256, 16, 197]) ###    [256, 17, 197] added coordinate encoding
After batch_split: torch.Size([64, 4, 197, 16]) ### [64, 4, 197, 17] added coordinate encoding
After Combine_Heads: torch.Size([64, 64, 197]) ### [64, 68, 197] added coordinate encoding

Output shape: torch.Size([64, 100])
"""

'\nConvNN\nTotal parameters: 185,150\nTrainable parameters: 185,150\n\nInput shape: torch.Size([64, 197, 64])\nAfter Permute: torch.Size([64, 64, 197])\nAfter Split_head: torch.Size([64, 4, 197, 16])\nAfter Batch_Combine: torch.Size([256, 16, 197]) ### [256, 17, 197] added coordinate encoding\nAfter Conv1d: torch.Size([256, 16, 197]) ###    [256, 17, 197] added coordinate encoding\nAfter batch_split: torch.Size([64, 4, 197, 16]) ### [64, 4, 197, 17] added coordinate encoding\nAfter Combine_Heads: torch.Size([64, 64, 197]) ### [64, 68, 197] added coordinate encoding\n\nOutput shape: torch.Size([64, 100])\n'

## ConvNNAttention

In [None]:
"""
ConvNNAttention
Total parameters: 83,716
Trainable parameters: 83,716

Input shape: torch.Size([64, 197, 64])
After Split_head: torch.Size([64, 4, 197, 16])
After Batch_Combine: torch.Size([256, 16, 197])
After Conv1d: torch.Size([256, 16, 197])
After permute: torch.Size([256, 197, 16])
After Batch_Split: torch.Size([64, 4, 197, 16])
After Combine_Heads: torch.Size([64, 197, 64])

Output shape: torch.Size([64, 100])

"""

╰─$ python -u "/Users/mingikang/Developer/Convolutional-Nearest-Neighbor/vit.py"
Regular Attention
Total parameters: 70,228
Trainable parameters: 70,228
Output shape: torch.Size([64, 100])

ConvNN
Total parameters: 218,546
Trainable parameters: 218,546
Output shape: torch.Size([64, 100])

ConvNNAttention
Total parameters: 71,930
Trainable parameters: 71,930
Output shape: torch.Size([64, 100])

╰─$ python -u "/Users/mingikang/Developer/Convolutional-Nearest-Neighbor/vit.py"
Regular Attention
Total parameters: 70,228
Trainable parameters: 70,228
Output shape: torch.Size([64, 100])

ConvNN
Total parameters: 218,295
Trainable parameters: 218,295
Output shape: torch.Size([64, 100])

ConvNNAttention
Total parameters: 71,679
Trainable parameters: 71,679
Output shape: torch.Size([64, 100])