# ConvNN Spatial Sampling Performance Test

- Spatial Sampling method's performance is drastically poor, need to figure out why and what is going on. 
- Try to use the same random seed for all tests 
- Try n = n for sample, this should be same as doing all samples

In [1]:
# Torch
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch import optim 



# Train + Data 
import sys 
sys.path.append('../2D_Modules')
from data import MNIST, FashionMNIST, CIFAR10
from train import train_model, evaluate_accuracy 
from pixelshuffle import * 
from Conv1d_NN_spatial import *
from Conv1d_NN import * 
from utils import * 


# other
import time


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
set_seed(40)

## I. 2D Testing - Performance Test

### i. Conv2d_NN_spatial

In [47]:
class Conv2d_NN_spatial(nn.Module): 
   def __init__(self, in_channels, out_channels, K=3, stride=3, padding=0, shuffle_pattern="N/A", shuffle_scale=2, samples=3, sample_padding=0, magnitude_type="distance"): 
      super().__init__()
      ### in_channels + out_channels must be shuffle_scale**2
      self.in_channels = in_channels
      self.out_channels = out_channels
      self.K = K
      self.stride = stride
      self.padding = padding
      self.shuffle_pattern = shuffle_pattern
      self.shuffle_scale = shuffle_scale
      self.samples = int(samples)
      self.sample_padding = sample_padding
      self.magnitude_type = magnitude_type
      
      self.upscale = PixelShuffle1D(upscale_factor=self.shuffle_scale)
      
      self.downscale = PixelUnshuffle1D(downscale_factor=self.shuffle_scale)
      
      self.Conv1d_NN_spatial = Conv1d_NN_spatial(in_channels=self.in_channels,
                                 out_channels=self.out_channels,
                                 K=self.K,
                                 stride=self.stride,
                                 padding=self.padding,
                                 shuffle_pattern=self.shuffle_pattern,
                                 shuffle_scale=self.shuffle_scale, 
                                 samples=self.samples, 
                                 magnitude_type=self.magnitude_type
                                 )
                                 
      
      
      self.flatten = nn.Flatten(start_dim=2)
      
      
   def forward(self, x): 
      # Ex. Original Size (32, 1, 28, 28) 
      first_time = time.time()
      
      start_time = time.time()

      x_ind = torch.round(torch.linspace(0 + self.sample_padding, x.shape[2] - self.sample_padding - 1, self.samples)).to(torch.int)
      y_ind = torch.round(torch.linspace(0 + self.sample_padding, x.shape[3] - self.sample_padding - 1, self.samples)).to(torch.int)
      x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')
      x_sample = torch.flatten(x[:, :, x_grid, y_grid], 2) # shape [32, 1, 25] if sample == 5 
      
      end_time = time.time()
      print(f"x_sample grid : {end_time - start_time}")
      
      
      # Flatten Layer : size (32, 1, 784)
      x1 = self.flatten(x)
      
      start_time = time.time()
      # Conv1d_NN Layer
      x2 = self.Conv1d_NN_spatial(x1, x_sample)
      end_time = time.time()
      print(f"Conv1d_NN_spatial : {end_time - start_time}")
      
      # Unflatten Layer 
      unflatten = nn.Unflatten(dim=2, unflattened_size=x.shape[2:])
      x3 = unflatten(x2)

      final_time = time.time()
      print(f"Final Time : {final_time - first_time}")
      return x3

In [48]:
ex = torch.rand(32, 1, 28, 28) 
print("Input: ", ex.shape)

conv2d_nn_spatial = Conv2d_NN_spatial(in_channels=1, out_channels=3, K=3, stride=3, padding=0, shuffle_pattern="N/A", shuffle_scale=2, samples=5, sample_padding= 3, magnitude_type="similarity")
output = conv2d_nn_spatial(ex)
print("Output: ", output.shape) # [32, 3, 784]

Input:  torch.Size([32, 1, 28, 28])
x_sample grid : 0.0007679462432861328
Conv1d_NN_spatial : 0.0058858394622802734
Final Time : 0.006787776947021484
Output:  torch.Size([32, 3, 28, 28])


## i. Conv2d_NN


In [49]:
class Conv2d_NN(nn.Module): 
   def __init__(self, in_channels, out_channels, K=3, stride=3, padding=0, shuffle_pattern="N/A", shuffle_scale=2, samples="all", magnitude_type="distance"): 
      super().__init__()
      ### in_channels + out_channels must be shuffle_scale**2
      self.in_channels = in_channels
      self.out_channels = out_channels
      self.K = K
      self.stride = stride
      self.padding = padding
      self.shuffle_pattern = shuffle_pattern
      self.shuffle_scale = shuffle_scale
      self.samples = samples
      self.magnitude_type = magnitude_type
      
      self.upscale = PixelShuffle1D(upscale_factor=self.shuffle_scale)
      
      self.downscale = PixelUnshuffle1D(downscale_factor=self.shuffle_scale)
      
      self.Conv1d_NN = Conv1d_NN(in_channels=self.in_channels * shuffle_scale **2,
                                 out_channels=self.out_channels * shuffle_scale **2,
                                 K=self.K,
                                 stride=self.stride,
                                 padding=self.padding,
                                 shuffle_pattern=self.shuffle_pattern,
                                 shuffle_scale=self.shuffle_scale, 
                                 samples=self.samples, 
                                 magnitude_type=self.magnitude_type
                                 )
                                 
      
      
      self.flatten = nn.Flatten(start_dim=2)
      
      
   def forward(self, x): 
      # Ex. Original Size (32, 1, 28, 28) 
      
      first_time = time.time()
      # Unshuffle Layer 
      # Ex. (32, 16, 7, 7) if upscale_factor = 4
      x1 = nn.functional.pixel_unshuffle(x, self.shuffle_scale)

      # print("Unshuffle: ", x1.shape)
      
      # Flatten Layer 
      # Ex. (32, 16, 49) 
      x2 = self.flatten(x1)
      # print("Flatten: ", x2.shape)
      
      # Conv1d_NN Layer
      # Ex. (32, 16, 49) 
      x3 = self.Conv1d_NN(x2)  
      # print("Conv1d_NN: ", x3.shape)
      
      # Unflatten Layer 
      # Ex. (32, 16, 7, 7)
      unflatten = nn.Unflatten(dim=2, unflattened_size=x1.shape[2:])
      x4 = unflatten(x3)
      # print("Unflatten: ", x4.shape)
      
      # Shuffle Layer 
      # Ex. (32, 16, 28, 28)
      x5 = nn.functional.pixel_shuffle(x4, self.shuffle_scale)
      # print("Shuffle: ", x5.shape)
      
      final_time = time.time()
      print(f"Final Time : {final_time - first_time}")
      return x5


In [51]:
ex = torch.rand(32, 1, 28, 28) 
print("Input: ", ex.shape)

conv2d_nn = Conv2d_NN(in_channels=1, out_channels=3, K=3, stride=3, padding=0, shuffle_pattern="N/A", shuffle_scale=2, samples=5, magnitude_type="similarity")
output = conv2d_nn(ex)
print("Output: ", output.shape) # [32, 3, 784]

Input:  torch.Size([32, 1, 28, 28])
Final Time : 0.011232137680053711
Output:  torch.Size([32, 3, 28, 28])


In [34]:

class Conv1d_NN_spatial(nn.Module): 
    def __init__(self, in_channels, out_channels, K=3, stride=3, padding=0, 
                 shuffle_pattern='N/A', shuffle_scale=2, 
                 samples='all', 
                 magnitude_type='distance'): 
        
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.stride = stride 
        self.padding = padding
        self.shuffle_pattern = shuffle_pattern 
        self.shuffle_scale = shuffle_scale
        
        self.samples = int(samples) if samples != 'all' else samples # Number of samples to consider
        self.magnitude_type = magnitude_type # Nearest Neighbor based on Distance or Similarity 
        self.maximum = True if self.magnitude_type == 'similarity' else False # Minimum or Maximum for Distance or Similarity
        
        # Unshuffle layer 
        self.unshuffle_layer = PixelUnshuffle1D(downscale_factor=self.shuffle_scale)
        
        # Shuffle Layer 
        self.shuffle_layer = PixelShuffle1D(upscale_factor=self.shuffle_scale)
                
        # Conv1d Layer
        self.in_channels = in_channels * shuffle_scale if self.shuffle_pattern in ["BA", "B"] else in_channels
        self.out_channels = out_channels * shuffle_scale if self.shuffle_pattern in ["BA", "A"] else out_channels

        # Conv1d Layer 
        self.conv1d_layer = nn.Conv1d(in_channels=self.in_channels, 
                                      out_channels=self.out_channels, 
                                      kernel_size=self.K, 
                                      stride=self.stride, 
                                      padding=self.padding)

        self.relu = nn.ReLU()

    def forward(self, x, y): 

        # Unshuffle Layer 
        if self.shuffle_pattern in ["B", "BA"]:
            x1 = self.unshuffle_layer(x)
        else:
            x1 = x
            
        if self.magnitude_type == 'distance':
            matrix_magnitude = self.calculate_distance_matrix_N(x1, y)
        elif self.magnitude_type == 'similarity':
            matrix_magnitude = self.calculate_similarity_matrix_N(x1, y)        
        
        prime = self.prime_vmap_2d_N(x1, matrix_magnitude, self.K, self.maximum)
        
        # Conv1d Layer
        x2 = self.conv1d_layer(prime)
        
        # ReLU Activation
        x3 = self.relu(x2)
        
        # Shuffle Layer
        if self.shuffle_pattern in ["A", "BA"]:
            x4 = self.shuffle_layer(x3)
        else:
            x4 = x3
        
        return x4
        
    ### N Samples ### 
    '''Distance Matrix Calculations for N Sample'''
    @staticmethod 
    def calculate_distance_matrix_N(matrix, matrix_sample):
        '''Calculate distance matrix between two input matrices''' 
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
        dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
        return torch.sqrt(dist_matrix)
        
    '''Similarity Matrix Calculations for N Sample'''
    @staticmethod
    def calculate_similarity_matrix_N(matrix, matrix_sample): 
        normalized_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        normalized_matrix_sample = F.normalize(matrix_sample, p=2, dim=1)
        similarity_matrix = dot_product = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix_sample)
        return similarity_matrix

    '''N Sample Methods'''
    @staticmethod
    def prime_vmap_2d_N(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
        '''Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D'''
        batched_process = torch.vmap(Conv1d_NN_spatial.process_batch_N, in_dims=(0, 0, None), out_dims=0)
        prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=True, maximum=maximum)
        return prime 
    
    @staticmethod
    def prime_vmap_3d_N(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
        '''Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D'''
        batched_process = torch.vmap(Conv1d_NN_spatial.process_batch_N, in_dims=(0, 0, None), out_dims=0)
        prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=False, maximum=maximum)
        return prime
    
    @staticmethod 
    def process_batch_N(matrix, magnitude_matrix, num_nearest_neighbors, flatten, maximum): 
        # Process the batch of matrices
        ind = torch.topk(magnitude_matrix, num_nearest_neighbors, largest=maximum).indices 
        neigh = matrix[:, ind]
        if flatten: 
            reshape = torch.flatten(neigh, start_dim=1)
            return reshape
        return neigh
    


In [35]:
# Example
ex = torch.rand(32, 1, 28, 28) 
print("Input: ", ex.shape)

conv2d_nn = Conv2d_NN(in_channels=1, out_channels=3, K=3, stride=3, padding=0, shuffle_pattern="N/A", shuffle_scale=2, samples=5)
output = conv2d_nn(ex)
print("Output: ", output.shape)
      

Input:  torch.Size([32, 1, 28, 28])
Output:  torch.Size([32, 3, 28, 28])


### ii. Models

In [36]:
# all sample model
conv2d_nn = nn.Sequential(
   Conv2d_NN(
      in_channels=1,
      out_channels=5,
      K=5,
      stride=5,
      padding=0,
      samples="all", 
      shuffle_scale=2
   ), 
   Conv2d_NN(
      in_channels=5,
      out_channels=10,
      K=5,
      stride=5,
      padding=0,
      samples="all", 
      shuffle_scale=2
   ),
   Conv2d_NN(
      in_channels=10,
      out_channels=20,
      K=5,
      stride=5,
      padding=0,
      samples="all", 
      shuffle_scale=2
   ),
   nn.Flatten(), 
   nn.Linear(15680, 10)
   
).to('cpu')
   

from torchsummary import summary
summary(conv2d_nn, (1, 28, 28))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1               [-1, 4, 196]               0
            Conv1d-2              [-1, 20, 196]             420
              ReLU-3              [-1, 20, 196]               0
         Conv1d_NN-4              [-1, 20, 196]               0
         Conv2d_NN-5            [-1, 5, 28, 28]               0
           Flatten-6              [-1, 20, 196]               0
            Conv1d-7              [-1, 40, 196]           4,040
              ReLU-8              [-1, 40, 196]               0
         Conv1d_NN-9              [-1, 40, 196]               0
        Conv2d_NN-10           [-1, 10, 28, 28]               0
          Flatten-11              [-1, 40, 196]               0
           Conv1d-12              [-1, 80, 196]          16,080
             ReLU-13              [-1, 80, 196]               0
        Conv1d_NN-14              [-1, 

In [37]:
# spatial model with all samples
conv2d_nn_spatial_all = nn.Sequential(
   Conv2d_NN_spatial(
      in_channels=1,
      out_channels=5,
      K=5,
      stride=5,
      padding=0,
      shuffle_scale=2, 
      samples=28
      
   ), 
   Conv2d_NN_spatial(
      in_channels=5,
      out_channels=10,
      K=5,
      stride=5,
      padding=0,
      shuffle_scale=2, 
      samples=28
   ),
   Conv2d_NN_spatial(
      in_channels=10,
      out_channels=20,
      K=5,
      stride=5,
      padding=0,
      shuffle_scale=2, 
      samples=28
   ),
   nn.Flatten(), 
   nn.Linear(15680, 10)
   
).to('cpu')
   

from torchsummary import summary
summary(conv2d_nn_spatial_all, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1               [-1, 1, 784]               0
            Conv1d-2               [-1, 5, 784]              30
              ReLU-3               [-1, 5, 784]               0
 Conv1d_NN_spatial-4               [-1, 5, 784]               0
 Conv2d_NN_spatial-5            [-1, 5, 28, 28]               0
           Flatten-6               [-1, 5, 784]               0
            Conv1d-7              [-1, 10, 784]             260
              ReLU-8              [-1, 10, 784]               0
 Conv1d_NN_spatial-9              [-1, 10, 784]               0
Conv2d_NN_spatial-10           [-1, 10, 28, 28]               0
          Flatten-11              [-1, 10, 784]               0
           Conv1d-12              [-1, 20, 784]           1,020
             ReLU-13              [-1, 20, 784]               0
Conv1d_NN_spatial-14              [-1, 

In [38]:
# spatial model with 1 sample
conv2d_nn_spatial_1 = nn.Sequential(
   Conv2d_NN_spatial(
      in_channels=1,
      out_channels=5,
      K=1,
      stride=1,
      padding=0,
      shuffle_scale=2, 
      samples=1
      
   ), 
   Conv2d_NN_spatial(
      in_channels=5,
      out_channels=10,
      K=1,
      stride=1,
      padding=0,
      shuffle_scale=2, 
      samples=1
   ),
   Conv2d_NN_spatial(
      in_channels=10,
      out_channels=20,
      K=1,
      stride=1,
      padding=0,
      shuffle_scale=2, 
      samples=1
   ),
   nn.Flatten(), 
   nn.Linear(15680, 10)
   
).to('cpu')
   

from torchsummary import summary
summary(conv2d_nn_spatial_all, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
           Flatten-1               [-1, 1, 784]               0
            Conv1d-2               [-1, 5, 784]              30
              ReLU-3               [-1, 5, 784]               0
 Conv1d_NN_spatial-4               [-1, 5, 784]               0
 Conv2d_NN_spatial-5            [-1, 5, 28, 28]               0
           Flatten-6               [-1, 5, 784]               0
            Conv1d-7              [-1, 10, 784]             260
              ReLU-8              [-1, 10, 784]               0
 Conv1d_NN_spatial-9              [-1, 10, 784]               0
Conv2d_NN_spatial-10           [-1, 10, 28, 28]               0
          Flatten-11              [-1, 10, 784]               0
           Conv1d-12              [-1, 20, 784]           1,020
             ReLU-13              [-1, 20, 784]               0
Conv1d_NN_spatial-14              [-1, 

### i. Data + Training

In [39]:
# MNIST
mnist = MNIST()

In [40]:
# ConvNN_2D
conv2d_nn.to('mps')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(conv2d_nn.parameters(), lr=0.001)
num_epochs = 10 
train_model(conv2d_nn, mnist.train_loader, criterion, optimizer, num_epochs)
evaluate_accuracy(conv2d_nn, mnist.test_loader)

Epoch 1, Time: 42.84804010391235, Loss: 0.373313202230788
Epoch 2, Time: 42.819385051727295, Loss: 0.22975348151410058
Epoch 3, Time: 42.46953296661377, Loss: 0.19280044169925742
Epoch 4, Time: 42.88373398780823, Loss: 0.16800292709600062
Epoch 5, Time: 41.99836182594299, Loss: 0.15230869689682272
Epoch 6, Time: 42.4278609752655, Loss: 0.14082590549159596
Epoch 7, Time: 44.04428005218506, Loss: 0.1331104949942784
Epoch 8, Time: 42.47032809257507, Loss: 0.11953459918669769
Epoch 9, Time: 41.90473818778992, Loss: 0.11008217206288344
Epoch 10, Time: 41.969505071640015, Loss: 0.10433385084901474

 Average epoch time: 42.58357663154602
Accuracy on test set: 94.86%


94.86

In [41]:
# ConvNN_2D_Spatial all samples
conv2d_nn_spatial_all.to('mps')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(conv2d_nn_spatial_all.parameters(), lr=0.001)
num_epochs = 10 
train_model(conv2d_nn_spatial_all, mnist.train_loader, criterion, optimizer, num_epochs)
evaluate_accuracy(conv2d_nn_spatial_all, mnist.test_loader)

KeyboardInterrupt: 

In [42]:
# ConvNN_2D_Spatial 1 sample
conv2d_nn_spatial_1.to('mps')

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(conv2d_nn_spatial_1.parameters(), lr=0.001)
num_epochs = 10 
train_model(conv2d_nn_spatial_1, mnist.train_loader, criterion, optimizer, num_epochs)
evaluate_accuracy(conv2d_nn_spatial_1, mnist.test_loader)

Epoch 1, Time: 601.8950850963593, Loss: 2.3103534903353466
Epoch 2, Time: 3276.250557899475, Loss: 2.301367968892746


KeyboardInterrupt: 