In [1]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [2]:
coordinate_cache = {}
def _add_coordinate_encoding(x):
    b, _, h, w = x.shape
    cache_key = f"{b}_{h}_{w}_{x.device}"

    if cache_key in coordinate_cache:
        expanded_grid = coordinate_cache[cache_key]
    else:
        y_coords_vec = torch.linspace(start=-1, end=1, steps=h, device=x.device)
        x_coords_vec = torch.linspace(start=-1, end=1, steps=w, device=x.device)

        y_grid, x_grid = torch.meshgrid(y_coords_vec, x_coords_vec, indexing='ij')
        grid = torch.stack((x_grid, y_grid), dim=0).unsqueeze(0)
        expanded_grid = grid.expand(b, -1, -1, -1)
        coordinate_cache[cache_key] = expanded_grid

    x_with_coords = torch.cat((x, expanded_grid), dim=1)
    return x_with_coords

In [3]:
ex = torch.rand(1, 3, 5, 5) 
ex_coord = _add_coordinate_encoding(ex)
print(ex_coord.shape)


torch.Size([1, 5, 5, 5])


In [4]:
coord = ex_coord[:, -2:, :, :]
print(coord.shape)

torch.Size([1, 2, 5, 5])


In [5]:
def _calculate_similarity_matrix(matrix):
    # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    norm_matrix = F.normalize(matrix, p=2, dim=1) 
    similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_matrix)
    similarity_matrix = torch.clamp(similarity_matrix, min=-1.0, max=1.0) 
    return similarity_matrix

In [6]:
coord = nn.Flatten(start_dim=2)(ex_coord[:, -2:, :, :])
print(coord.shape)
coord_mat = _calculate_similarity_matrix(coord)
print(coord_mat)

torch.Size([1, 2, 25])
tensor([[[ 1.0000e+00,  9.4868e-01,  7.0711e-01,  3.1623e-01,  1.2688e-08,
           9.4868e-01,  1.0000e+00,  7.0711e-01,  1.2688e-08, -3.1623e-01,
           7.0711e-01,  7.0711e-01,  0.0000e+00, -7.0711e-01, -7.0711e-01,
           3.1623e-01, -1.2688e-08, -7.0711e-01, -1.0000e+00, -9.4868e-01,
          -1.2688e-08, -3.1623e-01, -7.0711e-01, -9.4868e-01, -1.0000e+00],
         [ 9.4868e-01,  1.0000e+00,  8.9443e-01,  6.0000e-01,  3.1623e-01,
           8.0000e-01,  9.4868e-01,  8.9443e-01,  3.1623e-01,  1.4263e-08,
           4.4721e-01,  4.4721e-01,  0.0000e+00, -4.4721e-01, -4.4721e-01,
          -1.4263e-08, -3.1623e-01, -8.9443e-01, -9.4868e-01, -8.0000e-01,
          -3.1623e-01, -6.0000e-01, -8.9443e-01, -1.0000e+00, -9.4868e-01],
         [ 7.0711e-01,  8.9443e-01,  1.0000e+00,  8.9443e-01,  7.0711e-01,
           4.4721e-01,  7.0711e-01,  1.0000e+00,  7.0711e-01,  4.4721e-01,
           0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,


In [7]:
ex = torch.rand(1, 3, 5)
print(ex)
print()
sim_mat = _calculate_similarity_matrix(ex)
print(sim_mat)
print() 



tensor([[[0.9541, 0.0684, 0.7752, 0.0143, 0.3638],
         [0.5275, 0.3893, 0.6406, 0.2623, 0.3145],
         [0.3377, 0.2503, 0.4977, 0.4110, 0.5179]]])

tensor([[[1.0000, 0.6650, 0.9727, 0.5224, 0.8528],
         [0.6650, 1.0000, 0.8133, 0.9026, 0.8375],
         [0.9727, 0.8133, 1.0000, 0.7010, 0.9347],
         [0.5224, 0.9026, 0.7010, 1.0000, 0.8719],
         [0.8528, 0.8375, 0.9347, 0.8719, 1.0000]]])



In [8]:
def _calculate_similarity_matrix(matrix):
    # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    norm_matrix = F.normalize(matrix, p=2, dim=1) 
    similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_matrix)
    similarity_matrix = torch.clamp(similarity_matrix, min=-1.0, max=1.0) 
    return similarity_matrix

In [9]:
sim_mat = ex.transpose(2, 1) @ ex
print(sim_mat)
print()
sim_mat = nn.functional.normalize(sim_mat, p=2, dim=1)
print(sim_mat)
print()
sim_mat = _calculate_similarity_matrix(ex)
print(sim_mat)
print() 


tensor([[[1.3027, 0.3551, 1.2456, 0.2909, 0.6879],
         [0.3551, 0.2189, 0.4269, 0.2060, 0.2769],
         [1.2456, 0.4269, 1.2590, 0.3837, 0.7412],
         [0.2909, 0.2060, 0.3837, 0.2380, 0.3006],
         [0.6879, 0.2769, 0.7412, 0.3006, 0.4994]]])

tensor([[[0.6569, 0.5150, 0.6216, 0.4482, 0.5734],
         [0.1791, 0.3174, 0.2130, 0.3174, 0.2308],
         [0.6281, 0.6192, 0.6283, 0.5912, 0.6179],
         [0.1467, 0.2987, 0.1915, 0.3666, 0.2506],
         [0.3469, 0.4016, 0.3699, 0.4631, 0.4163]]])

tensor([[[1.0000, 0.6650, 0.9727, 0.5224, 0.8528],
         [0.6650, 1.0000, 0.8133, 0.9026, 0.8375],
         [0.9727, 0.8133, 1.0000, 0.7010, 0.9347],
         [0.5224, 0.9026, 0.7010, 1.0000, 0.8719],
         [0.8528, 0.8375, 0.9347, 0.8719, 1.0000]]])



In [10]:
print(sim_mat.shape)

torch.Size([1, 5, 5])


# SANITY CHECK

In [11]:
class Conv2d_NN(nn.Module): 
    """Convolution 2D Nearest Neighbor Layer"""
    def __init__(self, 
                in_channels, 
                out_channels, 
                K,
                stride, 
                sampling_type, 
                num_samples, 
                sample_padding,
                shuffle_pattern, 
                shuffle_scale, 
                magnitude_type,
                coordinate_encoding
                ): 
        """
        Parameters: 
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            K (int): Number of Nearest Neighbors for consideration.
            stride (int): Stride size.
            sampling_type (str): Sampling type: "all", "random", "spatial".
            num_samples (int): Number of samples to consider. -1 for all samples.
            shuffle_pattern (str): Shuffle pattern: "B", "A", "BA".
            shuffle_scale (int): Shuffle scale factor.
            magnitude_type (str): Distance or Similarity.
        """
        super(Conv2d_NN, self).__init__()
        
        # Assertions 
        assert K == stride, "Error: K must be same as stride. K == stride."
        assert shuffle_pattern in ["B", "A", "BA", "NA"], "Error: shuffle_pattern must be one of ['B', 'A', 'BA', 'NA']"
        assert magnitude_type in ["distance", "similarity"], "Error: magnitude_type must be one of ['distance', 'similarity']"
        assert sampling_type in ["all", "random", "spatial"], "Error: sampling_type must be one of ['all', 'random', 'spatial']"
        assert int(num_samples) > 0 or int(num_samples) == -1, "Error: num_samples must be greater than 0 or -1 for all samples"
        assert (sampling_type == "all" and int(num_samples) == -1) or (sampling_type != "all" and isinstance(num_samples, int)), "Error: num_samples must be -1 for 'all' sampling or an integer for 'random' and 'spatial' sampling"
        
        # Initialize parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.stride = stride
        self.sampling_type = sampling_type
        self.num_samples = num_samples if num_samples != -1 else 'all'  # -1 for all samples
        self.sample_padding = sample_padding if sampling_type == "spatial" else 0
        self.shuffle_pattern = shuffle_pattern
        self.shuffle_scale = shuffle_scale
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'similarity' else False
        self.INF_DISTANCE = 1e10
        self.NEG_INF_DISTANCE = -1e10

        # Positional Encoding (optional)
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {} 
        self.in_channels = in_channels + 2 if self.coordinate_encoding else in_channels
        self.out_channels = out_channels # + 2 if self.coordinate_encoding else out_channels

        # Shuffle2D/Unshuffle2D Layers
        self.shuffle_layer = nn.PixelShuffle(upscale_factor=self.shuffle_scale)
        self.unshuffle_layer = nn.PixelUnshuffle(downscale_factor=self.shuffle_scale)
        
        # Adjust Channels for PixelShuffle
        self.in_channels_1d = self.in_channels * (self.shuffle_scale**2) if self.shuffle_pattern in ["B", "BA"] else self.in_channels
        self.out_channels_1d = self.out_channels * (self.shuffle_scale**2) if self.shuffle_pattern in ["A", "BA"] else self.out_channels

        # Conv1d Layer
        self.in_channels_1d = 1
        self.conv1d_layer = nn.Conv1d(in_channels=self.in_channels_1d, 
                                      out_channels=self.out_channels_1d, 
                                      kernel_size=self.K, 
                                      stride=self.stride, 
                                      padding=0)

        # Flatten Layer
        self.flatten = nn.Flatten(start_dim=2)

        # # Pointwise Convolution Layer
        # self.pointwise_conv = nn.Conv2d(in_channels=self.out_channels,
        #                                  out_channels=self.out_channels - 2,
        #                                  kernel_size=1,
        #                                  stride=1,
        #                                  padding=0)
        
        

    def forward(self, x): 
        # Coordinate Channels (optional) + Unshuffle + Flatten 
        x = F.pad(x, (1, 1, 1, 1), mode='constant', value=0)
        print("x padded: ")
        print(x)
        print()
        x = self._add_coordinate_encoding(x) if self.coordinate_encoding else x
        print("x after coordinate encoding: ")
        print(x)
        x_2d = self.unshuffle_layer(x) if self.shuffle_pattern in ["B", "BA"] else x
        x = self.flatten(x_2d)
        print("x: ")
        print(x)
        print() 
        print(x.shape)
        
        if self.sampling_type == "all":    
            # ConvNN Algorithm 
            x_dist = x[:, -2:, :]
            print("x_dist: ")
            print(x_dist)
            print(x_dist.shape)

            matrix_magnitude = self._calculate_distance_matrix(x_dist, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix(x_dist)
            print("matrix_magnitude: ")
            print(matrix_magnitude)
            x = x[:, 0, :].unsqueeze(1)
            print(x)
            print(x.shape)
            prime = self._prime(x, matrix_magnitude, self.K, self.maximum)

        elif self.sampling_type == "random":
            # Select random samples
            rand_idx = torch.randperm(x.shape[2], device=x.device)[:self.num_samples]
            x_sample = x[:, :, rand_idx]

            # ConvNN Algorithm 
            matrix_magnitude = self._calculate_distance_matrix_N(x, x_sample, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix_N(x, x_sample)
            range_idx = torch.arange(len(rand_idx), device=x.device)
            matrix_magnitude[:, rand_idx, range_idx] = self.INF_DISTANCE if self.magnitude_type == 'distance' else self.NEG_INF_DISTANCE
            prime = self._prime_N(x, matrix_magnitude, self.K, rand_idx, self.maximum)
            
        elif self.sampling_type == "spatial":
            # Get spatial sampled indices
            x_ind = torch.linspace(0 + self.sample_padding, x_2d.shape[2] - self.sample_padding - 1, self.num_samples, device=x.device).to(torch.long)
            y_ind = torch.linspace(0 + self.sample_padding, x_2d.shape[3] - self.sample_padding - 1, self.num_samples, device=x.device).to(torch.long)
            x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')
            x_idx_flat, y_idx_flat = x_grid.flatten(), y_grid.flatten()
            width = x_2d.shape[2] 
            flat_indices = y_idx_flat * width + x_idx_flat  
            x_sample = x[:, :, flat_indices]

            # ConvNN Algorithm
            matrix_magnitude = self._calculate_distance_matrix_N(x, x_sample, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix_N(x, x_sample)
            range_idx = torch.arange(len(flat_indices), device=x.device)
            matrix_magnitude[:, flat_indices, range_idx] = self.INF_DISTANCE if self.magnitude_type == 'distance' else self.NEG_INF_DISTANCE
            prime = self._prime_N(x, matrix_magnitude, self.K, flat_indices, self.maximum)
        else: 
            raise ValueError("Invalid sampling_type. Must be one of ['all', 'random', 'spatial'].")
        print("\n prime: ")
        print(prime.shape)
        print(prime)
        # Post-Processing 
        x_conv = self.conv1d_layer(prime) 
        
        # Unflatten + Shuffle
        unflatten = nn.Unflatten(dim=2, unflattened_size=x_2d.shape[2:])
        x = unflatten(x_conv)  # [batch_size, out_channels
        x = self.shuffle_layer(x) if self.shuffle_pattern in ["A", "BA"] else x
        # x = self.pointwise_conv(x) if self.coordinate_encoding else x
        return x

    def _calculate_distance_matrix(self, matrix, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix)
        
        dist_matrix = norm_squared + norm_squared.transpose(2, 1) - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=-1.0, max=1.0) 
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix 
        
        return dist_matrix
    
    def _calculate_distance_matrix_N(self, matrix, matrix_sample, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
        
        dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=-1.0, max=1.0) 
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix

        return dist_matrix
    
    # def _calculate_similarity_matrix(self, matrix):
    #     # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
    #     norm_matrix = F.normalize(matrix, p=2, dim=1) 
    #     similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_matrix)
    #     similarity_matrix = torch.clamp(similarity_matrix, min=-1.0, max=1.0) 
    #     return similarity_matrix

    def _calculate_similarity_matrix(self, matrix, sigma=0.1):
        """Calculate similarity matrix based on coordinate distance"""
        b, c, t = matrix.shape  # c should be 2 for (x, y) coordinates
        
        # Calculate pairwise Euclidean distances between coordinates
        coord_expanded_1 = matrix.unsqueeze(3)  # [B, 2, T, 1]
        coord_expanded_2 = matrix.unsqueeze(2)  # [B, 2, 1, T]
        
        # Euclidean distance between coordinates
        coord_diff = coord_expanded_1 - coord_expanded_2  # [B, 2, T, T]
        coord_dist = torch.sqrt(torch.sum(coord_diff ** 2, dim=1) + 1e-8)  # [B, T, T]
        
        # Convert distance to similarity using Gaussian kernel
        similarity_matrix = torch.exp(-coord_dist ** 2 / (2 * sigma ** 2))
        
        return similarity_matrix
    
    def _calculate_similarity_matrix_N(self, matrix, matrix_sample):
        # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        norm_matrix = F.normalize(matrix, p=2, dim=1) 
        norm_sample = F.normalize(matrix_sample, p=2, dim=1)
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
        similarity_matrix = torch.clamp(similarity_matrix, min=-1.0, max=1.0) 
        return similarity_matrix
    
    def _prime(self, matrix, magnitude_matrix, K, maximum):
        b, c, t = matrix.shape
        _, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)    
        
        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)
        print("prime: ")
        print(prime)
        print()
        prime = prime.view(b, c, -1)
        return prime
        
    def _prime_N(self, matrix, magnitude_matrix, K, rand_idx, maximum):
        b, c, t = matrix.shape
        
        _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        # Map sample indices back to original matrix positions
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)

        # Gather matrix values and apply similarity weighting
        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  
        prime = prime.view(b, c, -1)
        return prime
    
    def _add_coordinate_encoding(self, x):
        b, _, h, w = x.shape
        cache_key = f"{b}_{h}_{w}_{x.device}"

        if cache_key in self.coordinate_cache:
            expanded_grid = self.coordinate_cache[cache_key]
        else:
            y_coords_vec = torch.linspace(start=-1, end=1, steps=h, device=x.device)
            x_coords_vec = torch.linspace(start=-1, end=1, steps=w, device=x.device)

            y_grid, x_grid = torch.meshgrid(y_coords_vec, x_coords_vec, indexing='ij')
            grid = torch.stack((x_grid, y_grid), dim=0).unsqueeze(0)
            expanded_grid = grid.expand(b, -1, -1, -1)
            self.coordinate_cache[cache_key] = expanded_grid

        x_with_coords = torch.cat((x, expanded_grid), dim=1)
        return x_with_coords ### Last two channels are coordinate channels 

In [12]:
ex = torch.Tensor(
    [
        [
            [
                [1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]
            ]
        ]
    ]
)

convnn = Conv2d_NN(in_channels=1, out_channels=6, K=9, stride=9, sampling_type='all', num_samples=-1, sample_padding=0, shuffle_pattern='NA', shuffle_scale=0.0, magnitude_type='similarity', coordinate_encoding=True)
out = convnn(ex)
print(out.shape)

"""
tensor([[[0., 0., 0., 0., 0., 
          0., 1., 2., 3., 0., 
          0., 4., 5., 6., 0., 
          0., 7., 8., 9., 0., 
          0., 0., 0., 0., 0.]]])

          [1., 4., 0., 2., 0., 5., 0., 0., 0.],
          [2., 1., 0., 3., 5., 0., 0., 6., 4.],
          [3., 6., 0., 0., 2., 5., 0., 0., 0.],  
          [4., 5., 1., 0., 7., 0., 2., 0., 8.],
          [5., 6., 8., 4., 2., 7., 1., 9., 3.],
          [6., 0., 5., 3., 9., 8., 2., 0., 0.],
          [7., 0., 8., 0., 4., 5., 0., 0., 0.],
          [8., 5., 0., 7., 9., 4., 0., 0., 6.],
          [9., 0., 0., 8., 6., 0., 0., 0., 5.],
          
          [1., 4., 0., 2., 0., 5., 0., 0., 0.],
          [2., 1., 0., 3., 5., 0., 0., 6., 4.],
          [3., 6., 0., 0., 2., 5., 0., 0., 0.],
          [4., 5., 1., 0., 7., 0., 2., 0., 8.],
          [5., 6., 8., 4., 2., 7., 1., 9., 3.],
          [6., 0., 5., 3., 9., 8., 2., 0., 0.],
          [7., 0., 8., 0., 4., 5., 0., 0., 0.],
          [8., 5., 0., 7., 9., 4., 0., 0., 6.],
          [9., 0., 0., 8., 6., 0., 0., 0., 5.]]]])
          
Window 0 (position [0, 0]):
tensor([[0., 0., 0.],
        [0., 1., 2.],
        [0., 4., 5.]])
Flattened: [0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 5.0]

Window 1 (position [0, 1]):
tensor([[0., 0., 0.],
        [1., 2., 3.],
        [4., 5., 6.]])
Flattened: [0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

Window 2 (position [0, 2]):
tensor([[0., 0., 0.],
        [2., 3., 0.],
        [5., 6., 0.]])
Flattened: [0.0, 0.0, 0.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0]

Window 3 (position [1, 0]):
tensor([[0., 1., 2.],
        [0., 4., 5.],
        [0., 7., 8.]])
Flattened: [0.0, 1.0, 2.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0]

Window 4 (position [1, 1]):
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
Flattened: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]

Window 5 (position [1, 2]):
tensor([[2., 3., 0.],
        [5., 6., 0.],
        [8., 9., 0.]])
Flattened: [2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 8.0, 9.0, 0.0]

Window 6 (position [2, 0]):
tensor([[0., 4., 5.],
        [0., 7., 8.],
        [0., 0., 0.]])
Flattened: [0.0, 4.0, 5.0, 0.0, 7.0, 8.0, 0.0, 0.0, 0.0]

Window 7 (position [2, 1]):
tensor([[4., 5., 6.],
        [7., 8., 9.],
        [0., 0., 0.]])
Flattened: [4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0, 0.0]

Window 8 (position [2, 2]):
tensor([[5., 6., 0.],
        [8., 9., 0.],
        [0., 0., 0.]])
Flattened: [5.0, 6.0, 0.0, 8.0, 9.0, 0.0, 0.0, 0.0, 0.0]

          
torch.Size([1, 1, 25])
prime: 
tensor([[[[0., 0., 0., 1., 0., 0., 4., 2., 5.],
          [0., 0., 0., 1., 0., 2., 0., 4., 3.],
          [0., 0., 0., 2., 3., 1., 5., 0., 0.],
          [0., 0., 0., 3., 0., 2., 6., 0., 1.],
          [0., 0., 0., 3., 0., 0., 6., 2., 5.],
          [0., 0., 0., 1., 0., 4., 0., 2., 5.],
          [1., 4., 0., 2., 0., 5., 0., 0., 0.],
          [2., 1., 0., 3., 5., 0., 0., 6., 4.],
          [3., 6., 0., 0., 2., 5., 0., 0., 0.],
          [0., 0., 0., 3., 6., 0., 0., 2., 5.],
          [0., 0., 4., 0., 7., 1., 5., 0., 0.],
          [4., 5., 1., 0., 7., 0., 2., 0., 8.],
          [5., 6., 8., 4., 2., 7., 1., 9., 3.],
          [6., 0., 5., 3., 9., 8., 2., 0., 0.],
          [0., 0., 6., 0., 9., 3., 0., 5., 0.],
          [0., 0., 0., 7., 4., 0., 8., 0., 5.],
          [7., 0., 8., 0., 4., 5., 0., 0., 0.],
          [8., 5., 0., 7., 9., 4., 0., 0., 6.],
          [9., 0., 0., 8., 6., 0., 0., 0., 5.],
          [0., 0., 9., 0., 0., 6., 8., 0., 3.],
          [0., 0., 0., 7., 0., 0., 8., 4., 5.],
          [0., 0., 0., 7., 8., 0., 4., 0., 9.],
          [0., 0., 0., 8., 9., 7., 0., 5., 0.],
          [0., 0., 0., 9., 0., 8., 0., 6., 5.],
          [0., 0., 0., 9., 0., 0., 8., 6., 5.]]]])

            
"""

"""
The results may be different because of the different positions of the kernel window from ConvNN. 
Example: 
Convolution = 
tensor([[0., 0., 0.],
        [0., 1., 2.],
        [0., 4., 5.]])
Flattened: [0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 5.0]

ConvNN = 
tensor([[0., 0., 0.],
        [0., 1., 2.],
        [0., 4., 5.]])
Topk:      [1.0, 4.0, 0.0, 2.0, 0.0, 5.0, 0.0, 0.0, 0.0],
"""

# SOLVED

x padded: 
tensor([[[[0., 0., 0., 0., 0.],
          [0., 1., 2., 3., 0.],
          [0., 4., 5., 6., 0.],
          [0., 7., 8., 9., 0.],
          [0., 0., 0., 0., 0.]]]])

x after coordinate encoding: 
tensor([[[[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.0000,  1.0000,  2.0000,  3.0000,  0.0000],
          [ 0.0000,  4.0000,  5.0000,  6.0000,  0.0000],
          [ 0.0000,  7.0000,  8.0000,  9.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

         [[-1.0000, -0.5000,  0.0000,  0.5000,  1.0000],
          [-1.0000, -0.5000,  0.0000,  0.5000,  1.0000],
          [-1.0000, -0.5000,  0.0000,  0.5000,  1.0000],
          [-1.0000, -0.5000,  0.0000,  0.5000,  1.0000],
          [-1.0000, -0.5000,  0.0000,  0.5000,  1.0000]],

         [[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000],
          [-0.5000, -0.5000, -0.5000, -0.5000, -0.5000],
          [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.5000,  0.5000,  0.5000,  0.5000,  0.

'\nThe results may be different because of the different positions of the kernel window from ConvNN. \nExample: \nConvolution = \ntensor([[0., 0., 0.],\n        [0., 1., 2.],\n        [0., 4., 5.]])\nFlattened: [0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 5.0]\n\nConvNN = \ntensor([[0., 0., 0.],\n        [0., 1., 2.],\n        [0., 4., 5.]])\nTopk:      [1.0, 4.0, 0.0, 2.0, 0.0, 5.0, 0.0, 0.0, 0.0],\n'

In [13]:
import torch

def filter_non_zero_starting_rows_multichannel(tensor):
    """
    Filter rows based on the first element of the first channel being non-zero
    
    Args:
        tensor: Input tensor of shape [B, C, num_rows, row_length]
    
    Returns:
        Filtered tensor with only rows where first channel's first element != 0
    """
    # Get the shape
    b, c, num_rows, row_length = tensor.shape
    
    # Create mask based on first channel only
    # tensor[:, 0, :, 0] gets first element of each row in first channel
    mask = tensor[:, 0, :, 0] != 0  # Shape: [b, num_rows]
    
    # Get indices of non-zero starting rows
    non_zero_indices = torch.where(mask[0])[0]  # [0] because batch dimension
    
    # Select rows from ALL channels
    filtered_tensor = tensor[:, :, non_zero_indices, :]
    
    return filtered_tensor, non_zero_indices

# Test with your example (assuming you have multiple channels)
# For demonstration, let's create a multi-channel version
prime_single_channel = torch.tensor([[[[0., 0., 0., 1., 0., 0., 4., 2., 5.],
                                      [0., 0., 0., 1., 0., 2., 0., 4., 3.],
                                      [0., 0., 0., 2., 3., 1., 5., 0., 0.],
                                      [0., 0., 0., 3., 0., 2., 6., 0., 1.],
                                      [0., 0., 0., 3., 0., 0., 6., 2., 5.],
                                      [0., 0., 0., 1., 0., 4., 0., 2., 5.],
                                      [1., 4., 0., 2., 0., 5., 0., 0., 0.],
                                      [2., 1., 0., 3., 5., 0., 0., 6., 4.],
                                      [3., 6., 0., 0., 2., 5., 0., 0., 0.],
                                      [0., 0., 0., 3., 6., 0., 0., 2., 5.],
                                      [0., 0., 4., 0., 7., 1., 5., 0., 0.],
                                      [4., 5., 1., 0., 7., 0., 2., 0., 8.],
                                      [5., 6., 8., 4., 2., 7., 1., 9., 3.],
                                      [6., 0., 5., 3., 9., 8., 2., 0., 0.],
                                      [0., 0., 6., 0., 9., 3., 0., 5., 0.],
                                      [0., 0., 0., 7., 4., 0., 8., 0., 5.],
                                      [7., 0., 8., 0., 4., 5., 0., 0., 0.],
                                      [8., 5., 0., 7., 9., 4., 0., 0., 6.],
                                      [9., 0., 0., 8., 6., 0., 0., 0., 5.],
                                      [0., 0., 9., 0., 0., 6., 8., 0., 3.],
                                      [0., 0., 0., 7., 0., 0., 8., 4., 5.],
                                      [0., 0., 0., 7., 8., 0., 4., 0., 9.],
                                      [0., 0., 0., 8., 9., 7., 0., 5., 0.],
                                      [0., 0., 0., 9., 0., 8., 0., 6., 5.],
                                      [0., 0., 0., 9., 0., 0., 8., 6., 5.]]]])

# Create a multi-channel version for demonstration
# Let's say we have 3 channels
prime_multi_channel = torch.cat([
    prime_single_channel,  # Channel 0 (your original data)
    torch.rand_like(prime_single_channel),  # Channel 1 (random data)
    torch.rand_like(prime_single_channel)   # Channel 2 (random data)
], dim=1)

print(f"Multi-channel tensor shape: {prime_multi_channel.shape}")

# Filter based on first channel only
filtered_tensor, kept_indices = filter_non_zero_starting_rows_multichannel(prime_multi_channel)

print(f"Original shape: {prime_multi_channel.shape}")
print(f"Filtered shape: {filtered_tensor.shape}")
print(f"Kept row indices: {kept_indices}")

# Show the filtered first channel
print("\nFiltered first channel:")
print(filtered_tensor[0, 0, :, :])

Multi-channel tensor shape: torch.Size([1, 3, 25, 9])
Original shape: torch.Size([1, 3, 25, 9])
Filtered shape: torch.Size([1, 3, 9, 9])
Kept row indices: tensor([ 6,  7,  8, 11, 12, 13, 16, 17, 18])

Filtered first channel:
tensor([[1., 4., 0., 2., 0., 5., 0., 0., 0.],
        [2., 1., 0., 3., 5., 0., 0., 6., 4.],
        [3., 6., 0., 0., 2., 5., 0., 0., 0.],
        [4., 5., 1., 0., 7., 0., 2., 0., 8.],
        [5., 6., 8., 4., 2., 7., 1., 9., 3.],
        [6., 0., 5., 3., 9., 8., 2., 0., 0.],
        [7., 0., 8., 0., 4., 5., 0., 0., 0.],
        [8., 5., 0., 7., 9., 4., 0., 0., 6.],
        [9., 0., 0., 8., 6., 0., 0., 0., 5.]])


In [14]:
import torch

# Your original tensor
prime = torch.tensor([[[[0., 0., 0., 1., 0., 0., 4., 2., 5.],
                        [0., 0., 0., 1., 0., 2., 0., 4., 3.],
                        [0., 0., 0., 2., 3., 1., 5., 0., 0.],
                        [0., 0., 0., 3., 0., 2., 6., 0., 1.],
                        [0., 0., 0., 3., 0., 0., 6., 2., 5.],
                        [0., 0., 0., 1., 0., 4., 0., 2., 5.],
                        [1., 4., 0., 2., 0., 5., 0., 0., 0.],
                        [2., 1., 0., 3., 5., 0., 0., 6., 4.],
                        [3., 6., 0., 0., 2., 5., 0., 0., 0.],
                        [0., 0., 0., 3., 6., 0., 0., 2., 5.],
                        [0., 0., 4., 0., 7., 1., 5., 0., 0.],
                        [4., 5., 1., 0., 7., 0., 2., 0., 8.],
                        [5., 6., 8., 4., 2., 7., 1., 9., 3.],
                        [6., 0., 5., 3., 9., 8., 2., 0., 0.],
                        [0., 0., 6., 0., 9., 3., 0., 5., 0.],
                        [0., 0., 0., 7., 4., 0., 8., 0., 5.],
                        [7., 0., 8., 0., 4., 5., 0., 0., 0.],
                        [8., 5., 0., 7., 9., 4., 0., 0., 6.],
                        [9., 0., 0., 8., 6., 0., 0., 0., 5.],
                        [0., 0., 9., 0., 0., 6., 8., 0., 3.],
                        [0., 0., 0., 7., 0., 0., 8., 4., 5.],
                        [0., 0., 0., 7., 8., 0., 4., 0., 9.],
                        [0., 0., 0., 8., 9., 7., 0., 5., 0.],
                        [0., 0., 0., 9., 0., 8., 0., 6., 5.],
                        [0., 0., 0., 9., 0., 0., 8., 6., 5.]]]])

# Method 1: Create a mask for rows that don't start with 0
def filter_non_zero_starting_rows(tensor):
    # Get the shape
    b, c, num_rows, row_length = tensor.shape
    
    # Create mask for rows where first element is not 0
    mask = tensor[:, :, :, 0] != 0  # Shape: [b, c, num_rows]
    
    # Get indices of non-zero starting rows
    non_zero_indices = torch.where(mask[0, 0])[0]
    
    # Select only the rows that don't start with 0
    filtered_tensor = tensor[:, :, non_zero_indices, :]
    
    return filtered_tensor

# Apply the filter
filtered_prime = filter_non_zero_starting_rows(prime)

print(f"Original shape: {prime.shape}")
print(f"Filtered shape: {filtered_prime.shape}")
print("\nFiltered tensor:")
print(filtered_prime)

# Method 2: More concise using boolean indexing
def filter_non_zero_starting_rows_concise(tensor):
    # Reshape to 2D for easier indexing
    reshaped = tensor.squeeze()  # Remove batch and channel dimensions
    
    # Create mask for rows that don't start with 0
    mask = reshaped[:, 0] != 0
    
    # Filter rows
    filtered_rows = reshaped[mask]
    
    # Restore original dimensions
    return filtered_rows.unsqueeze(0).unsqueeze(0)

# Alternative method
filtered_prime_v2 = filter_non_zero_starting_rows_concise(prime)
print(f"\nMethod 2 - Filtered shape: {filtered_prime_v2.shape}")
print("Filtered tensor (Method 2):")
print(filtered_prime_v2)

Original shape: torch.Size([1, 1, 25, 9])
Filtered shape: torch.Size([1, 1, 9, 9])

Filtered tensor:
tensor([[[[1., 4., 0., 2., 0., 5., 0., 0., 0.],
          [2., 1., 0., 3., 5., 0., 0., 6., 4.],
          [3., 6., 0., 0., 2., 5., 0., 0., 0.],
          [4., 5., 1., 0., 7., 0., 2., 0., 8.],
          [5., 6., 8., 4., 2., 7., 1., 9., 3.],
          [6., 0., 5., 3., 9., 8., 2., 0., 0.],
          [7., 0., 8., 0., 4., 5., 0., 0., 0.],
          [8., 5., 0., 7., 9., 4., 0., 0., 6.],
          [9., 0., 0., 8., 6., 0., 0., 0., 5.]]]])

Method 2 - Filtered shape: torch.Size([1, 1, 9, 9])
Filtered tensor (Method 2):
tensor([[[[1., 4., 0., 2., 0., 5., 0., 0., 0.],
          [2., 1., 0., 3., 5., 0., 0., 6., 4.],
          [3., 6., 0., 0., 2., 5., 0., 0., 0.],
          [4., 5., 1., 0., 7., 0., 2., 0., 8.],
          [5., 6., 8., 4., 2., 7., 1., 9., 3.],
          [6., 0., 5., 3., 9., 8., 2., 0., 0.],
          [7., 0., 8., 0., 4., 5., 0., 0., 0.],
          [8., 5., 0., 7., 9., 4., 0., 0., 6.],

In [15]:
ones = torch.ones(1, 1, 3, 3) 
print(ones)
ones = F.pad(ones, (1, 1, 1, 1), mode='constant', value=0)
print(ones)
print(ones.shape)

tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
tensor([[[[0., 0., 0., 0., 0.],
          [0., 1., 1., 1., 0.],
          [0., 1., 1., 1., 0.],
          [0., 1., 1., 1., 0.],
          [0., 0., 0., 0., 0.]]]])
torch.Size([1, 1, 5, 5])


In [16]:
import torch
import torch.nn.functional as F

def get_3x3_windows_from_flattened(x_flat, original_shape=(5, 5)):
    """
    Extract all 3x3 convolutional windows from a flattened tensor
    
    Args:
        x_flat: Flattened tensor [1, 1, H*W]
        original_shape: Original 2D shape (H, W)
    
    Returns:
        windows: All 3x3 windows in 1D format [num_windows, 9]
    """
    h, w = original_shape
    
    # Reshape back to 2D
    x_2d = x_flat.reshape(1, 1, h, w)
    
    # Use unfold to get 3x3 patches
    # unfold(dim, size, step) extracts sliding windows
    patches = x_2d.unfold(2, 3, 1).unfold(3, 3, 1)  # [1, 1, H-2, W-2, 3, 3]
    
    # Reshape to get each window as a 1D vector
    num_windows_h, num_windows_w = patches.shape[2], patches.shape[3]
    windows = patches.reshape(num_windows_h * num_windows_w, 9)
    
    return windows, num_windows_h, num_windows_w

# Your input tensor
x_flat = torch.tensor([[[0., 0., 0., 0., 0., 
                        0., 1., 2., 3., 0., 
                        0., 4., 5., 6., 0., 
                        0., 7., 8., 9., 0., 
                        0., 0., 0., 0., 0.]]])

# Extract all 3x3 windows
windows, num_h, num_w = get_3x3_windows_from_flattened(x_flat)

print(f"Number of 3x3 windows: {windows.shape[0]} ({num_h}x{num_w})")
print(f"Each window has {windows.shape[1]} elements (3x3 = 9)")
print()

# Display all windows
for i, window in enumerate(windows):
    row = i // num_w
    col = i % num_w
    print(f"Window {i} (position [{row}, {col}]):")
    print(window.reshape(3, 3))
    print("Flattened:", window.tolist())
    print()
    

Number of 3x3 windows: 9 (3x3)
Each window has 9 elements (3x3 = 9)

Window 0 (position [0, 0]):
tensor([[0., 0., 0.],
        [0., 1., 2.],
        [0., 4., 5.]])
Flattened: [0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 0.0, 4.0, 5.0]

Window 1 (position [0, 1]):
tensor([[0., 0., 0.],
        [1., 2., 3.],
        [4., 5., 6.]])
Flattened: [0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]

Window 2 (position [0, 2]):
tensor([[0., 0., 0.],
        [2., 3., 0.],
        [5., 6., 0.]])
Flattened: [0.0, 0.0, 0.0, 2.0, 3.0, 0.0, 5.0, 6.0, 0.0]

Window 3 (position [1, 0]):
tensor([[0., 1., 2.],
        [0., 4., 5.],
        [0., 7., 8.]])
Flattened: [0.0, 1.0, 2.0, 0.0, 4.0, 5.0, 0.0, 7.0, 8.0]

Window 4 (position [1, 1]):
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
Flattened: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]

Window 5 (position [1, 2]):
tensor([[2., 3., 0.],
        [5., 6., 0.],
        [8., 9., 0.]])
Flattened: [2.0, 3.0, 0.0, 5.0, 6.0, 0.0, 8.0, 9.0, 0.0]

Window 6 (pos

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def _add_coordinate_encoding(x):
    """Add coordinate channels to input tensor"""
    b, c, h, w = x.shape
    
    # Create coordinate grids
    y_coords = torch.linspace(-1, 1, h, device=x.device)
    x_coords = torch.linspace(-1, 1, w, device=x.device)
    
    y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
    # Stack coordinates and add batch dimension
    coord_grid = torch.stack([x_grid, y_grid], dim=0).unsqueeze(0)  # [1, 2, H, W]
    coord_grid = coord_grid.expand(b, -1, -1, -1)  # [B, 2, H, W]
    
    return coord_grid

def _calculate_coordinate_similarity_matrix(coord_matrix, sigma=0.5):
    """Calculate similarity matrix based on coordinate distance"""
    b, c, t = coord_matrix.shape  # c should be 2 for (x, y) coordinates
    
    # Calculate pairwise Euclidean distances between coordinates
    coord_expanded_1 = coord_matrix.unsqueeze(3)  # [B, 2, T, 1]
    coord_expanded_2 = coord_matrix.unsqueeze(2)  # [B, 2, 1, T]
    
    # Euclidean distance between coordinates
    coord_diff = coord_expanded_1 - coord_expanded_2  # [B, 2, T, T]
    coord_dist = torch.sqrt(torch.sum(coord_diff ** 2, dim=1))  # [B, T, T]
    
    # Convert distance to similarity using Gaussian kernel
    similarity_matrix = torch.exp(-coord_dist ** 2 / (2 * sigma ** 2))
    
    return similarity_matrix

def get_spatial_neighbors(x, K=9, sigma=0.5):
    """
    Get K nearest spatial neighbors using coordinate-based similarity
    
    Args:
        x: Input tensor [B, C, H, W]
        K: Number of neighbors to select
        sigma: Standard deviation for Gaussian similarity kernel
    """
    b, c, h, w = x.shape
    
    # Add padding
    x_padded = F.pad(x, (1, 1, 1, 1), mode='constant', value=0)
    
    # Get coordinate encoding for the padded tensor
    coord_grid = _add_coordinate_encoding(x_padded)  # [B, 2, H+2, W+2]
    
    # Flatten spatial dimensions
    x_flat = x_padded.flatten(2)  # [B, C, (H+2)*(W+2)]
    coord_flat = coord_grid.flatten(2)  # [B, 2, (H+2)*(W+2)]
    
    print(f"Original shape: {x.shape}")
    print(f"Padded shape: {x_padded.shape}")
    print(f"Flattened coordinate shape: {coord_flat.shape}")
    
    # Calculate similarity matrix based on coordinates
    similarity_matrix = _calculate_coordinate_similarity_matrix(coord_flat, sigma=sigma)
    print(f"Similarity matrix shape: {similarity_matrix.shape}")
    
    # Get top K neighbors for each position
    topk_values, topk_indices = torch.topk(similarity_matrix, k=K, dim=2, largest=True)
    
    # Expand indices to match all channels
    topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, -1, K)  # [B, C, T, K]
    
    # Gather the neighboring values
    x_expanded = x_flat.unsqueeze(-1).expand(b, c, -1, K)
    neighbors = torch.gather(x_expanded, dim=2, index=topk_indices_exp)
    
    return neighbors, topk_indices, similarity_matrix

# Test with your example
def test_spatial_neighbors():
    # Create test tensor - your example reshaped to 2D
    x = torch.tensor([[[
        [0., 0., 0., 0., 0.],
        [0., 1., 2., 3., 0.],
        [0., 4., 5., 6., 0.],
        [0., 7., 8., 9., 0.],
        [0., 0., 0., 0., 0.]
    ]]])
    
    print("Input tensor:")
    print(x.squeeze())
    print()
    
    # Get 9 nearest neighbors
    neighbors, indices, sim_matrix = get_spatial_neighbors(x, K=9, sigma=0.3)
    
    # Look at neighbors for the position containing value '5' (center of 3x3 grid)
    # In the padded 5x5 grid, value '5' should be at position (2,2)
    # In flattened coordinates: 2*5 + 2 = 12
    center_pos = 12  # This should correspond to value '5'
    
    print(f"Neighbors for center position (value 5):")
    center_neighbors = neighbors[0, 0, center_pos, :]  # [K]
    print(center_neighbors)
    
    # Let's also check a few other positions
    print(f"\nNeighbors for position with value 1:")
    pos_1 = 6  # Position of value '1'
    neighbors_1 = neighbors[0, 0, pos_1, :]
    print(neighbors_1)
    
    return neighbors, indices, sim_matrix

# Alternative approach: Direct coordinate-based neighbor selection
def get_3x3_neighbors_direct(x):
    """
    Directly get 3x3 neighbors without similarity matrix calculation
    This mimics standard convolution receptive field
    """
    b, c, h, w = x.shape
    
    # Pad the input
    x_padded = F.pad(x, (1, 1, 1, 1), mode='constant', value=0)
    
    # Use unfold to get 3x3 patches
    patches = x_padded.unfold(2, 3, 1).unfold(3, 3, 1)  # [B, C, H, W, 3, 3]
    patches = patches.reshape(b, c, h*w, 9)  # [B, C, H*W, 9]
    
    return patches

# Test both approaches
if __name__ == "__main__":
    print("=== Coordinate-based similarity approach ===")
    neighbors, indices, sim_matrix = test_spatial_neighbors()
    
    print("\n=== Direct 3x3 patch approach ===")
    x = torch.tensor([[[
        [0., 0., 0., 0., 0.],
        [0., 1., 2., 3., 0.],
        [0., 4., 5., 6., 0.],
        [0., 7., 8., 9., 0.],
        [0., 0., 0., 0., 0.]
    ]]])
    
    patches = get_3x3_neighbors_direct(x)
    
    # Get patch for center position (1, 1) in original coordinates
    # This corresponds to value '5'
    center_patch = patches[0, 0, 1*3 + 1, :]  # 3x3 grid, position (1,1)
    print("3x3 patch around value 5:")
    print(center_patch.reshape(3, 3))

=== Coordinate-based similarity approach ===
Input tensor:
tensor([[0., 0., 0., 0., 0.],
        [0., 1., 2., 3., 0.],
        [0., 4., 5., 6., 0.],
        [0., 7., 8., 9., 0.],
        [0., 0., 0., 0., 0.]])

Original shape: torch.Size([1, 1, 5, 5])
Padded shape: torch.Size([1, 1, 7, 7])
Flattened coordinate shape: torch.Size([1, 2, 49])
Similarity matrix shape: torch.Size([1, 49, 49])
Neighbors for center position (value 5):
tensor([0., 0., 0., 0., 0., 3., 0., 0., 0.])

Neighbors for position with value 1:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 3.])

=== Direct 3x3 patch approach ===
3x3 patch around value 5:
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [3., 0., 0.]])


In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def _add_coordinate_encoding(x):
    """Add coordinate channels to input tensor"""
    b, c, h, w = x.shape
    
    # Create coordinate grids
    y_coords = torch.linspace(-1, 1, h, device=x.device)
    x_coords = torch.linspace(-1, 1, w, device=x.device)
    
    y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
    # Stack coordinates and add batch dimension
    coord_grid = torch.stack([x_grid, y_grid], dim=0).unsqueeze(0)  # [1, 2, H, W]
    coord_grid = coord_grid.expand(b, -1, -1, -1)  # [B, 2, H, W]
    
    return coord_grid

def _calculate_coordinate_similarity_matrix(coord_matrix, sigma=0.3):
    """Calculate similarity matrix based on coordinate distance"""
    b, c, t = coord_matrix.shape  # c should be 2 for (x, y) coordinates
    
    # Calculate pairwise Euclidean distances between coordinates
    coord_expanded_1 = coord_matrix.unsqueeze(3)  # [B, 2, T, 1]
    coord_expanded_2 = coord_matrix.unsqueeze(2)  # [B, 2, 1, T]
    
    # Euclidean distance between coordinates
    coord_diff = coord_expanded_1 - coord_expanded_2  # [B, 2, T, T]
    coord_dist = torch.sqrt(torch.sum(coord_diff ** 2, dim=1) + 1e-8)  # [B, T, T]
    
    # Convert distance to similarity using Gaussian kernel
    similarity_matrix = torch.exp(-coord_dist ** 2 / (2 * sigma ** 2))
    
    return similarity_matrix

def get_spatial_neighbors(x, K=9, sigma=0.3):
    """
    Get K nearest spatial neighbors using coordinate-based similarity
    """
    b, c, h, w = x.shape
    
    # Add padding
    x_padded = F.pad(x, (1, 1, 1, 1), mode='constant', value=0)
    padded_h, padded_w = x_padded.shape[2], x_padded.shape[3]
    
    # Get coordinate encoding for the padded tensor
    coord_grid = _add_coordinate_encoding(x_padded)  # [B, 2, H+2, W+2]
    
    # Flatten spatial dimensions
    x_flat = x_padded.flatten(2)  # [B, C, (H+2)*(W+2)]
    coord_flat = coord_grid.flatten(2)  # [B, 2, (H+2)*(W+2)]
    
    print(f"Original shape: {x.shape}")
    print(f"Padded shape: {x_padded.shape}")
    print(f"Flattened coordinate shape: {coord_flat.shape}")
    
    # Calculate similarity matrix based on coordinates
    similarity_matrix = _calculate_coordinate_similarity_matrix(coord_flat, sigma=sigma)
    print(f"Similarity matrix shape: {similarity_matrix.shape}")
    
    # Get top K neighbors for each position
    topk_values, topk_indices = torch.topk(similarity_matrix, k=K, dim=2, largest=True)
    
    # Expand indices to match all channels
    topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, -1, K)  # [B, C, T, K]
    
    # Gather the neighboring values
    x_expanded = x_flat.unsqueeze(-1).expand(b, c, -1, K)
    neighbors = torch.gather(x_expanded, dim=2, index=topk_indices_exp)
    
    return neighbors, topk_indices, similarity_matrix, x_flat, padded_h, padded_w

def find_value_position(x_flat, value, padded_h, padded_w):
    """Find the flattened position of a specific value"""
    # Find where the value occurs
    positions = (x_flat[0, 0, :] == value).nonzero(as_tuple=False).flatten()
    if len(positions) > 0:
        pos = positions[0].item()
        # Convert back to 2D coordinates for verification
        row = pos // padded_w
        col = pos % padded_w
        print(f"Value {value} found at flattened position {pos} (2D: row={row}, col={col})")
        return pos
    else:
        print(f"Value {value} not found")
        return None

def test_spatial_neighbors():
    # Create test tensor - your example
    x = torch.tensor([[[
        [0., 0., 0., 0., 0.],
        [0., 1., 2., 3., 0.],
        [0., 4., 5., 6., 0.],
        [0., 7., 8., 9., 0.],
        [0., 0., 0., 0., 0.]
    ]]])
    
    print("Input tensor:")
    print(x.squeeze())
    print()
    
    # Get 9 nearest neighbors
    neighbors, indices, sim_matrix, x_flat, padded_h, padded_w = get_spatial_neighbors(x, K=9, sigma=0.3)
    
    print("Padded and flattened tensor:")
    print("Shape:", x_flat.shape)
    print("Values:", x_flat[0, 0, :])
    print()
    
    # Find positions of values 5 and 1
    pos_5 = find_value_position(x_flat, 5.0, padded_h, padded_w)
    pos_1 = find_value_position(x_flat, 1.0, padded_h, padded_w)
    
    if pos_5 is not None:
        print(f"\nNeighbors for center position (value 5) at position {pos_5}:")
        center_neighbors = neighbors[0, 0, pos_5, :]
        print("Values:", center_neighbors)
        print("Indices:", indices[0, pos_5, :])
        
        # Show the neighbor positions in 2D grid for verification
        print("Neighbor positions in 2D grid:")
        for i, idx in enumerate(indices[0, pos_5, :]):
            row = idx.item() // padded_w
            col = idx.item() % padded_w
            val = center_neighbors[i].item()
            print(f"  Index {idx.item()}: (row={row}, col={col}), value={val}")
    
    if pos_1 is not None:
        print(f"\nNeighbors for position with value 1 at position {pos_1}:")
        neighbors_1 = neighbors[0, 0, pos_1, :]
        print("Values:", neighbors_1)
        print("Indices:", indices[0, pos_1, :])
        
        # Show the neighbor positions in 2D grid for verification
        print("Neighbor positions in 2D grid:")
        for i, idx in enumerate(indices[0, pos_1, :]):
            row = idx.item() // padded_w
            col = idx.item() % padded_w
            val = neighbors_1[i].item()
            print(f"  Index {idx.item()}: (row={row}, col={col}), value={val}")
    
    return neighbors, indices, sim_matrix

def get_3x3_neighbors_direct(x):
    """
    Directly get 3x3 neighbors without similarity matrix calculation
    """
    b, c, h, w = x.shape
    
    # Pad the input
    x_padded = F.pad(x, (1, 1, 1, 1), mode='constant', value=0)
    
    # Use unfold to get 3x3 patches
    patches = x_padded.unfold(2, 3, 1).unfold(3, 3, 1)  # [B, C, H, W, 3, 3]
    patches = patches.reshape(b, c, h*w, 9)  # [B, C, H*W, 9]
    
    return patches

# Test both approaches
if __name__ == "__main__":
    print("=== Coordinate-based similarity approach ===")
    neighbors, indices, sim_matrix = test_spatial_neighbors()
    

=== Coordinate-based similarity approach ===
Input tensor:
tensor([[0., 0., 0., 0., 0.],
        [0., 1., 2., 3., 0.],
        [0., 4., 5., 6., 0.],
        [0., 7., 8., 9., 0.],
        [0., 0., 0., 0., 0.]])

Original shape: torch.Size([1, 1, 5, 5])
Padded shape: torch.Size([1, 1, 7, 7])
Flattened coordinate shape: torch.Size([1, 2, 49])
Similarity matrix shape: torch.Size([1, 49, 49])
Padded and flattened tensor:
Shape: torch.Size([1, 1, 49])
Values: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 2.,
        3., 0., 0., 0., 0., 4., 5., 6., 0., 0., 0., 0., 7., 8., 9., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

Value 5.0 found at flattened position 24 (2D: row=3, col=3)
Value 1.0 found at flattened position 16 (2D: row=2, col=2)

Neighbors for center position (value 5) at position 24:
Values: tensor([5., 2., 4., 6., 8., 1., 3., 7., 9.])
Indices: tensor([24, 17, 23, 25, 31, 16, 18, 30, 32])
Neighbor positions in 2D grid:
  Ind

# MyTopK function
- need to make selecting Ks more stable (same pattern for when comparing same similarity)

In [1]:
import torch

# Original tensor
data = torch.tensor([0.1, 0.5, 0.2, 0.5, 0.3])
k = 2

# Get sorted values and indices with stable tie-breaking
sorted_values, sorted_indices = torch.sort(data, descending=True, stable=True)

# Select the top-k elements and their original indices
topk_values = sorted_values[:k]
topk_original_indices = sorted_indices[:k]

print(f"Top-k values: {topk_values}")
print(f"Original indices of top-k values: {topk_original_indices}")

Top-k values: tensor([0.5000, 0.5000])
Original indices of top-k values: tensor([1, 3])


In [3]:
# Original tensor
data = torch.tensor([0.1, 0.5, 0.2, 0.5, 0.3])
k = 2

topk_values, topk_indices = torch.topk(data, k, largest=True)
print(f"Top-k values: {topk_values}")
print(f"Original indices of top-k values: {topk_indices}")

Top-k values: tensor([0.5000, 0.5000])
Original indices of top-k values: tensor([1, 3])
