## Speed comparison with different convolution 

In [1]:
# Torch
import torch 
from torch import nn
from torch.nn import Conv1d
import torch.nn.functional as F

import numpy as np


# Train + Data 
import sys 
sys.path.append('../Layers')
from pixelshuffle import PixelShuffle1D, PixelUnshuffle1D

import time



### 1. Convolutional layers time test
- Conv1d(3, 3, kernel_size=3, stride=1, padding=1) vs Conv1d(3, 3, kernel_size=3, stride=3)
-  (32, 3, 224) vs. (32, 3, 224*3)


In [2]:
# x = torch.randn(32, 3, 224).to('mps')

# original = nn.Conv1d(3, 3, 3, stride=1, padding=1).to('mps')

# start = time.time()

# for i in range(1000):
#     o_out = original(x)
    
# end = time.time()
# print("Original Conv1d Time: ", end - start)

# print(o_out.shape)

In [3]:
# x1 = torch.randn(32, 3, 224*3).to('mps')

# prime = nn.Conv1d(3, 3, 3, stride = 3).to('mps')

# start = time.time()
# for i in range(1000):
#     p_out = prime(x1)
# end = time.time()

# print("Prime Conv1d Time: ", end - start)

# print(p_out.shape)

#### i. (32, 3, 224) Conv1d(3, 3, kernel_size=3, stride=1, padding=1)


In [4]:
import torch
import torch.nn as nn
import time

# Check if MPS is available
if torch.backends.mps.is_available():
    device_mps = torch.device("mps")
    print("MPS device is available!")
else:
    device_mps = torch.device("cpu")
    print("MPS device not found, using CPU instead")

device_cpu = torch.device("cpu")

# Create proper input for Conv1d: [batch_size, channels, length]
batch_size = 800
x = torch.randn(batch_size, 3, 150)

# Define models and optimizer
def benchmark_device(device, iterations=100):
    # Move data and model to the specified device
    x_device = x.to(device)
    model = nn.Conv1d(3, 16, kernel_size=3, stride=1, padding=1).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Warmup
    for _ in range(5):
        out = model(x_device)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Benchmark
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    start = time.time()
    
    for _ in range(iterations):
        # Forward pass
        out = model(x_device)
        
        # Compute "loss"
        loss = out.sum()
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        optimizer.zero_grad()
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    return end - start

# Run benchmarks
print(f"Running benchmark with batch size {batch_size}")

# time_cpu = benchmark_device(device_cpu)
# print(f"CPU Time: {time_cpu:.4f} seconds")

if torch.backends.mps.is_available():
    time_mps = benchmark_device(device_mps)
    print(f"MPS Time: {time_mps:.4f} seconds")
    # print(f"MPS Speedup: {time_cpu/time_mps:.2f}x")

MPS device is available!
Running benchmark with batch size 800


  from .autonotebook import tqdm as notebook_tqdm


MPS Time: 0.0796 seconds


#### ii. (32, 3, 224*3) Conv1d(3, 3, kernel_size=3, stride=3)


In [1]:
import torch
import torch.nn as nn
import time

# Check if MPS is available
if torch.backends.mps.is_available():
    device_mps = torch.device("mps")
    print("MPS device is available!")
else:
    device_mps = torch.device("cpu")
    print("MPS device not found, using CPU instead")

device_cpu = torch.device("cpu")

# Create proper input for Conv1d: [batch_size, channels, length]
batch_size = 800
x = torch.randn(batch_size, 3, 150*3)

# Define models and optimizer
def benchmark_device(device, iterations=100):
    # Move data and model to the specified device
    x_device = x.to(device)
    model = nn.Conv1d(3, 16, kernel_size=3, stride=3).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Warmup
    for _ in range(5):
        out = model(x_device)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Benchmark
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    start = time.time()
    
    for _ in range(iterations):
        # Forward pass
        out = model(x_device)
        
        # Compute "loss"
        loss = out.sum()
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        optimizer.zero_grad()
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    return end - start

# Run benchmarks
print(f"Running benchmark with batch size {batch_size}")

# time_cpu = benchmark_device(device_cpu)
# print(f"CPU Time: {time_cpu:.4f} seconds")

if torch.backends.mps.is_available():
    time_mps = benchmark_device(device_mps)
    print(f"MPS Time: {time_mps:.4f} seconds")
    # print(f"MPS Speedup: {time_cpu/time_mps:.2f}x")

MPS device is available!
Running benchmark with batch size 800


  from .autonotebook import tqdm as notebook_tqdm


MPS Time: 0.0789 seconds


- CPU time is way faster for nn.Conv1d(3, 16, kernel_size=3, stride=3) for (3, 224*3)
- GPU time is slower for nn.Conv1d(3, 16, kernel_size=3, stride=1, padding=1) for (3, 224)


### 2. ConvNN 1d close examine

Original

In [6]:
class Conv1d_NN(nn.Module):
    """
    Convolution 1D Nearest Neighbor Layer for Convolutional Neural Networks.
    
    Attributes:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        K (int): Number of Nearest Neighbors for consideration.
        stride (int): Stride size.
        padding (int): Padding size.
        shuffle_pattern (str): Shuffle pattern.
        shuffle_scale (int): Shuffle scale factor.
        samples (int/str): Number of samples to consider.
        magnitude_type (str): Distance or Similarity.
        
    Notes:
        - K must be same as stride. K == stride.
    """
    
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 K=3, 
                 stride=3, 
                 padding=0, 
                 shuffle_pattern='N/A', 
                 shuffle_scale=2, 
                 samples='all', 
                 magnitude_type='similarity'
                 ): 
        
        """
        Initializes the Conv1d_NN module.
        
        Parameters:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            K (int): Number of Nearest Neighbors for consideration.
            stride (int): Stride size.
            padding (int): Padding size.
            shuffle_pattern (str): Shuffle pattern: "B", "A", "BA".
            shuffle_scale (int): Shuffle scale factor.
            samples (int/str): Number of samples to consider.
            magnitude_type (str): Distance or Similarity.
        """
        
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.stride = stride 
        self.padding = padding
        self.shuffle_pattern = shuffle_pattern 
        self.shuffle_scale = shuffle_scale
        self.samples = int(samples) if samples != 'all' else samples 
        self.magnitude_type = magnitude_type 
        self.maximum = True if self.magnitude_type == 'similarity' else False
        
        # Unshuffle layer 
        self.unshuffle_layer = PixelUnshuffle1D(downscale_factor=self.shuffle_scale)
        
        # Shuffle Layer 
        self.shuffle_layer = PixelShuffle1D(upscale_factor=self.shuffle_scale)
        
        # Channels for Conv1d Layer
        self.in_channels = in_channels * shuffle_scale if self.shuffle_pattern in ["BA", "B"] else in_channels
        self.out_channels = out_channels * shuffle_scale if self.shuffle_pattern in ["BA", "A"] else out_channels

        # Conv1d Layer 
        self.conv1d_layer = Conv1d(in_channels=self.in_channels, 
                                    out_channels=self.out_channels, 
                                    kernel_size=self.K, 
                                    stride=self.stride, 
                                    padding=self.padding)
        
        self.random_idx_times = []

        
        self.matrix_magnitude_times = []
        self.prime_times = []
        self.conv1d_times = []
        

    def forward(self, x): 
        # Consider all samples 
        if self.samples == 'all': 
            # Unshuffle Layer 
            if self.shuffle_pattern in ["B", "BA"]:
                x1 = self.unshuffle_layer(x)
            else:
                x1 = x
            
            matrix_magnitude_start = time.time()
            # Calculate Distance/Similarity Matrix + Prime Vmap 2D
            if self.magnitude_type == 'distance': 
                matrix_magnitude = self.calculate_distance_matrix(x1)
            elif self.magnitude_type == 'similarity':
                matrix_magnitude = self.calculate_similarity_matrix(x1)
            matrix_magnitude_end = time.time()
            self.matrix_magnitude_times.append(matrix_magnitude_end - matrix_magnitude_start)
                
            prime_start = time.time()
            prime_2d = self.prime_vmap_2d(x1, matrix_magnitude, self.K, self.maximum) 
            prime_end = time.time()
            self.prime_times.append(prime_end - prime_start)
            
            # Conv1d Layer
            conv1d_start = time.time()
            x2 = self.conv1d_layer(prime_2d)
            conv1d_end = time.time()
            self.conv1d_times.append(conv1d_end - conv1d_start)
            
            
            # Shuffle Layer 
            if self.shuffle_pattern in ["A", "BA"]:
                x3 = self.shuffle_layer(x2)
            else:
                x3 = x2
            
            return x3
        
        # Consider N samples
        else: 
            # Unshuffle Layer 
            if self.shuffle_pattern in ["B", "BA"]:
                x1 = self.unshuffle_layer(x)
            else:
                x1 = x
                
                
            random_idx_start = time.time()
            # Calculate Distance/Similarity Matrix + Prime Vmap 2D
            rand_idx = torch.randperm(x1.shape[2], device=x1.device)[:self.samples]
            x1_sample = x1[:, :, rand_idx]
            random_idx_end = time.time()
            self.random_idx_times.append(random_idx_end - random_idx_start)
            
            
            
            matrix_magnitude_start = time.time()
            if self.magnitude_type == 'distance':
                matrix_magnitude = self.calculate_distance_matrix_N(x1, x1_sample)
            elif self.magnitude_type == 'similarity':
                matrix_magnitude = self.calculate_similarity_matrix_N(x1, x1_sample)
                
            if self.magnitude_type == 'distance':
                matrix_magnitude[:, rand_idx, np.arange(len(rand_idx))] = np.inf 
            elif self.magnitude_type == 'similarity':
                matrix_magnitude[:, rand_idx, np.arange(len(rand_idx))] = -np.inf
            matrix_magnitude_end = time.time()
            self.matrix_magnitude_times.append(matrix_magnitude_end - matrix_magnitude_start)
            
            
            matrix_magnitude_end = time.time()
            self.matrix_magnitude_times.append(matrix_magnitude_end - matrix_magnitude_start)
            
            prime_start = time.time()
            prime = self.prime_vmap_2d_N(x1, matrix_magnitude, self.K, rand_idx, self.maximum)
            prime_end = time.time()
            self.prime_times.append(prime_end - prime_start)
            
            # Conv1d Layer
            conv1d_start = time.time()
            x2 = self.conv1d_layer(prime)
            conv1d_end = time.time()
            self.conv1d_times.append(conv1d_end - conv1d_start)
            
            # Shuffle Layer
            if self.shuffle_pattern in ["A", "BA"]:
                x3 = self.shuffle_layer(x2)
            else:
                x3 = x2
            
            return x3
    
    ### All Samples ###
    @staticmethod
    def calculate_distance_matrix(matrix):
        """Calculates distance matrix of the input matrix"""
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix)
        dist_matrix = norm_squared + norm_squared.transpose(2, 1) - 2 * dot_product
        return torch.sqrt(dist_matrix)

    @staticmethod 
    def calculate_similarity_matrix(matrix): 
        """Calculates similarity matrix of the input matrix"""
        normalized_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        dot_product = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix)
        similarity_matrix = dot_product 
        return similarity_matrix
    
    @staticmethod 
    def prime_vmap_2d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
        """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D"""
        batched_process = torch.vmap(Conv1d_NN.process_batch, in_dims=(0, 0, None), out_dims=0)
        prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=True, maximum=maximum)
        return prime 

    @staticmethod 
    def prime_vmap_3d(matrix, magnitude_matrix, num_nearest_neighbors, maximum): 
        """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D"""
        batched_process = torch.vmap(Conv1d_NN.process_batch, in_dims=(0, 0, None), out_dims=0)
        prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, flatten=False, maximum=maximum)
        return prime

    @staticmethod 
    def process_batch(matrix, magnitude_matrix, num_nearest_neighbors, flatten, maximum): 
        """Process the batch of matrices by finding the K nearest neighbors with reshaping."""
        ind = torch.topk(magnitude_matrix, num_nearest_neighbors, largest=maximum).indices 
        neigh = matrix[:, ind]
        if flatten: 
            reshape = torch.flatten(neigh, start_dim=1)
            return reshape
        return neigh
    
    ### N Samples ### 
    @staticmethod 
    def calculate_distance_matrix_N(matrix, matrix_sample):
        """Calculates distance matrix between two input matrices""" 
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
        dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
        return torch.sqrt(dist_matrix)
        
    @staticmethod
    def calculate_similarity_matrix_N(matrix, matrix_sample): 
        """Calculates similarity matrix between two input matrices"""
        normalized_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        normalized_matrix_sample = F.normalize(matrix_sample, p=2, dim=1)
        similarity_matrix = dot_product = torch.bmm(normalized_matrix.transpose(2, 1), normalized_matrix_sample)
        return similarity_matrix

    @staticmethod
    def prime_vmap_2d_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, maximum): 
        """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 2D"""
        batched_process = torch.vmap(Conv1d_NN.process_batch_N, in_dims=(0, 0, None, None), out_dims=0)
        prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten=True, maximum=maximum)
        return prime 
    
    @staticmethod
    def prime_vmap_3d_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, maximum): 
        """Vectorization / Vmap Implementation for Nearest Neighbor Tensor 3D"""
        batched_process = torch.vmap(Conv1d_NN.process_batch_N, in_dims=(0, 0, None, None), out_dims=0)
        prime = batched_process(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten=False, maximum=maximum)
        return prime
    
    @staticmethod
    def process_batch_N(matrix, magnitude_matrix, num_nearest_neighbors, rand_idx, flatten, maximum): 
        """Process the batch of matrices by finding the K nearest neighbors with reshaping."""
        topk_ind = torch.topk(magnitude_matrix, num_nearest_neighbors - 1, largest=maximum).indices
        device = topk_ind.device
        rand_idx = rand_idx.to(device) # same device as topk_ind
        mapped_tensor = rand_idx[topk_ind] 
        index_tensor = torch.arange(0, matrix.shape[1], device=device).unsqueeze(1) # shape [40, 1]
        final_tensor = torch.cat([index_tensor, mapped_tensor], dim=1)
        neigh = matrix[:, final_tensor] 
        if flatten: 
            reshape = torch.flatten(neigh, start_dim=1)
            return reshape
        return neigh
    

In [7]:
'''All Samples'''

import torch
import torch.nn as nn
import time

# Check if MPS is available
if torch.backends.mps.is_available():
    device_mps = torch.device("mps")
    print("MPS device is available!")
else:
    device_mps = torch.device("cpu")
    print("MPS device not found, using CPU instead")

device_cpu = torch.device("cpu")

# Create proper input for Conv1d: [batch_size, channels, length]
batch_size = 800
x = torch.randn(batch_size, 3, 150)

# Define models and optimizer
def benchmark_device(device, iterations=100):
    # Move data and model to the specified device
    x_device = x.to(device)
    model = Conv1d_NN(3, 16, K=3, stride=3).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Warmup
    for _ in range(5):
        out = model(x_device)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Benchmark
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    start = time.time()
    
    for _ in range(iterations):
        # Forward pass
        out = model(x_device)
        
        # Compute "loss"
        loss = out.sum()
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        optimizer.zero_grad()
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    
    return ((end - start), model)

# Run benchmarks
print(f"Running benchmark with batch size {batch_size}")

# time_cpu = benchmark_device(device_cpu)
# print(f"CPU Time: {time_cpu:.4f} seconds")

if torch.backends.mps.is_available():
    benchmark = benchmark_device(device_mps)
    time_mps = benchmark[0]
    model = benchmark[1]
    print(f"MPS Time: {time_mps:.4f} seconds")
    # print(f"MPS Speedup: {time_cpu/time_mps:.2f}x")
    
print()
print("Matix Magnitude Times: ", sum(model.matrix_magnitude_times))
print("Prime Times: ", sum(model.prime_times))
print("Conv1d Times: ", sum(model.conv1d_times))

MPS device is available!
Running benchmark with batch size 800
MPS Time: 8.8121 seconds

Matix Magnitude Times:  0.04259228706359863
Prime Times:  9.166929483413696
Conv1d Times:  0.022477149963378906


In [None]:
'''50 Samples'''

import torch
import torch.nn as nn
import time

# Check if MPS is available
if torch.backends.mps.is_available():
    device_mps = torch.device("mps")
    print("MPS device is available!")
else:
    device_mps = torch.device("cpu")
    print("MPS device not found, using CPU instead")

device_cpu = torch.device("cpu")

# Create proper input for Conv1d: [batch_size, channels, length]
batch_size = 800
x = torch.randn(batch_size, 3, 150)

# Define models and optimizer
def benchmark_device(device, iterations=100):
    # Move data and model to the specified device
    x_device = x.to(device)
    model = Conv1d_NN(3, 16, K=3, stride=3, samples=50).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Warmup
    for _ in range(5):
        out = model(x_device)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Benchmark
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    start = time.time()
    
    for _ in range(iterations):
        # Forward pass
        out = model(x_device)
        
        # Compute "loss"
        loss = out.sum()
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        optimizer.zero_grad()
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    
    return ((end - start), model)

# Run benchmarks
print(f"Running benchmark with batch size {batch_size}")

# time_cpu = benchmark_device(device_cpu)
# print(f"CPU Time: {time_cpu:.4f} seconds")

if torch.backends.mps.is_available():
    benchmark = benchmark_device(device_mps)
    time_mps = benchmark[0]
    model = benchmark[1]
    print(f"MPS Time: {time_mps:.4f} seconds")
    # print(f"MPS Speedup: {time_cpu/time_mps:.2f}x")
    
print()
print("Random Index Times: ", sum(model.random_idx_times))
print("Matix Magnitude Times: ", sum(model.matrix_magnitude_times))
print("Prime Times: ", sum(model.prime_times))
print("Conv1d Times: ", sum(model.conv1d_times))

MPS device is available!
Running benchmark with batch size 800
MPS Time: 6.6426 seconds

Random Index Times:  0.04142451286315918
Matix Magnitude Times:  0.5385489463806152
Prime Times:  6.592907428741455
Conv1d Times:  0.02121138572692871


2. Optimized

In [9]:

class Conv1d_NN_optimized(nn.Module): 
    """
    Convolution 1D Nearest Neighbor Layer for Convolutional Neural Networks.
    
    Attributes:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        K (int): Number of Nearest Neighbors for consideration.
        stride (int): Stride size.
        padding (int): Padding size.
        shuffle_pattern (str): Shuffle pattern.
        shuffle_scale (int): Shuffle scale factor.
        samples (int/str): Number of samples to consider.
        magnitude_type (str): Distance or Similarity.
        
    Notes:
        - K must be same as stride. K == stride.
    """
    
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 K=3, 
                 stride=3, 
                 padding=0, 
                 shuffle_pattern='N/A', 
                 shuffle_scale=2, 
                 samples='all', 
                 magnitude_type='similarity'
                 ): 
        
        """
        Initializes the Conv1d_NN module.
        
        Parameters:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            K (int): Number of Nearest Neighbors for consideration.
            stride (int): Stride size.
            padding (int): Padding size.
            shuffle_pattern (str): Shuffle pattern: "B", "A", "BA".
            shuffle_scale (int): Shuffle scale factor.
            samples (int/str): Number of samples to consider.
            magnitude_type (str): Distance or Similarity.
        """
        
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.stride = stride 
        self.padding = padding
        self.shuffle_pattern = shuffle_pattern 
        self.shuffle_scale = shuffle_scale
        self.samples = int(samples) if samples != 'all' else samples 
        self.magnitude_type = magnitude_type 
        self.maximum = True if self.magnitude_type == 'similarity' else False
        
        # Unshuffle layer 
        self.unshuffle_layer = PixelUnshuffle1D(downscale_factor=self.shuffle_scale)
        
        # Shuffle Layer 
        self.shuffle_layer = PixelShuffle1D(upscale_factor=self.shuffle_scale)
        
        # Channels for Conv1d Layer
        self.in_channels = in_channels * shuffle_scale if self.shuffle_pattern in ["BA", "B"] else in_channels
        self.out_channels = out_channels * shuffle_scale if self.shuffle_pattern in ["BA", "A"] else out_channels

        # Conv1d Layer 
        self.conv1d_layer = Conv1d(in_channels=self.in_channels, 
                                    out_channels=self.out_channels, 
                                    kernel_size=self.K, 
                                    stride=self.stride, 
                                    padding=self.padding)
        
        self.random_idx_times = []

        
        self.matrix_magnitude_times = []
        self.prime_times = []
        self.conv1d_times = []

    def forward(self, x): 
        # Consider all samples 
        if self.samples == 'all': 
            # Unshuffle Layer 
            if self.shuffle_pattern in ["B", "BA"]:
                x1 = self.unshuffle_layer(x)
            else:
                x1 = x
            
            matrix_magnitude_start = time.time()    
            # Calculate Distance/Similarity Matrix + Prime Vmap 2D
            if self.magnitude_type == 'distance': 
                matrix_magnitude = self._calculate_distance_matrix(x1)
            elif self.magnitude_type == 'similarity':
                matrix_magnitude = self._calculate_similarity_matrix(x1)
            matrix_magnitude_end = time.time()
            self.matrix_magnitude_times.append(matrix_magnitude_end - matrix_magnitude_start)
                
                
                
            prime_start = time.time()
            prime_2d = self._prime(x1, matrix_magnitude, self.K, self.maximum) 
            prime_end = time.time() 
            self.prime_times.append(prime_end - prime_start)
            
            conv1d_start = time.time()
            # Conv1d Layer
            x2 = self.conv1d_layer(prime_2d)
            conv1d_end = time.time()
            self.conv1d_times.append(conv1d_end - conv1d_start)
            
            # Shuffle Layer 
            if self.shuffle_pattern in ["A", "BA"]:
                x3 = self.shuffle_layer(x2)
            else:
                x3 = x2
            
            return x3
        
        # Consider N samples
        else: 
            # Unshuffle Layer 
            if self.shuffle_pattern in ["B", "BA"]:
                x1 = self.unshuffle_layer(x)
            else:
                x1 = x
                
            random_idx_start = time.time()
            # Calculate Distance/Similarity Matrix + Prime Vmap 2D
            rand_idx = torch.randperm(x1.shape[2], device=x1.device)[:self.samples]
            x1_sample = x1[:, :, rand_idx]
            
            random_idx_end = time.time()
            self.random_idx_times.append(random_idx_end - random_idx_start)
                
            
            
            
            matrix_magnitude_start = time.time()
            if self.magnitude_type == 'distance':
                matrix_magnitude = self._calculate_distance_matrix_N(x1, x1_sample)
            elif self.magnitude_type == 'similarity':
                matrix_magnitude = self._calculate_similarity_matrix_N(x1, x1_sample)
                
            range_idx = torch.arange(len(rand_idx), device=x1.device)
                
        
            if self.magnitude_type == 'distance':
                matrix_magnitude[:, rand_idx, range_idx] = float('inf') 
            elif self.magnitude_type == 'similarity':
                matrix_magnitude[:, rand_idx, range_idx] = float('-inf')
                
            matrix_magnitude_end = time.time()
            self.matrix_magnitude_times.append(matrix_magnitude_end - matrix_magnitude_start)
                
            
            prime_start = time.time()
            prime = self._prime_N(x1, matrix_magnitude, self.K, rand_idx, self.maximum)
            prime_end = time.time()
            self.prime_times.append(prime_end - prime_start)
            
            
            # Conv1d Layer
            conv1d_start = time.time()
            x2 = self.conv1d_layer(prime)
            conv1d_end = time.time()
            self.conv1d_times.append(conv1d_end - conv1d_start)
            
            
            # Shuffle Layer
            if self.shuffle_pattern in ["A", "BA"]:
                x3 = self.shuffle_layer(x2)
            else:
                x3 = x2
            
            return x3
    
    
    def _calculate_distance_matrix(self, matrix, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix)
        dist_matrix = norm_squared + norm_squared.transpose(2, 1) - 2 * dot_product
        
        dist_matrix = torch.clamp(dist_matrix, min=0) # remove negative values
        
        if sqrt:
            dist_matrix = torch.sqrt(dist_matrix)
        return dist_matrix
    
    def _calculate_distance_matrix_N(self, matrix, matrix_sample, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
        
        dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
        
        dist_matrix = torch.clamp(dist_matrix, min=0) # remove negative values
        
        if sqrt:
            dist_matrix = torch.sqrt(dist_matrix)
        return dist_matrix
    
    
    def _calculate_similarity_matrix(self, matrix):
        norm_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_matrix)
        return similarity_matrix
    
    def _calculate_similarity_matrix_N(self, matrix, matrix_sample):
        norm_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        norm_sample = F.normalize(matrix_sample, p=2, dim=1)
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
        return similarity_matrix

    def _prime(self, matrix, magnitude_matrix, K, maximum):
        b, c, t = matrix.shape 

        _, topk_indices = torch.topk(magnitude_matrix, k = K, dim=2, largest=maximum)
        
        tk = topk_indices.shape[-1]
        
        assert K == tk, "Error: K must be same as tk. K == tk."
        
        indices_expanded = topk_indices.unsqueeze(1).expand(b, c, t, tk)
        batch_indices = torch.arange(b).view(b, 1, 1, 1).expand(b, c, t, tk)
        channel_indices = torch.arange(c).view(1, c, 1, 1).expand(b, c, t, tk)
        
        prime = matrix[batch_indices, channel_indices, indices_expanded]
        prime = prime.view(b, c, -1)
        return prime
    
    def _prime_N(self, matrix, magnitude_matrix, K, rand_idx, maximum):
        
        b, c, t = matrix.shape

        _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."
        
        mapped_tensor = rand_idx[topk_indices]


        token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
        
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)

        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)
        
        batch_indices = torch.arange(b, device=matrix.device).view(b, 1, 1, 1).expand(b, c, t, K)
        channel_indices = torch.arange(c, device=matrix.device).view(1, c, 1, 1).expand(b, c, t, K)

        prime = matrix[batch_indices, channel_indices, indices_expanded]  
        
        prime = prime.view(b, c, -1)
        return prime


In [10]:
'''All Samples'''

import torch
import torch.nn as nn
import time

# Check if MPS is available
if torch.backends.mps.is_available():
    device_mps = torch.device("mps")
    print("MPS device is available!")
else:
    device_mps = torch.device("cpu")
    print("MPS device not found, using CPU instead")

device_cpu = torch.device("cpu")

# Create proper input for Conv1d: [batch_size, channels, length]
batch_size = 800
x = torch.randn(batch_size, 3, 150)

# Define models and optimizer
def benchmark_device(device, iterations=100):
    # Move data and model to the specified device
    x_device = x.to(device)
    model = Conv1d_NN_optimized(3, 16, K=3, stride=3).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Warmup
    for _ in range(5):
        out = model(x_device)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Benchmark
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    start = time.time()
    
    for _ in range(iterations):
        # Forward pass
        out = model(x_device)
        
        # Compute "loss"
        loss = out.sum()
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        optimizer.zero_grad()
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    
    return ((end - start), model)

# Run benchmarks
print(f"Running benchmark with batch size {batch_size}")

# time_cpu = benchmark_device(device_cpu)
# print(f"CPU Time: {time_cpu:.4f} seconds")

if torch.backends.mps.is_available():
    benchmark = benchmark_device(device_mps)
    time_mps = benchmark[0]
    model = benchmark[1]
    print(f"MPS Time: {time_mps:.4f} seconds")
    # print(f"MPS Speedup: {time_cpu/time_mps:.2f}x")

print()
print("Random Index Times: ", sum(model.random_idx_times))
print("Matix Magnitude Times: ", sum(model.matrix_magnitude_times))
print("Prime Times: ", sum(model.prime_times))
print("Conv1d Times: ", sum(model.conv1d_times))

MPS device is available!
Running benchmark with batch size 800
MPS Time: 8.8645 seconds

Random Index Times:  0
Matix Magnitude Times:  0.025761127471923828
Prime Times:  9.175337314605713
Conv1d Times:  0.021328210830688477


In [11]:
'''50 Samples'''

import torch
import torch.nn as nn
import time

# Check if MPS is available
if torch.backends.mps.is_available():
    device_mps = torch.device("mps")
    print("MPS device is available!")
else:
    device_mps = torch.device("cpu")
    print("MPS device not found, using CPU instead")

device_cpu = torch.device("cpu")

# Create proper input for Conv1d: [batch_size, channels, length]
batch_size = 800
x = torch.randn(batch_size, 3, 150)

# Define models and optimizer
def benchmark_device(device, iterations=100):
    # Move data and model to the specified device
    x_device = x.to(device)
    model = Conv1d_NN_optimized(3, 16, K=3, stride=3, samples=50).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Warmup
    for _ in range(5):
        out = model(x_device)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Benchmark
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    start = time.time()
    
    for _ in range(iterations):
        # Forward pass
        out = model(x_device)
        
        # Compute "loss"
        loss = out.sum()
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        optimizer.zero_grad()
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    
    return ((end - start), model)

# Run benchmarks
print(f"Running benchmark with batch size {batch_size}")

# time_cpu = benchmark_device(device_cpu)
# print(f"CPU Time: {time_cpu:.4f} seconds")

if torch.backends.mps.is_available():
    benchmark = benchmark_device(device_mps)
    time_mps = benchmark[0]
    model = benchmark[1]
    print(f"MPS Time: {time_mps:.4f} seconds")
    # print(f"MPS Speedup: {time_cpu/time_mps:.2f}x")
    
print()
print("Random Index Times: ", sum(model.random_idx_times))
print("Matix Magnitude Times: ", sum(model.matrix_magnitude_times))
print("Prime Times: ", sum(model.prime_times))
print("Conv1d Times: ", sum(model.conv1d_times))

MPS device is available!
Running benchmark with batch size 800
MPS Time: 6.3435 seconds

Random Index Times:  0.02911543846130371
Matix Magnitude Times:  0.08023738861083984
Prime Times:  2.34340500831604
Conv1d Times:  0.011876106262207031


Optimized V2

In [12]:

class Conv1d_NN_optimized_v2(nn.Module): 
    """
    Convolution 1D Nearest Neighbor Layer for Convolutional Neural Networks.
    
    Attributes:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        K (int): Number of Nearest Neighbors for consideration.
        stride (int): Stride size.
        padding (int): Padding size.
        shuffle_pattern (str): Shuffle pattern.
        shuffle_scale (int): Shuffle scale factor.
        samples (int/str): Number of samples to consider.
        magnitude_type (str): Distance or Similarity.
        
    Notes:
        - K must be same as stride. K == stride.
    """
    
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 K=3, 
                 stride=3, 
                 padding=0, 
                 shuffle_pattern='N/A', 
                 shuffle_scale=2, 
                 samples='all', 
                 magnitude_type='similarity'
                 ): 
        
        """
        Initializes the Conv1d_NN module.
        
        Parameters:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            K (int): Number of Nearest Neighbors for consideration.
            stride (int): Stride size.
            padding (int): Padding size.
            shuffle_pattern (str): Shuffle pattern: "B", "A", "BA".
            shuffle_scale (int): Shuffle scale factor.
            samples (int/str): Number of samples to consider.
            magnitude_type (str): Distance or Similarity.
        """
        
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.stride = stride 
        self.padding = padding
        self.shuffle_pattern = shuffle_pattern 
        self.shuffle_scale = shuffle_scale
        self.samples = int(samples) if samples != 'all' else samples 
        self.magnitude_type = magnitude_type 
        self.maximum = True if self.magnitude_type == 'similarity' else False
        
        # Unshuffle layer 
        self.unshuffle_layer = PixelUnshuffle1D(downscale_factor=self.shuffle_scale)
        
        # Shuffle Layer 
        self.shuffle_layer = PixelShuffle1D(upscale_factor=self.shuffle_scale)
        
        # Channels for Conv1d Layer
        self.in_channels = in_channels * shuffle_scale if self.shuffle_pattern in ["BA", "B"] else in_channels
        self.out_channels = out_channels * shuffle_scale if self.shuffle_pattern in ["BA", "A"] else out_channels

        # Conv1d Layer 
        self.conv1d_layer = Conv1d(in_channels=self.in_channels, 
                                    out_channels=self.out_channels, 
                                    kernel_size=self.K, 
                                    stride=self.stride, 
                                    padding=self.padding)
        
        self.random_idx_times = []

        
        self.matrix_magnitude_times = []
        self.prime_times = []
        self.conv1d_times = []
        

    def forward(self, x): 
        # Consider all samples 
        if self.samples == 'all': 
            # Unshuffle Layer 
            if self.shuffle_pattern in ["B", "BA"]:
                x1 = self.unshuffle_layer(x)
            else:
                x1 = x
            
            matrix_magnitude_start = time.time()    
            # Calculate Distance/Similarity Matrix + Prime Vmap 2D
            if self.magnitude_type == 'distance': 
                matrix_magnitude = self._calculate_distance_matrix(x1)
            elif self.magnitude_type == 'similarity':
                matrix_magnitude = self._calculate_similarity_matrix(x1)
            matrix_magnitude_end = time.time()
            self.matrix_magnitude_times.append(matrix_magnitude_end - matrix_magnitude_start)
                
                
                
            prime_start = time.time()
            prime_2d = self._prime(x1, matrix_magnitude, self.K, self.maximum) 
            prime_end = time.time() 
            self.prime_times.append(prime_end - prime_start)
            
            conv1d_start = time.time()
            # Conv1d Layer
            x2 = self.conv1d_layer(prime_2d)
            conv1d_end = time.time()
            self.conv1d_times.append(conv1d_end - conv1d_start)
            
            # Shuffle Layer 
            if self.shuffle_pattern in ["A", "BA"]:
                x3 = self.shuffle_layer(x2)
            else:
                x3 = x2
            
            return x3
        
        # Consider N samples
        else: 
            # Unshuffle Layer 
            if self.shuffle_pattern in ["B", "BA"]:
                x1 = self.unshuffle_layer(x)
            else:
                x1 = x
                
            random_idx_start = time.time()
            # Calculate Distance/Similarity Matrix + Prime Vmap 2D
            rand_idx = torch.randperm(x1.shape[2], device=x1.device)[:self.samples]
            x1_sample = x1[:, :, rand_idx]
            
            random_idx_end = time.time()
            self.random_idx_times.append(random_idx_end - random_idx_start)
                
            
            
            
            matrix_magnitude_start = time.time()
            if self.magnitude_type == 'distance':
                matrix_magnitude = self._calculate_distance_matrix_N(x1, x1_sample)
            elif self.magnitude_type == 'similarity':
                matrix_magnitude = self._calculate_similarity_matrix_N(x1, x1_sample)
                
            range_idx = torch.arange(len(rand_idx), device=x1.device)
                
        
            if self.magnitude_type == 'distance':
                matrix_magnitude[:, rand_idx, range_idx] = float('inf') 
            elif self.magnitude_type == 'similarity':
                matrix_magnitude[:, rand_idx, range_idx] = float('-inf')
                
            matrix_magnitude_end = time.time()
            self.matrix_magnitude_times.append(matrix_magnitude_end - matrix_magnitude_start)
                
            
            prime_start = time.time()
            prime = self._prime_N(x1, matrix_magnitude, self.K, rand_idx, self.maximum)
            prime_end = time.time()
            self.prime_times.append(prime_end - prime_start)
            
            
            # Conv1d Layer
            conv1d_start = time.time()
            x2 = self.conv1d_layer(prime)
            conv1d_end = time.time()
            self.conv1d_times.append(conv1d_end - conv1d_start)
            
            
            # Shuffle Layer
            if self.shuffle_pattern in ["A", "BA"]:
                x3 = self.shuffle_layer(x2)
            else:
                x3 = x2
            
            return x3
    
    
    def _calculate_distance_matrix(self, matrix, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix)
        dist_matrix = norm_squared + norm_squared.transpose(2, 1) - 2 * dot_product
        
        dist_matrix = torch.clamp(dist_matrix, min=0) # remove negative values
        
        if sqrt:
            dist_matrix = torch.sqrt(dist_matrix)
        return dist_matrix
    
    def _calculate_distance_matrix_N(self, matrix, matrix_sample, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
        
        dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
        
        dist_matrix = torch.clamp(dist_matrix, min=0) # remove negative values
        
        if sqrt:
            dist_matrix = torch.sqrt(dist_matrix)
        return dist_matrix
    
    
    def _calculate_similarity_matrix(self, matrix):
        norm_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_matrix)
        return similarity_matrix
    
    def _calculate_similarity_matrix_N(self, matrix, matrix_sample):
        norm_matrix = F.normalize(matrix, p=2, dim=1) # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        norm_sample = F.normalize(matrix_sample, p=2, dim=1)
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
        return similarity_matrix

    def _prime(self, matrix, magnitude_matrix, K, maximum):
        b, c, t = matrix.shape
        # Get top-K indices: shape [b, t, K]
        _, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
        
        # Expand indices to add channel dimension: [b, 1, t, K] then expand to [b, c, t, K]
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        
        # Unsqueeze matrix and expand so that the gathered dimension has size K.
        # matrix.unsqueeze(-1) yields shape [b, c, t, 1]
        # Then expand to [b, c, t, K] and force contiguous memory.
        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        
        # Gather along the token dimension (dim=2) using the expanded indices.
        prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)
        
        # Flatten the token and neighbor dimensions: [b, c, t*K]
        prime = prime.view(b, c, -1)
        return prime
    
    def _prime_N(self, matrix, magnitude_matrix, K, rand_idx, maximum):
        b, c, t = matrix.shape

        # Get top-(K-1) indices from the magnitude matrix; shape: [b, t, K-1]
        _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        # Map indices from the sampled space to the full token indices using rand_idx.
        # mapped_tensor will have shape: [b, t, K-1]
        mapped_tensor = rand_idx[topk_indices]

        # Create self indices for each token; shape: [1, t, 1] then expand to [b, t, 1]
        token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)

        # Concatenate self index with neighbor indices to form final indices; shape: [b, t, K]
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)

        # Expand final_indices to include the channel dimension; result shape: [b, c, t, K]
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)

        # Expand matrix to shape [b, c, t, 1] and then to [b, c, t, K] (ensuring contiguous memory)
        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()

        # Gather neighbor features along the token dimension (dim=2)
        prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  # shape: [b, c, t, K]

        # Flatten the token and neighbor dimensions into one: [b, c, t*K]
        prime = prime.view(b, c, -1)
        return prime

In [15]:
'''All Samples'''

import torch
import torch.nn as nn
import time

# Check if MPS is available
if torch.backends.mps.is_available():
    device_mps = torch.device("mps")
    print("MPS device is available!")
else:
    device_mps = torch.device("cpu")
    print("MPS device not found, using CPU instead")

device_cpu = torch.device("cpu")

# Create proper input for Conv1d: [batch_size, channels, length]
batch_size = 800
x = torch.randn(batch_size, 3, 150)

# Define models and optimizer
def benchmark_device(device, iterations=100):
    # Move data and model to the specified device
    x_device = x.to(device)
    model = Conv1d_NN_optimized_v2(3, 16, K=3, stride=3).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Warmup
    for _ in range(5):
        out = model(x_device)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Benchmark
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    start = time.time()
    
    for _ in range(iterations):
        # Forward pass
        out = model(x_device)
        
        # Compute "loss"
        loss = out.sum()
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        optimizer.zero_grad()
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    
    return ((end - start), model)

# Run benchmarks
print(f"Running benchmark with batch size {batch_size}")

# time_cpu = benchmark_device(device_cpu)
# print(f"CPU Time: {time_cpu:.4f} seconds")

if torch.backends.mps.is_available():
    benchmark = benchmark_device(device_mps)
    time_mps = benchmark[0]
    model = benchmark[1]
    print(f"MPS Time: {time_mps:.4f} seconds")
    # print(f"MPS Speedup: {time_cpu/time_mps:.2f}x")

print()
print("Random Index Times: ", sum(model.random_idx_times))
print("Matix Magnitude Times: ", sum(model.matrix_magnitude_times))
print("Prime Times: ", sum(model.prime_times))
print("Conv1d Times: ", sum(model.conv1d_times))

MPS device is available!
Running benchmark with batch size 800
MPS Time: 8.6103 seconds

Random Index Times:  0
Matix Magnitude Times:  5.652441740036011
Prime Times:  0.09897708892822266
Conv1d Times:  0.01589179039001465


In [16]:
'''50 Samples'''

import torch
import torch.nn as nn
import time

# Check if MPS is available
if torch.backends.mps.is_available():
    device_mps = torch.device("mps")
    print("MPS device is available!")
else:
    device_mps = torch.device("cpu")
    print("MPS device not found, using CPU instead")

device_cpu = torch.device("cpu")

# Create proper input for Conv1d: [batch_size, channels, length]
batch_size = 800
x = torch.randn(batch_size, 3, 150)

# Define models and optimizer
def benchmark_device(device, iterations=100):
    # Move data and model to the specified device
    x_device = x.to(device)
    model = Conv1d_NN_optimized_v2(3, 16, K=3, stride=3, samples=50).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    # Warmup
    for _ in range(5):
        out = model(x_device)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    # Benchmark
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    start = time.time()
    
    for _ in range(iterations):
        # Forward pass
        out = model(x_device)
        
        # Compute "loss"
        loss = out.sum()
        
        # Backward pass
        loss.backward()
        
        # Update weights
        optimizer.step()
        optimizer.zero_grad()
    
    torch.cuda.synchronize() if device.type == 'cuda' else None
    if device.type == 'mps':
        torch.mps.synchronize()
    
    end = time.time()
    
    return ((end - start), model)

# Run benchmarks
print(f"Running benchmark with batch size {batch_size}")

# time_cpu = benchmark_device(device_cpu)
# print(f"CPU Time: {time_cpu:.4f} seconds")

if torch.backends.mps.is_available():
    benchmark = benchmark_device(device_mps)
    time_mps = benchmark[0]
    model = benchmark[1]
    print(f"MPS Time: {time_mps:.4f} seconds")
    # print(f"MPS Speedup: {time_cpu/time_mps:.2f}x")
    
print()
print("Random Index Times: ", sum(model.random_idx_times))
print("Matix Magnitude Times: ", sum(model.matrix_magnitude_times))
print("Prime Times: ", sum(model.prime_times))
print("Conv1d Times: ", sum(model.conv1d_times))

MPS device is available!
Running benchmark with batch size 800
MPS Time: 6.2555 seconds

Random Index Times:  0.05554485321044922
Matix Magnitude Times:  0.08896684646606445
Prime Times:  2.3018600940704346
Conv1d Times:  0.0177919864654541


## SOMETHING IS WRONG WITH ConvNN1d Optimized N Random Sampling