In [1]:
import torch 
import torch.nn as nn 


In [2]:

import argparse 
from pathlib import Path
import os 

# Datasets 
from dataset import ImageNet, CIFAR10, CIFAR100
from train_eval import Train_Eval

# Models 
from models.allconvnet import AllConvNet 

# Utilities 
from utils import write_to_file, set_seed


In [3]:
import os
print("Current directory:", os.getcwd())



Current directory: /mnt/research/j.farias/mkang2/Convolutional-Nearest-Neighbor


### I. CONTROL Conv2d

In [4]:
from types import SimpleNamespace

# Create default args
args = SimpleNamespace(
    layer="Conv2d",
    num_layers=3,
    channels=[8, 16, 32],
    K=9,
    kernel_size=3,
    sampling_type="all",
    num_samples=-1,
    sample_padding=0,
    num_heads=4,
    attention_dropout=0.1,
    shuffle_pattern="BA",
    shuffle_scale=2,
    magnitude_type="similarity",
    coordinate_encoding=False,
    dataset="cifar10",
    data_path="./Data",
    batch_size=64,
    num_epochs=100,
    use_amp=False,
    clip_grad_norm=None,
    criterion="CrossEntropy",
    optimizer="adamw",
    momentum=0.9,
    weight_decay=1e-6,
    lr=1e-3,
    lr_step=20,
    lr_gamma=0.1,
    scheduler="step",
    device="cuda",
    seed=0,
    output_dir="./Output/Simple/Conv2d_Control", 
    resize=False
)
    

In [5]:
# Check if the output directory exists, if not create it
if args.output_dir:
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

# Dataset 
if args.dataset == "cifar10":
    dataset = CIFAR10(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size 
elif args.dataset == "cifar100":
    dataset = CIFAR100(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size 
elif args.dataset == "imagenet":
    dataset = ImageNet(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size
else:
    raise ValueError("Dataset not supported")

# Model 
model = AllConvNet(args)
print(f"Model: {model.name}")

# Parameters
total_params, trainable_params = model.parameter_count()
print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")
args.total_params = total_params
args.trainable_params = trainable_params

# Set the seed for reproducibility
set_seed(args.seed)


# Training Modules 
train_eval_results = Train_Eval(args, 
                            model, 
                            dataset.train_loader, 
                            dataset.test_loader
                            )

# Storing Results in output directory 
write_to_file(os.path.join(args.output_dir, "args.txt"), args)
write_to_file(os.path.join(args.output_dir, "model.txt"), model)
write_to_file(os.path.join(args.output_dir, "train_eval_results.txt"), train_eval_results)

Files already downloaded and verified




Upscale transform not defined. Skipping dataset upscale.
Model: All Convolutional Network Conv2d
Total Parameters: 6362
Trainable Parameters: 6362




[Epoch 001] Time: 7.2123s | [Train] Loss: 2.01643919 Accuracy: Top1: 25.9211%, Top5: 76.5465% | [Test] Loss: 1.78093931 Accuracy: Top1: 37.7886%, Top5: 86.6740%
[Epoch 002] Time: 6.9808s | [Train] Loss: 1.80737409 Accuracy: Top1: 33.8675%, Top5: 84.3950% | [Test] Loss: 1.67001616 Accuracy: Top1: 42.2373%, Top5: 88.5052%
[Epoch 003] Time: 7.0129s | [Train] Loss: 1.74238550 Accuracy: Top1: 36.7807%, Top5: 86.1273% | [Test] Loss: 1.61141083 Accuracy: Top1: 44.2377%, Top5: 90.2070%
[Epoch 004] Time: 7.2755s | [Train] Loss: 1.69237972 Accuracy: Top1: 38.7148%, Top5: 87.2822% | [Test] Loss: 1.57868538 Accuracy: Top1: 44.9542%, Top5: 90.6150%
[Epoch 005] Time: 7.1388s | [Train] Loss: 1.65905210 Accuracy: Top1: 40.0036%, Top5: 87.9616% | [Test] Loss: 1.54086816 Accuracy: Top1: 45.5215%, Top5: 91.3416%
[Epoch 006] Time: 7.1853s | [Train] Loss: 1.63314448 Accuracy: Top1: 40.9407%, Top5: 88.4970% | [Test] Loss: 1.52009429 Accuracy: Top1: 46.4471%, Top5: 91.5605%
[Epoch 007] Time: 7.2083s | [Train

### II. Original ConvNN

In [8]:
from types import SimpleNamespace

# Create default args
args = SimpleNamespace(
    layer="ConvNN",
    num_layers=3,
    channels=[8, 16, 32],
    K=9,
    kernel_size=3,
    sampling_type="spatial",
    num_samples=6,
    sample_padding=0,
    num_heads=4,
    attention_dropout=0.1,
    shuffle_pattern="BA",
    shuffle_scale=2,
    magnitude_type="similarity",
    coordinate_encoding=False,
    dataset="cifar10",
    data_path="./Data",
    batch_size=64,
    num_epochs=100,
    use_amp=False,
    clip_grad_norm=None,
    criterion="CrossEntropy",
    optimizer="adamw",
    momentum=0.9,
    weight_decay=1e-6,
    lr=1e-3,
    lr_step=20,
    lr_gamma=0.1,
    scheduler="step",
    device="cuda",
    seed=0,
    output_dir="./Output/Simple/ConvNN_Spat", 
    resize=False
)
    

In [9]:
# Check if the output directory exists, if not create it
if args.output_dir:
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

# Dataset 
if args.dataset == "cifar10":
    dataset = CIFAR10(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size 
elif args.dataset == "cifar100":
    dataset = CIFAR100(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size 
elif args.dataset == "imagenet":
    dataset = ImageNet(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size
else:
    raise ValueError("Dataset not supported")

# Model 
model = AllConvNet(args)
print(f"Model: {model.name}")

# Parameters
total_params, trainable_params = model.parameter_count()
print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")
args.total_params = total_params
args.trainable_params = trainable_params

# Set the seed for reproducibility
set_seed(args.seed)


# Training Modules 
train_eval_results = Train_Eval(args, 
                            model, 
                            dataset.train_loader, 
                            dataset.test_loader
                            )

# Storing Results in output directory 
write_to_file(os.path.join(args.output_dir, "args.txt"), args)
write_to_file(os.path.join(args.output_dir, "model.txt"), model)
write_to_file(os.path.join(args.output_dir, "train_eval_results.txt"), train_eval_results)

Files already downloaded and verified
Upscale transform not defined. Skipping dataset upscale.
Model: All Convolutional Network ConvNN
Total Parameters: 97452
Trainable Parameters: 97452
[Epoch 001] Time: 10.0946s | [Train] Loss: 2.10904080 Accuracy: Top1: 24.8242%, Top5: 73.3256% | [Test] Loss: 1.96597065 Accuracy: Top1: 32.3746%, Top5: 82.4841%
[Epoch 002] Time: 10.1116s | [Train] Loss: 1.94094051 Accuracy: Top1: 33.2041%, Top5: 82.9344% | [Test] Loss: 1.86873518 Accuracy: Top1: 36.1266%, Top5: 85.6389%
[Epoch 003] Time: 10.1402s | [Train] Loss: 1.86896201 Accuracy: Top1: 36.6428%, Top5: 85.7057% | [Test] Loss: 1.81621471 Accuracy: Top1: 39.3710%, Top5: 87.1318%
[Epoch 004] Time: 10.1520s | [Train] Loss: 1.82236847 Accuracy: Top1: 39.0225%, Top5: 86.9945% | [Test] Loss: 1.78607956 Accuracy: Top1: 40.7842%, Top5: 88.1369%
[Epoch 005] Time: 10.2137s | [Train] Loss: 1.78477724 Accuracy: Top1: 40.9767%, Top5: 88.1074% | [Test] Loss: 1.75255738 Accuracy: Top1: 42.0780%, Top5: 89.7094%
[Ep

# NEW _PRIME_NEW function = multiply topk values with topk indexed values

In [14]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F 

class Conv2d_NN(nn.Module): 
    """Convolution 2D Nearest Neighbor Layer"""
    def __init__(self, 
                in_channels, 
                out_channels, 
                K,
                stride, 
                sampling_type, 
                num_samples, 
                sample_padding,
                shuffle_pattern, 
                shuffle_scale, 
                magnitude_type,
                coordinate_encoding=False
                ): 
        """
        Parameters: 
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            K (int): Number of Nearest Neighbors for consideration.
            stride (int): Stride size.
            sampling_type (str): Sampling type: "all", "random", "spatial".
            num_samples (int): Number of samples to consider. -1 for all samples.
            shuffle_pattern (str): Shuffle pattern: "B", "A", "BA".
            shuffle_scale (int): Shuffle scale factor.
            magnitude_type (str): Distance or Similarity.
        """
        super(Conv2d_NN, self).__init__()
        
        # Assertions 
        assert K == stride, "Error: K must be same as stride. K == stride."
        assert shuffle_pattern in ["B", "A", "BA", "NA"], "Error: shuffle_pattern must be one of ['B', 'A', 'BA', 'NA']"
        assert magnitude_type in ["distance", "similarity"], "Error: magnitude_type must be one of ['distance', 'similarity']"
        assert sampling_type in ["all", "random", "spatial"], "Error: sampling_type must be one of ['all', 'random', 'spatial']"
        assert int(num_samples) > 0 or int(num_samples) == -1, "Error: num_samples must be greater than 0 or -1 for all samples"
        assert (sampling_type == "all" and int(num_samples) == -1) or (sampling_type != "all" and isinstance(num_samples, int)), "Error: num_samples must be -1 for 'all' sampling or an integer for 'random' and 'spatial' sampling"
        
        # Initialize parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.K = K
        self.stride = stride
        self.sampling_type = sampling_type
        self.num_samples = num_samples if num_samples != -1 else 'all'  # -1 for all samples
        self.sample_padding = sample_padding if sampling_type == "spatial" else 0
        self.shuffle_pattern = shuffle_pattern
        self.shuffle_scale = shuffle_scale
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'similarity' else False

        # Positional Encoding (optional)
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {} 
        self.in_channels = in_channels + 2 if self.coordinate_encoding else in_channels
        self.out_channels = out_channels + 2 if self.coordinate_encoding else out_channels

        # Shuffle2D/Unshuffle2D Layers
        self.shuffle_layer = nn.PixelShuffle(upscale_factor=self.shuffle_scale)
        self.unshuffle_layer = nn.PixelUnshuffle(downscale_factor=self.shuffle_scale)
        
        # Adjust Channels for PixelShuffle
        self.in_channels_1d = self.in_channels * (self.shuffle_scale**2) if self.shuffle_pattern in ["B", "BA"] else self.in_channels
        self.out_channels_1d = self.out_channels * (self.shuffle_scale**2) if self.shuffle_pattern in ["A", "BA"] else self.out_channels

        # Conv1d Layer
        self.conv1d_layer = nn.Conv1d(in_channels=self.in_channels_1d, 
                                      out_channels=self.out_channels_1d, 
                                      kernel_size=self.K, 
                                      stride=self.stride, 
                                      padding=0)

        # Flatten Layer
        self.flatten = nn.Flatten(start_dim=2)

        # Pointwise Convolution Layer
        self.pointwise_conv = nn.Conv2d(in_channels=self.out_channels,
                                         out_channels=self.out_channels - 2,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)
        
        

    def forward(self, x): 
        # Coordinate Channels (optional) + Unshuffle + Flatten 
        x = self._add_coordinate_encoding(x) if self.coordinate_encoding else x
        x_2d = self.unshuffle_layer(x) if self.shuffle_pattern in ["B", "BA"] else x
        x = self.flatten(x_2d)

        if self.sampling_type == "all":    
            # ConvNN Algorithm 
            matrix_magnitude = self._calculate_distance_matrix(x, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix(x)
            
            prime = self._prime_new(x, matrix_magnitude, self.K, self.maximum) ### CHANGED
             
        elif self.sampling_type == "random":
            # Select random samples
            rand_idx = torch.randperm(x.shape[2], device=x.device)[:self.num_samples]
            x_sample = x[:, :, rand_idx]

            # ConvNN Algorithm 
            matrix_magnitude = self._calculate_distance_matrix_N(x, x_sample, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix_N(x, x_sample)
            range_idx = torch.arange(len(rand_idx), device=x.device)
            matrix_magnitude[:, rand_idx, range_idx] = float('inf') if self.magnitude_type == 'distance' else float('-inf')
            
            prime = self._prime_N_new(x, matrix_magnitude, self.K, rand_idx, self.maximum)
            
        elif self.sampling_type == "spatial":
            # Get spatial sampled indices
            x_ind = torch.linspace(0 + self.sample_padding, x_2d.shape[2] - self.sample_padding - 1, self.num_samples, device=x.device).to(torch.long)
            y_ind = torch.linspace(0 + self.sample_padding, x_2d.shape[3] - self.sample_padding - 1, self.num_samples, device=x.device).to(torch.long)
            x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')
            x_idx_flat, y_idx_flat = x_grid.flatten(), y_grid.flatten()
            width = x_2d.shape[2] 
            flat_indices = y_idx_flat * width + x_idx_flat  
            x_sample = x[:, :, flat_indices]

            # ConvNN Algorithm
            matrix_magnitude = self._calculate_distance_matrix_N(x, x_sample, sqrt=True) if self.magnitude_type == 'distance' else self._calculate_similarity_matrix_N(x, x_sample)
            range_idx = torch.arange(len(flat_indices), device=x.device)
            matrix_magnitude[:, flat_indices, range_idx] = float('inf') if self.magnitude_type == 'distance' else float('-inf')
            prime = self._prime_N_new(x, matrix_magnitude, self.K, flat_indices, self.maximum)
        else: 
            raise ValueError("Invalid sampling_type. Must be one of ['all', 'random', 'spatial'].")

        # Post-Processing 
        x_conv = self.conv1d_layer(prime) 
        
        # Unflatten + Shuffle
        unflatten = nn.Unflatten(dim=2, unflattened_size=x_2d.shape[2:])
        x = unflatten(x_conv)  # [batch_size, out_channels
        x = self.shuffle_layer(x) if self.shuffle_pattern in ["A", "BA"] else x
        x = self.pointwise_conv(x) if self.coordinate_encoding else x
        return x

    def _calculate_distance_matrix(self, matrix, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix)
        
        dist_matrix = norm_squared + norm_squared.transpose(2, 1) - 2 * dot_product
        # dist_matrix = torch.clamp(dist_matrix, min=0) # remove negative values
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix # take square root if needed
        
        return dist_matrix
    
    def _calculate_distance_matrix_N(self, matrix, matrix_sample, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True).permute(0, 2, 1)
        norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True).transpose(2, 1).permute(0, 2, 1)
        dot_product = torch.bmm(matrix.transpose(2, 1), matrix_sample)
        
        dist_matrix = norm_squared + norm_squared_sample - 2 * dot_product
        # dist_matrix = torch.clamp(dist_matrix, min=0) # remove negative values
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix

        return dist_matrix
    
    def _calculate_similarity_matrix(self, matrix):
        # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        norm_matrix = F.normalize(matrix, p=2, dim=1) 
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_matrix)
        return similarity_matrix
    
    def _calculate_similarity_matrix_N(self, matrix, matrix_sample):
        # p=2 (L2 Norm - Euclidean Distance), dim=1 (across the channels)
        norm_matrix = F.normalize(matrix, p=2, dim=1) 
        norm_sample = F.normalize(matrix_sample, p=2, dim=1)
        similarity_matrix = torch.bmm(norm_matrix.transpose(2, 1), norm_sample)
        return similarity_matrix

    def _prime_new(self, matrix, magnitude_matrix, K, maximum):
        b, c, t = matrix.shape
        topk_values, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)    
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)    
        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime
        prime = prime.view(b, c, -1)
        return prime
    
    def _prime_N(self, matrix, magnitude_matrix, K, rand_idx, maximum):
        b, c, t = matrix.shape
        _, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)

        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  
        prime = prime.view(b, c, -1)
        return prime
    
    def _prime_N_new(self, matrix, magnitude_matrix, K, rand_idx, maximum):
        b, c, t = matrix.shape
        topk_values, topk_indices = torch.topk(magnitude_matrix, k=K - 1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)

        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K-1)
        ones = torch.ones((b, c, t, 1), device=matrix.device)
        topk_values_exp = torch.cat((ones, topk_values_exp), dim=-1)
        

        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  
        prime = topk_values_exp * prime
        prime = prime.view(b, c, -1)
        return prime

    
    def _add_coordinate_encoding(self, x):
        b, _, h, w = x.shape
        cache_key = f"{b}_{h}_{w}_{x.device}"

        if cache_key in self.coordinate_cache:
            expanded_grid = self.coordinate_cache[cache_key]
        else:
            y_coords_vec = torch.linspace(start=-1, end=1, steps=h, device=x.device)
            x_coords_vec = torch.linspace(start=-1, end=1, steps=w, device=x.device)

            y_grid, x_grid = torch.meshgrid(y_coords_vec, x_coords_vec, indexing='ij')
            grid = torch.stack((x_grid, y_grid), dim=0).unsqueeze(0)
            expanded_grid = grid.expand(b, -1, -1, -1)
            self.coordinate_cache[cache_key] = expanded_grid

        x_with_coords = torch.cat((x, expanded_grid), dim=1)
        return x_with_coords

In [None]:

class Conv2d_New(nn.Module): 
    """Convolution 2D Nearest Neighbor Layer"""
    def __init__(self, 
                in_channels, 
                out_channels, 
                kernel_size,
                stride, 
                shuffle_pattern, 
                shuffle_scale, 
                coordinate_encoding
                ): 
        
        super(Conv2d_New, self).__init__()
        
        # Assertions 
        assert shuffle_pattern in ["B", "A", "BA", "NA"], "Error: shuffle_pattern must be one of ['B', 'A', 'BA', 'NA']"
        
        # Initialize parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.shuffle_pattern = shuffle_pattern
        self.shuffle_scale = shuffle_scale

        # Positional Encoding (optional)
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {} 
        self.in_channels = in_channels + 2 if self.coordinate_encoding else in_channels
        self.out_channels = out_channels + 2 if self.coordinate_encoding else out_channels

        # Shuffle2D/Unshuffle2D Layers
        self.shuffle_layer = nn.PixelShuffle(upscale_factor=self.shuffle_scale)
        self.unshuffle_layer = nn.PixelUnshuffle(downscale_factor=self.shuffle_scale)
        
        # Adjust Channels for PixelShuffle
        self.in_channels_shuff = self.in_channels * (self.shuffle_scale**2) if self.shuffle_pattern in ["B", "BA"] else self.in_channels
        self.out_channels_shuff = self.out_channels * (self.shuffle_scale**2) if self.shuffle_pattern in ["A", "BA"] else self.out_channels

        # Conv2d Layer
        self.conv2d_layer = nn.Conv2d(in_channels=self.in_channels_shuff, 
                                      out_channels=self.out_channels_shuff, 
                                      kernel_size=self.kernel_size, 
                                      stride=self.stride, 
                                      padding="same")


        # Pointwise Convolution Layer
        self.pointwise_conv = nn.Conv2d(in_channels=self.out_channels,
                                         out_channels=self.out_channels - 2,
                                         kernel_size=1,
                                         stride=1,
                                         padding=0)
        
        
    def forward(self, x): 
        # Coordinate Channels (optional) + Unshuffle + Flatten 
        x = self._add_coordinate_encoding(x) if self.coordinate_encoding else x
        x_2d = self.unshuffle_layer(x) if self.shuffle_pattern in ["B", "BA"] else x

        # Conv2d Layer
        x = self.conv2d_layer(x_2d)

        x = self.shuffle_layer(x) if self.shuffle_pattern in ["A", "BA"] else x
        x = self.pointwise_conv(x) if self.coordinate_encoding else x
        return x

    def _add_coordinate_encoding(self, x):
        b, _, h, w = x.shape
        cache_key = f"{b}_{h}_{w}_{x.device}"

        if cache_key in self.coordinate_cache:
            expanded_grid = self.coordinate_cache[cache_key]
        else:
            y_coords_vec = torch.linspace(start=-1, end=1, steps=h, device=x.device)
            x_coords_vec = torch.linspace(start=-1, end=1, steps=w, device=x.device)

            y_grid, x_grid = torch.meshgrid(y_coords_vec, x_coords_vec, indexing='ij')
            grid = torch.stack((x_grid, y_grid), dim=0).unsqueeze(0)
            expanded_grid = grid.expand(b, -1, -1, -1)
            self.coordinate_cache[cache_key] = expanded_grid

        x_with_coords = torch.cat((x, expanded_grid), dim=1)
        return x_with_coords

In [None]:
class AllConvNet(nn.Module): 
    def __init__(self, args): 
        super(AllConvNet, self).__init__()
        self.args = args
        self.model = "All Convolutional Network"
        self.name = f"{self.model} {self.args.layer}"
        
        layers = []
        in_ch = self.args.img_size[0] 

        for i in range(self.args.num_layers):
            out_ch = self.args.channels[i]

            # A dictionary to hold parameters for the current layer
            layer_params = {
                "in_channels": in_ch,
                "out_channels": out_ch,
                "shuffle_pattern": self.args.shuffle_pattern,
                "shuffle_scale": self.args.shuffle_scale,
            }

            if self.args.layer == "Conv2d":
                layer = Conv2d_New(
                    in_channels=in_ch, 
                    out_channels=out_ch, 
                    kernel_size=self.args.kernel_size, 
                    stride=1, 
                    shuffle_pattern=self.args.shuffle_pattern,
                    shuffle_scale=self.args.shuffle_scale,
                    coordinate_encoding=self.args.coordinate_encoding
                )
            
            elif self.args.layer == "ConvNN":
                layer_params.update({
                    "K": self.args.K,
                    "stride": self.args.K, # Stride is always K
                    "sampling_type": self.args.sampling_type,
                    "num_samples": self.args.num_samples,
                    "sample_padding": self.args.sample_padding,
                    "magnitude_type": self.args.magnitude_type,
                    "coordinate_encoding": self.args.coordinate_encoding
                })
                layer = Conv2d_NN(**layer_params)

            # elif self.args.layer == "ConvNN_Attn":
            #     layer_params.update({
            #         "K": self.args.K,
            #         "stride": self.args.K,
            #         "sampling_type": self.args.sampling_type,
            #         "num_samples": self.args.num_samples,
            #         "sample_padding": self.args.sample_padding,
            #         "magnitude_type": self.args.magnitude_type,
            #         "img_size": self.args.img_size[1:], # Pass H, W
            #         "attention_dropout": self.args.attention_dropout,
            #         "coordinate_encoding": self.args.coordinate_encoding
            #     })
            #     layer = Conv2d_NN_Attn(**layer_params)
            
            # elif self.args.layer == "Attention":
            #     layer_params.update({
            #         "num_heads": self.args.num_heads,
            #     })
            #     layer = Attention2d(**layer_params)
            # elif "/" in self.args.layer: # Handle all branching cases
            #     ch1 = out_ch // 2 if out_ch % 2 == 0 else out_ch // 2 + 1
            #     ch2 = out_ch - ch1
                
            #     layer_params.update({"channel_ratio": (ch1, ch2)})
                
            #     # --- Check all sub-cases for branching layers ---
            #     if self.args.layer == "Conv2d/ConvNN":
            #         layer_params.update({
            #             "kernel_size": self.args.kernel_size,
            #             "K": self.args.K, "stride": self.args.K,
            #             "sampling_type": self.args.sampling_type, "num_samples": self.args.num_samples,
            #             "sample_padding": self.args.sample_padding, "magnitude_type": self.args.magnitude_type,
            #             "coordinate_encoding": self.args.coordinate_encoding
            #         })
            #         layer = Conv2d_ConvNN_Branching(**layer_params)
                
            #     elif self.args.layer == "Conv2d/ConvNN_Attn":
            #         layer_params.update({
            #             "kernel_size": self.args.kernel_size,
            #             "K": self.args.K, "stride": self.args.K,
            #             "sampling_type": self.args.sampling_type, "num_samples": self.args.num_samples,
            #             "sample_padding": self.args.sample_padding, "magnitude_type": self.args.magnitude_type,
            #             "img_size": self.args.img_size[1:],
            #             "coordinate_encoding": self.args.coordinate_encoding
            #         })
            #         layer = Conv2d_ConvNN_Attn_Branching(**layer_params)
                
            #     elif self.args.layer == "Attention/ConvNN":
            #         layer_params.update({
            #             "num_heads": self.args.num_heads,
            #             "K": self.args.K, "stride": self.args.K,
            #             "sampling_type": self.args.sampling_type, "num_samples": self.args.num_samples,
            #             "sample_padding": self.args.sample_padding, "magnitude_type": self.args.magnitude_type,
            #             "coordinate_encoding": self.args.coordinate_encoding
            #         })
            #         layer = Attention_ConvNN_Branching(**layer_params)

            #     elif self.args.layer == "Attention/ConvNN_Attn":
            #         layer_params.update({
            #             "num_heads": self.args.num_heads,
            #             "K": self.args.K, "stride": self.args.K,
            #             "sampling_type": self.args.sampling_type, "num_samples": self.args.num_samples,
            #             "sample_padding": self.args.sample_padding, "magnitude_type": self.args.magnitude_type,
            #             "img_size": self.args.img_size[1:],
            #             "coordinate_encoding": self.args.coordinate_encoding
            #         })
            #         layer = Attention_ConvNN_Attn_Branching(**layer_params)
                
            #     # This is the specific case that was failing
            #     elif self.args.layer == "Conv2d/Attention":
            #         layer_params.update({
            #             "num_heads": self.args.num_heads,
            #             "kernel_size": self.args.kernel_size, 
            #             "coordinate_encoding": self.args.coordinate_encoding
            #         })
            #         layer = Attention_Conv2d_Branching(**layer_params)
                
            #     else:
            #         # This else now only catches unknown branching types
            #         raise ValueError(f"Unknown branching layer type: {self.args.layer}")

            else:
                # This is the final else for non-branching types
                raise ValueError(f"Layer type {self.args.layer} not supported in AllConvNet")

            layers.append(nn.InstanceNorm2d(num_features=out_ch)) # Pre-layer normalization
            layers.append(layer)
            if self.args.layer == "ConvNN_Attn":
                pass #layers.append(nn.Dropout(p=self.args.attention_dropout))
            layers.append(nn.ReLU(inplace=True))
            
            # Update in_ch for the next layer
            in_ch = out_ch
            
        self.features = nn.Sequential(*layers)
        
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
            
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(in_ch, self.args.num_classes) # Use the final in_ch value
        )
        
        self.to(self.args.device)

    def forward(self, x): 
        x = self.features(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x
    
    def summary(self): 
        original_device = next(self.parameters()).device
        try:
            self.to("cpu")
            print(f"--- Summary for {self.name} ---")
            # torchsummary expects batch dimension, but img_size doesn't include it
            summary(self, input_size=self.img_size, device="cpu") 
        except Exception as e:
            print(f"Could not generate summary: {e}")
        finally:
            # Move model back to its original device
            self.to(original_device)
        
    def parameter_count(self): 
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return total_params, trainable_params

In [16]:
from types import SimpleNamespace

# Create default args
args = SimpleNamespace(
    layer="ConvNN",
    num_layers=3,
    channels=[8, 16, 32],
    K=9,
    kernel_size=3,
    sampling_type="all",
    num_samples=-1,
    sample_padding=0,
    num_heads=4,
    attention_dropout=0.1,
    shuffle_pattern="BA",
    shuffle_scale=2,
    magnitude_type="similarity",
    coordinate_encoding=True,
    dataset="cifar10",
    data_path="./Data",
    batch_size=64,
    num_epochs=100,
    use_amp=False,
    clip_grad_norm=None,
    criterion="CrossEntropy",
    optimizer="adamw",
    momentum=0.9,
    weight_decay=1e-6,
    lr=1e-3,
    lr_step=20,
    lr_gamma=0.1,
    scheduler="step",
    device="cuda",
    seed=0,
    output_dir="./Output/Simple/ConvNN_Coord_New_Prime_No_Clamp", 
    resize=False
)
    

In [17]:
# Check if the output directory exists, if not create it
if args.output_dir:
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

# Dataset 
if args.dataset == "cifar10":
    dataset = CIFAR10(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size 
elif args.dataset == "cifar100":
    dataset = CIFAR100(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size 
elif args.dataset == "imagenet":
    dataset = ImageNet(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size
else:
    raise ValueError("Dataset not supported")

# Model 
model = AllConvNet(args)
print(f"Model: {model.name}")

# Parameters
total_params, trainable_params = model.parameter_count()
print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")
args.total_params = total_params
args.trainable_params = trainable_params

# Set the seed for reproducibility
set_seed(args.seed)


# Training Modules 
train_eval_results = Train_Eval(args, 
                            model, 
                            dataset.train_loader, 
                            dataset.test_loader
                            )

# Storing Results in output directory 
write_to_file(os.path.join(args.output_dir, "args.txt"), args)
write_to_file(os.path.join(args.output_dir, "model.txt"), model)
write_to_file(os.path.join(args.output_dir, "train_eval_results.txt"), train_eval_results)

Files already downloaded and verified
Upscale transform not defined. Skipping dataset upscale.
Model: All Convolutional Network ConvNN
Total Parameters: 123338
Trainable Parameters: 123338
[Epoch 001] Time: 9.6555s | [Train] Loss: 2.00876740 Accuracy: Top1: 28.9242%, Top5: 79.1820% | [Test] Loss: 1.82081600 Accuracy: Top1: 37.6095%, Top5: 87.2811%
[Epoch 002] Time: 9.7364s | [Train] Loss: 1.77392406 Accuracy: Top1: 40.9607%, Top5: 88.5730% | [Test] Loss: 1.68247658 Accuracy: Top1: 45.3921%, Top5: 90.6847%
[Epoch 003] Time: 9.6794s | [Train] Loss: 1.67284842 Accuracy: Top1: 46.1857%, Top5: 90.9607% | [Test] Loss: 1.60171279 Accuracy: Top1: 49.9701%, Top5: 92.0880%
[Epoch 004] Time: 9.7322s | [Train] Loss: 1.61192420 Accuracy: Top1: 49.4865%, Top5: 92.3054% | [Test] Loss: 1.56095310 Accuracy: Top1: 51.7217%, Top5: 93.0434%
[Epoch 005] Time: 9.4963s | [Train] Loss: 1.57910795 Accuracy: Top1: 50.7992%, Top5: 92.9408% | [Test] Loss: 1.53900927 Accuracy: Top1: 52.6373%, Top5: 93.9192%
[Epoch

KeyboardInterrupt: 

# New Conv2d with pixel shuffle n coordinate 

In [None]:
from types import SimpleNamespace

# Create default args
args = SimpleNamespace(
    layer="Conv2d",
    num_layers=3,
    channels=[8, 16, 32],
    K=9,
    kernel_size=3,
    sampling_type="all",
    num_samples=-1,
    sample_padding=0,
    num_heads=4,
    attention_dropout=0.1,
    shuffle_pattern="BA",
    shuffle_scale=2,
    magnitude_type="similarity",
    coordinate_encoding=True,
    dataset="cifar10",
    data_path="./Data",
    batch_size=64,
    num_epochs=100,
    use_amp=False,
    clip_grad_norm=None,
    criterion="CrossEntropy",
    optimizer="adamw",
    momentum=0.9,
    weight_decay=1e-6,
    lr=1e-3,
    lr_step=20,
    lr_gamma=0.1,
    scheduler="step",
    device="cuda",
    seed=0,
    output_dir="./Output/Simple/Conv2d_New", 
    resize=False
)
    

In [None]:
# Check if the output directory exists, if not create it
if args.output_dir:
    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

# Dataset 
if args.dataset == "cifar10":
    dataset = CIFAR10(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size 
elif args.dataset == "cifar100":
    dataset = CIFAR100(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size 
elif args.dataset == "imagenet":
    dataset = ImageNet(args)
    args.num_classes = dataset.num_classes 
    args.img_size = dataset.img_size
else:
    raise ValueError("Dataset not supported")

# Model 
model = AllConvNet(args)
print(f"Model: {model.name}")

# Parameters
total_params, trainable_params = model.parameter_count()
print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")
args.total_params = total_params
args.trainable_params = trainable_params

# Set the seed for reproducibility
set_seed(args.seed)


# Training Modules 
train_eval_results = Train_Eval(args, 
                            model, 
                            dataset.train_loader, 
                            dataset.test_loader
                            )

# Storing Results in output directory 
write_to_file(os.path.join(args.output_dir, "args.txt"), args)
write_to_file(os.path.join(args.output_dir, "model.txt"), model)
write_to_file(os.path.join(args.output_dir, "train_eval_results.txt"), train_eval_results)