In [1]:
from dataset import ImageNet, CIFAR10, CIFAR100
from train_eval import Train_Eval
from types import SimpleNamespace


In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

class Conv2d_NN(nn.Module):
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 K, 
                 stride, 
                 padding, 
                 sampling_type, 
                 num_samples, 
                 sample_padding,
                 shuffle_pattern, 
                 shuffle_scale, 
                 magnitude_type, 
                 similarity_type, 
                 aggregation_type, 
                 lambda_param
                ):

        super(Conv2d_NN, self).__init__()

        assert K == stride, "K must be equal to stride for ConvNN"
        assert padding > 0 or padding == 0, "Cannot have Negative Padding"
        assert shuffle_pattern in ["B", "A", "BA", "NA"], "Shuffle pattern must be: Before, After, Before After, Not Applicable"
        assert magnitude_type in ["cosine", "euclidean"], "Similarity Matrix must be either cosine similarity or euclidean distance"
        assert sampling_type in ["all", "random", "spatial"], "Consider all neighbors, random neighbors, or spatial neighbors"
        assert int(num_samples) > 0 or int(num_samples) == -1, "Number of samples to consider must be greater than 0 or -1 for all samples"
        assert (sampling_type == "all" and int(num_samples) == -1) or (sampling_type != "all" and isinstance(num_samples, int)), "Number of samples must be -1 for all samples or integer for random and spatial sampling"

        assert similarity_type in ["Loc", "Col", "Loc_Col"], "Similarity Matrix based on Location, Color, or both"
        assert aggregation_type in ["Col", "Loc_Col"], "Aggregation based on Color or Location and Color"

        # Core Parameters
        self.in_channels = in_channels 
        self.out_channels = out_channels 
        self.K = K
        self.stride = stride 
        self.padding = padding 

        # 3 Sampling Types: all, random, spatial
        self.sampling_type = sampling_type
        self.num_samples = int(num_samples)
        self.sample_padding = int(sample_padding) if sampling_type == "spatial" else 0

        # Pixel Shuffling (optional) 
        self.shuffle_pattern = shuffle_pattern
        self.shuffle_scale = shuffle_scale

        # Similarity Metric
        self.magnitude_type = magnitude_type
        self.maximum = True if magnitude_type == "cosine" else False

        # Similarity and Aggregation Types
        self.similarity_type = similarity_type
        self.aggregation_type = aggregation_type
        
        # Positional Encoding (optional)
        self.coordinate_encoding = True if (similarity_type in ["Loc", "Loc_Col"] or aggregation_type == "Loc_Col") else False
        self.coordinate_cache = {}

        # Pixel Shuffle Adjustments
        self.shuffle_layer = nn.PixelShuffle(upscale_factor=self.shuffle_scale) 
        self.unshuffle_layer = nn.PixelUnshuffle(downscale_factor=self.shuffle_scale)

        self.in_channels_1d = self.in_channels * (self.shuffle_scale ** 2) if self.shuffle_pattern in ["B", "BA"] else self.in_channels
        self.out_channels_1d = self.out_channels * (self.shuffle_scale ** 2) if self.shuffle_pattern in ["A", "BA"] else self.out_channels

        self.in_channels_1d = self.in_channels_1d + 2 if self.aggregation_type == "Loc_Col" else self.in_channels_1d

        # Conv1d Layer
        self.conv1d_layer = nn.Conv1d(
            in_channels = self.in_channels_1d,
            out_channels = self.out_channels_1d,
            kernel_size = self.K, 
            stride = self.stride, 
            padding = 0, 
            # bias = False # Only if similarity_type is "Loc" (make ConvNN exactly same as Conv2d)
        )

        # Flatten * Unflatten layers 
        self.flatten = nn.Flatten(start_dim=2)
        self.unflatten = None

        # Shapes
        self.og_shape = None 
        self.padded_shape = None

        # Utility Variables
        self.INF = 1e5
        self.NEG_INF = -1e5

        self.lambda_param = lambda_param
        # self.lambda_param = nn.Parameter(torch.tensor(0.5), requires_grad=True)


    def forward(self, x):  
        # 1. Pixel Unshuffle Layer
        x = self.unshuffle_layer(x) if self.shuffle_pattern in ["B", "BA"] else x
        self.og_shape = x.shape

        # 2. Add Padding 
        if self.padding > 0:
            x = F.pad(x, (self.padding, self.padding, self.padding, self.padding), mode='constant', value=0)
            self.padded_shape = x.shape

        # 3. Add Coordinate Encoding
        x = self._add_coordinate_encoding(x) if self.coordinate_encoding else x

        # 4. Flatten Layer
        x = self.flatten(x) 

        # 5. Similarity and Aggregation Type 
        if self.similarity_type == "Loc":
            x_sim = x[:, -2:, :]
        elif self.similarity_type == "Loc_Col":
            x_sim = x
        elif self.similarity_type == "Col" and self.aggregation_type == "Col":
            x_sim = x
        elif self.similarity_type == "Col" and self.aggregation_type == "Loc_Col":
            x_sim = x[:, :-2, :]

        if self.similarity_type in ["Loc", "Loc_Col"] and self.aggregation_type == "Col":
            x = x[:, :-2, :]
        else: 
            x = x

        if self.similarity_type == "Loc_Col":      
            # Normalize each modality to unit variance before combining
            color_feats = x_sim[:, :-2, :]
            color_std = torch.std(color_feats, dim=[1,2], keepdim=True) + 1e-6
            color_norm = color_feats / color_std

            coord_feats = x_sim[:, -2:, :]  # Already in [-1,1]
            x_sim = torch.cat([self.lambda_param * color_norm, 
                            (1-self.lambda_param) * coord_feats], dim=1)
            
        # 6. Sampling + Similarity Calculation + Aggregation
        if self.sampling_type == "all":
            similarity_matrix = self._calculate_euclidean_matrix(x_sim) if self.magnitude_type == "euclidean" else self._calculate_cosine_matrix(x_sim)
            prime = self._prime(x, similarity_matrix, self.K, self.maximum)
            
        elif self.sampling_type == "random":
            if self.num_samples > x.shape[-1]:
                x_sample = x_sim
                similarity_matrix = self._calculate_euclidean_matrix_N(x_sim, x_sample) if self.magnitude_type == "euclidean" else self._calculate_cosine_matrix_N(x_sim, x_sample)
                torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(-0.1 if self.magnitude_type == "euclidean" else 1.1)
                prime = self._prime(x, similarity_matrix, self.K, self.maximum)

            else:
                rand_idx = torch.randperm(x.shape[-1], device=x.device)[:self.num_samples]
                x_sample = x_sim[:, :, rand_idx]
                similarity_matrix = self._calculate_euclidean_matrix_N(x_sim, x_sample) if self.magnitude_type == "euclidean" else self._calculate_cosine_matrix_N(x_sim, x_sample)
                range_idx = torch.arange(len(rand_idx), device=x.device)
                similarity_matrix[:, rand_idx, range_idx] = self.INF if self.magnitude_type == "euclidean" else self.NEG_INF
                prime = self._prime_N(x, similarity_matrix, self.K, rand_idx, self.maximum)
            

        elif self.sampling_type == "spatial":
            if self.num_samples > self.og_shape[-2]:
                x_sample = x_sim
                similarity_matrix = self._calculate_euclidean_matrix_N(x_sim, x_sample) if self.magnitude_type == "euclidean" else self._calculate_cosine_matrix_N(x_sim, x_sample)
                torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(-0.1 if self.magnitude_type == "euclidean" else 1.1)
                prime = self._prime(x, similarity_matrix, self.K, self.maximum)
            else:
                x_ind = torch.linspace(0 + self.sample_padding, self.og_shape[-2] - self.sample_padding - 1, self.num_samples, device=x.device).to(torch.long)
                y_ind = torch.linspace(0 + self.sample_padding, self.og_shape[-1] - self.sample_padding - 1, self.num_samples, device=x.device).to(torch.long)
                x_grid, y_grid = torch.meshgrid(x_ind, y_ind, indexing='ij')
                x_idx_flat, y_idx_flat = x_grid.flatten(), y_grid.flatten()
                width = self.og_shape[-2]
                flat_indices = y_idx_flat * width + x_idx_flat
                x_sample = x_sim[:, :, flat_indices]

                similarity_matrix = self._calculate_euclidean_matrix_N(x_sim, x_sample) if self.magnitude_type == "euclidean" else self._calculate_cosine_matrix_N(x_sim, x_sample)

                range_idx = torch.arange(len(flat_indices), device=x.device)    
                similarity_matrix[:, flat_indices, range_idx] = self.INF if self.magnitude_type == "euclidean" else self.NEG_INF
                
                prime = self._prime_N(x, similarity_matrix, self.K, flat_indices, self.maximum)
        else:
            raise NotImplementedError("Sampling Type not Implemented")
        
        # 7. Conv1d Layer
        x = self.conv1d_layer(prime)

        # 8. Unflatten Layer
        if not self.unflatten: 
            self.unflatten = nn.Unflatten(dim=2, unflattened_size=self.og_shape[2:])
        x = self.unflatten(x)

        # 9. Pixel Shuffle Layer
        x = self.shuffle_layer(x) if self.shuffle_pattern in ["A", "BA"] else x 
        return x 

    def _calculate_euclidean_matrix(self, matrix, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
        dot_product = torch.matmul(matrix.transpose(1, 2), matrix)

        dist_matrix = norm_squared.transpose(1, 2) + norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        torch.diagonal(dist_matrix, dim1=1, dim2=2).fill_(-0.1)
        return dist_matrix
    
    def _calculate_euclidean_matrix_N(self, matrix, matrix_sample, sqrt=False):
        norm_squared = torch.sum(matrix ** 2, dim=1, keepdim=True)
        norm_squared_sample = torch.sum(matrix_sample ** 2, dim=1, keepdim=True)
        dot_product = torch.matmul(matrix.transpose(1, 2), matrix_sample)
        
        dist_matrix = norm_squared.transpose(1, 2) + norm_squared_sample - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0) 
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix
    
    def _calculate_cosine_matrix(self, matrix):
        norm_matrix = F.normalize(matrix, p=2, dim=1)
        similarity_matrix = torch.matmul(norm_matrix.transpose(1, 2), norm_matrix)
        similarity_matrix = torch.clamp(similarity_matrix, min=-1.0, max=1.0) 
        torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(1.1)
        return similarity_matrix
    
    def _calculate_cosine_matrix_N(self, matrix, matrix_sample):
        norm_matrix = F.normalize(matrix, p=2, dim=1) 
        norm_sample = F.normalize(matrix_sample, p=2, dim=1)
        similarity_matrix = torch.matmul(norm_matrix.transpose(1, 2), norm_sample)
        similarity_matrix = torch.clamp(similarity_matrix, min=-1.0, max=1.0) 
        return similarity_matrix
    
    def _prime(self, matrix, magnitude_matrix, K, maximum):
        b, c, t = matrix.shape

        if self.similarity_type == "Loc":
            topk_values, topk_indices = torch.sort(magnitude_matrix, dim=2, descending=maximum, stable=True)
            topk_indices = topk_indices[:, :, :K]
            topk_indices, _ = torch.sort(topk_indices, dim=-1)
        else:
            topk_values, topk_indices = torch.topk(magnitude_matrix, k=K, dim=2, largest=maximum)

        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)    
        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=topk_indices_exp)

        if self.padding > 0: 
            prime = prime.view(b, c, self.padded_shape[-2], self.padded_shape[-1], K)
            prime = prime[:, :, self.padding:-self.padding, self.padding:-self.padding, :]
            prime = prime.reshape(b, c, K * self.og_shape[-2] * self.og_shape[-1])
        else: 
            prime = prime.view(b, c, -1)

        return prime
        
    def _prime_N(self, matrix, magnitude_matrix, K, rand_idx, maximum):
        b, c, t = matrix.shape
        
        topk_values, topk_indices = torch.topk(magnitude_matrix, k=K-1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        # Map sample indices back to original matrix positions
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=matrix.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=2)
        if self.similarity_type == "Loc":
            final_indices, _ = torch.sort(final_indices, dim=-1)
        indices_expanded = final_indices.unsqueeze(1).expand(b, c, t, K)

        # Gather matrix values and apply similarity weighting
        matrix_expanded = matrix.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(matrix_expanded, dim=2, index=indices_expanded)  

        if self.padding > 0:
            prime = prime.view(b, c, self.padded_shape[-2], self.padded_shape[-1], K)
            prime = prime[:, :, self.padding:-self.padding, self.padding:-self.padding, :]
            prime = prime.reshape(b, c, K * self.og_shape[-2] * self.og_shape[-1])
        else:
            prime = prime.view(b, c, -1)
        return prime

    def _add_coordinate_encoding(self, x):
        b, _, h, w = x.shape
        cache_key = f"{b}_{h}_{w}_{x.device}"

        if cache_key in self.coordinate_cache:
            expanded_grid = self.coordinate_cache[cache_key]
        else:
            y_coords_vec = torch.linspace(start=-1, end=1, steps=h, device=x.device)
            x_coords_vec = torch.linspace(start=-1, end=1, steps=w, device=x.device)

            y_grid, x_grid = torch.meshgrid(y_coords_vec, x_coords_vec, indexing='ij')
            grid = torch.stack((x_grid, y_grid), dim=0).unsqueeze(0)
            expanded_grid = grid.expand(b, -1, -1, -1)
            self.coordinate_cache[cache_key] = expanded_grid

        x_with_coords = torch.cat((x, expanded_grid), dim=1)
        return x_with_coords ### Last two channels are coordinate channels 


class MultiHeadConvNNAttention(nn.Module):
    def __init__(self, 
                 d_hidden, 
                 num_heads, 
                 attention_dropout,
                 K, 
                 sampling_type, 
                 num_samples, 
                 sample_padding, 
                 magnitude_type, 
                 seq_length=197, 
                 coordinate_encoding=False
                 ):
        
        super(MultiHeadConvNNAttention, self).__init__()
        assert d_hidden % num_heads == 0, "d_hidden must be divisible by num_heads"

        # Core Parameters
        self.d_hidden = d_hidden
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.d_k = d_hidden // num_heads

        # ConvNN Parameters
        self.K = K
        self.seq_length = seq_length

        # 3 types of sampling: all, random, spatial
        self.sampling_type = sampling_type
        self.num_samples = int(num_samples) 
        self.sample_padding = int(sample_padding) if sampling_type == 'spatial' else 0    

        # Similarity Metric 
        self.magnitude_type = magnitude_type
        self.maximum = True if self.magnitude_type == 'cosine' else False

        # Coordinate Encoding (optional) 
        self.coordinate_encoding = coordinate_encoding
        self.coordinate_cache = {}
        
        # Linear projections for query, key, value
        self.W_q = nn.Linear(d_hidden, d_hidden)
        self.W_k = nn.Linear(d_hidden, d_hidden)
        self.W_v = nn.Linear(d_hidden, d_hidden)
        self.W_o = nn.Linear(d_hidden, d_hidden)   
        self.dropout = nn.Dropout(attention_dropout)

        self.in_channels = (d_hidden // num_heads) + 1 if coordinate_encoding else (d_hidden // num_heads)
        self.out_channels = (d_hidden // num_heads) 
        
        self.conv = nn.Conv1d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.K,
            stride=self.K,
            padding=0,
        )

        # Utility Variables 
        self.INF = 1.1
        self.NEG_INF = -0.1 
        
    def split_head(self, x): 
        batch_size, seq_length, d_hidden = x.size()
        self.batch_size = batch_size
        # self.seq_length = seq_length
        return x.contiguous().view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2) # (B, num_heads, seq_length, d_k)
        
    def combine_heads(self, x): 
        
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_hidden) 
    
    def batch_split(self, x): 
        x = x.reshape(self.batch_size, -1, self.d_k, self.seq_length)
        return x.permute(0, 1, 3, 2).contiguous()
        
    def batch_combine(self, x): 
        batch_size, _, seq_length, d_k = x.size()
        x = x.permute(0, 1, 3, 2).contiguous() 
        return x.view(-1, self.d_k, seq_length)
        
    def forward(self, x):
        # Note: x shape: (B, seq_length, d_hidden)
        # 1. Splithead & Batch Combine
        k = self.batch_combine(self.split_head(self.W_k(x)))
        v = self.batch_combine(self.split_head(self.W_v(x)))
        

        # 2. Add Coordinate Encoding 
        k = self._add_coordinate_encoding(k) if self.coordinate_encoding else k
        v = self._add_coordinate_encoding(v) if self.coordinate_encoding else v


        # 3. Sampling & Similarity Calculation
        if self.sampling_type == 'all': # All Samples
            q = self.batch_combine(self.split_head(self.W_q(x)))
            
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix(k, q, sqrt=True)

            
            prime = self._prime(v, similarity_matrix, self.K, self.maximum)

        elif self.sampling_type == 'random': # Random Samples
            rand_idx = torch.randperm(x.shape[1], device=x.device)[:self.num_samples]
            x_sample = x[:, rand_idx, :]            
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)
            range_idx = torch.arange(len(rand_idx), device=q.device)
            similarity_matrix[:, rand_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF


            prime = self._prime_N(v, similarity_matrix, self.K, rand_idx, self.maximum)

        elif self.sampling_type == 'spatial': # Spatial Samples
            spat_idx = torch.linspace(0 + self.sample_padding, x.shape[1] - self.sample_padding - 1, self.num_samples, device=x.device).long()
            x_sample = x[:, spat_idx, :]
            q = self.batch_combine(self.split_head(self.W_q(x_sample)))
            q = self._add_coordinate_encoding(q) if self.coordinate_encoding else q

            similarity_matrix = self._calculate_cosine_matrix_N(k, q) if self.magnitude_type == 'cosine' else self._calculate_euclidean_matrix_N(k, q, sqrt=True)
            range_idx = torch.arange(len(spat_idx), device=q.device)
            similarity_matrix[:, spat_idx, range_idx] = self.INF if self.magnitude_type == 'euclidean' else self.NEG_INF


            prime = self._prime_N(v, similarity_matrix, self.K, spat_idx, self.maximum)
            
        else: 
            raise ValueError("Invalid sampling_type. Must be one of ['all', 'random', 'spatial']")

        # 4. Conv1d Layer
        x = self.conv(prime)  

        # 5. Dropout + Reshape (B, seq_length, d_hidden)
        x = self.dropout(x)
        x = x.permute(0, 2, 1) 

        # 6. Final Linear Projection
        x = self.W_o(self.combine_heads(self.batch_split(x)))
        return x       

    def _calculate_euclidean_matrix(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        torch.diagonal(dist_matrix, dim1=1, dim2=2).fill_(-0.1) 
        return dist_matrix 

    def _calculate_euclidean_matrix_N(self, K, Q, sqrt=False):
        k_norm_squared = torch.sum(K**2, dim=1, keepdim=True)
        q_norm_squared = torch.sum(Q**2, dim=1, keepdim=True)
        dot_product = torch.bmm(K.transpose(1, 2), Q)

        dist_matrix = k_norm_squared.transpose(1, 2) + q_norm_squared - 2 * dot_product
        dist_matrix = torch.clamp(dist_matrix, min=0.0)
        dist_matrix = torch.sqrt(dist_matrix) if sqrt else dist_matrix
        return dist_matrix 

    def _calculate_cosine_matrix(self, K, Q):
        k_norm = F.normalize(K, p=2, dim=1)
        q_norm = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(k_norm.transpose(1, 2), q_norm)
        torch.diagonal(similarity_matrix, dim1=1, dim2=2).fill_(1.1)  # Fill diagonal with 1.1 to self-select
        return similarity_matrix

    def _calculate_cosine_matrix_N(self, K, Q):
        norm_k = F.normalize(K, p=2, dim=1)
        norm_q = F.normalize(Q, p=2, dim=1)
        similarity_matrix = torch.matmul(norm_k.transpose(1, 2), norm_q)
        similarity_matrix = torch.softmax(similarity_matrix, dim=-1)
        return similarity_matrix

    def _prime(self, v, qk, K, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K, dim=2, largest=maximum)
        topk_values = torch.softmax(topk_values, dim=-1)
        topk_indices_exp = topk_indices.unsqueeze(1).expand(b, c, t, K)
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K)

    

        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime 

        prime = prime.view(b, c, -1)

        return prime

    def _prime_N(self, v, qk, K, rand_idx, maximum):
        b, c, t = v.shape
        topk_values, topk_indices = torch.topk(qk, k=K-1, dim=2, largest=maximum)
        tk = topk_indices.shape[-1]
        assert K == tk + 1, "Error: K must be same as tk + 1. K == tk + 1."

        # Map sample indicies back to original matrix positions 
        mapped_tensor = rand_idx[topk_indices]
        token_indices = torch.arange(t, device=v.device).view(1, t, 1).expand(b, t, 1)
        final_indices = torch.cat([token_indices, mapped_tensor], dim=-1)
        topk_indices_exp = final_indices.unsqueeze(1).expand(b, c, t, K)

        # Expand topk values to match the shape of indices
        topk_values = torch.softmax(topk_values, dim=-1)
        topk_values_exp = topk_values.unsqueeze(1).expand(b, c, t, K-1)
        zeros = torch.zeros((b, c, t, 1), device=v.device)
        topk_values_exp = torch.cat((zeros, topk_values_exp), dim=-1)

        # Gather matrix values and apply similarity weighting 
        v_expanded = v.unsqueeze(-1).expand(b, c, t, K).contiguous()    
        prime = torch.gather(v_expanded, dim=2, index=topk_indices_exp)
        prime = topk_values_exp * prime

        prime = prime.view(b, c, -1)
        return prime
    
    def _add_coordinate_encoding(self, x):
        b, c, t = x.shape 
        cache_key = f"{b}_{t}_{x.device}"
        if cache_key in self.coordinate_cache: 
            expanded_coords = self.coordinate_cache[cache_key]
        else: 
            coords_vec = torch.linspace(start=-1, end=1, steps=t, device=x.device).unsqueeze(0).expand(b, -1) 
            expanded_coords = coords_vec.unsqueeze(1).expand(b, -1, -1) 
            self.coordinate_cache[cache_key] = expanded_coords

        x_with_coords = torch.cat([x, expanded_coords], dim=1) 
        return x_with_coords 

class MKNet(nn.Module):
    def __init__(self, 
                d_hidden = 128, 
                d_ffn = 512,
                num_layers = 8, 
                dropout = 0.1, 
                convnn_args = {
                    'K': 9,
                    'stride': 9,
                    'padding': 1,
                    'sampling_type': 'all',
                    'num_samples': -1,
                    'sample_padding': 0,
                    'shuffle_pattern': "NA", 
                    'shuffle_scale': 0,
                    'magnitude_type': "cosine",
                    'similarity_type': "Col",
                    'aggregation_type': "Col",
                    'lambda_param': 0.5
                }, 
                convnn_attn_args = {
                    "num_heads": 1,
                    "attention_dropout": 0.1, 
                    "K": 9,
                    "sampling_type": "all",
                    "num_samples": -1,
                    "sample_padding": 0,
                    "magnitude_type": "cosine",
                    "seq_length": 196, 
                    "coordinate_encoding": False
                }
                 ):

        super(MKNet, self).__init__()
        # Implementation of MKNet architecture goes here
        """
        MKNet architecture implementation
        """

        self.d_hidden = d_hidden
        self.num_layers = num_layers
        self.dropout = dropout
        self.convnn_args = convnn_args
        self.convnn_attn_args = convnn_attn_args

        self.patch_embedding = PatchEmbedding(
            d_hidden=d_hidden, 
            img_size=224, 
            patch_size=16, 
            n_channels=3
            )

        self.cross_layers = nn.Sequential(*[MKCrossBlock(
            d_hidden=self.d_hidden, 
            dropout = self.dropout, 
            convnn_args=self.convnn_args,
            convnn_attn_args=self.convnn_attn_args
            ) for _ in range(self.num_layers)])

        self.ffn_layers = MKFFN(
            d_hidden=self.d_hidden, 
            d_ffn=d_ffn, 
            dropout=self.dropout
            )

    def forward(self, x): # input [B, C, H, W]      
        x = self.patch_embedding(x)  # [B, d_hidden, H/patch_size, W/patch_size]
        x = self.cross_layers(x)     # [B, d_hidden, H/patch_size, W/patch_size]
        x = self.ffn_layers(x)       # [B, d_hidden, H/patch_size, W/patch_size]
        return x

    def parameter_count(self): 
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return total_params, trainable_params


class MKLocalBlock(nn.Module):
    def __init__(self, 
                d_hidden, 
                convnn_args):
        super(MKLocalBlock, self).__init__()
        # Implementation of MKNet block goes here
        """
        MKNet block implementation
        """
        self.d_hidden = d_hidden
        self.convnn = Conv2d_NN(
            in_channels=d_hidden,
            out_channels=d_hidden,
            K=convnn_args['K'],
            stride=convnn_args['stride'],
            padding=convnn_args['padding'],
            sampling_type=convnn_args['sampling_type'],
            num_samples=convnn_args['num_samples'],
            sample_padding=convnn_args['sample_padding'],
            shuffle_pattern=convnn_args['shuffle_pattern'],
            shuffle_scale=convnn_args['shuffle_scale'],
            magnitude_type=convnn_args['magnitude_type'],
            similarity_type=convnn_args['similarity_type'],
            aggregation_type=convnn_args['aggregation_type'],
            lambda_param=convnn_args['lambda_param']
        )
        self.conv = nn.Conv2d(d_hidden, d_hidden, kernel_size=3, padding=1, stride=1)

        self.gelu = nn.GELU()


    def forward(self, x): # input [B, C, H, W]
        x = self.convnn(x)
        # x = self.conv(x)
        x = self.gelu(x)
        return x

class MKGlobalBlock(nn.Module):
    def __init__(self, d_hidden, convnn_attn_args):
        super(MKGlobalBlock, self).__init__()
        # Implementation of MKNet global block goes here
        """
        MKNet global block implementation
        """
        self.convnn_attn = MultiHeadConvNNAttention(
            d_hidden=d_hidden,
            num_heads=convnn_attn_args["num_heads"],
            attention_dropout=convnn_attn_args["attention_dropout"],
            K=convnn_attn_args["K"],
            sampling_type=convnn_attn_args["sampling_type"],
            num_samples=convnn_attn_args["num_samples"],
            sample_padding=convnn_attn_args["sample_padding"],
            magnitude_type=convnn_attn_args["magnitude_type"],
            seq_length=convnn_attn_args["seq_length"],
            coordinate_encoding=convnn_attn_args["coordinate_encoding"]
        )
        self.norm = nn.LayerNorm(d_hidden)
        self.dropout = nn.Dropout(0.1) 

    def forward(self, x): # input [B, C, H, W]
        B, C, H, W = x.size()
        x_reshaped = x.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)  # [B, H*W, C]
        x_attn = self.convnn_attn(x_reshaped)  # [B, H*W, C]
        x_attn = self.norm(x_attn)
        x_attn = self.dropout(x_attn)
        x = x + x_attn.view(B, H, W, C).permute(0, 3, 1, 2)  # [B, C, H, W]
        return x

class MKCrossBlock(nn.Module):
    def __init__(self, 
                 d_hidden, 
                 dropout, 
                 convnn_args, 
                 convnn_attn_args
                 ):
        super(MKCrossBlock, self).__init__()
        # Implementation of MKNet cross block goes here
        """
        MKNet cross block implementation
        """

        self.local_block = MKLocalBlock(d_hidden=d_hidden, convnn_args=convnn_args)
        self.global_block = MKGlobalBlock(d_hidden=d_hidden, convnn_attn_args=convnn_attn_args)

        self.local_conv = nn.Conv2d(in_channels=d_hidden, 
                                    out_channels=d_hidden, 
                                    kernel_size=1,
                                    stride=1,
                                    padding=0, 
                                    bias=False)

        self.global_conv = nn.Conv2d(in_channels=d_hidden, 
                                     out_channels=d_hidden, 
                                     kernel_size=1,
                                     stride=1,
                                     padding=0, 
                                     bias=False)

        self.combine_conv = nn.Conv2d(in_channels=d_hidden, 
                                      out_channels=d_hidden, 
                                      kernel_size=1,
                                      stride=1,
                                      padding=0, 
                                      bias=False)

        self.dropout_local = nn.Dropout(dropout)
        self.dropout_global = nn.Dropout(dropout)

        self.norm_local = nn.LayerNorm(d_hidden)
        self.norm_global = nn.LayerNorm(d_hidden)
        self.norm_combine = nn.LayerNorm(d_hidden)

    def forward(self, x): # input [B, C, H, W]
        identity = x 
        
        # Local Branch
        x_local = self.norm_local(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 
        x_local = self.local_block(x)
        x_local = x + self.dropout_local(x_local)
        x_local = self.local_conv(x_local)

        # Global Branch
        x_global = self.norm_global(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) 
        x_global = self.global_block(x)
        x_global = x + self.dropout_global(x_global)
        x_global = self.global_conv(x_global)

        # Combine Local and Global
        x_combine = x_local + x_global + identity
        x_combine = self.norm_combine(x_combine.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
        x_combine = self.combine_conv(x_combine)

        x = x_combine
        return x

class MKFFN(nn.Module):
    def __init__(self, d_hidden, d_ffn, num_classes=100, dropout=0.1):
        super(MKFFN, self).__init__()
        self.fc1 = nn.Linear(d_hidden, d_ffn)
        self.fc2 = nn.Linear(d_ffn, d_hidden)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()
        
        # Use adaptive pooling instead of flattening all spatial dimensions
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.output_layer = nn.Linear(d_hidden, num_classes) 

    def forward(self, x):  # [B, C, H, W]
        # Apply FFN to spatial features
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).contiguous().view(B, H*W, C)  # [B, H*W, C]
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)  # [B, C, H, W]
        
        # Global pooling and classification
        x = self.pool(x)  # [B, C, 1, 1]
        x = x.flatten(1)  # [B, C]
        x = self.output_layer(x)  # [B, num_classes]
        return x

class PatchEmbedding(nn.Module):
    def __init__(self, d_hidden, img_size, patch_size, n_channels=3):
        super(PatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        self.proj = nn.Conv2d(n_channels, d_hidden, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(d_hidden)
        
        # Add learnable positional embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, d_hidden, 
                                                   img_size // patch_size, 
                                                   img_size // patch_size))

    def forward(self, x):
        x = self.proj(x)  # [B, d_hidden, H/P, W/P]
        x = x + self.pos_embed  # Add positional encoding
        x = x.permute(0, 2, 3, 1)  # [B, H/P, W/P, d_hidden]
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)  # [B, d_hidden, H/P, W/P]
        return x




In [3]:
mknet = MKNet()

args = SimpleNamespace()
args.resize = 224
args.augment = True
args.noise = 0
args.data_path = "./Data"
args.batch_size = 512
args.seed = 42
args.num_epochs = 100
args.criterion = "CrossEntropy"
args.optimizer = "adamw"
args.lr = 1e-3
args.weight_decay = 2e-4
args.scheduler = "cosine"
args.device = "cuda"
args.use_amp = False

mknet = MKNet().to(args.device)

dataset = CIFAR100(args)
args.num_classes = dataset.num_classes 

# Training Modules 
train_eval_results = Train_Eval(args, 
                                    mknet, 
                                    dataset.train_loader, 
                                    dataset.test_loader
                                    )


Files already downloaded and verified
Files already downloaded and verified


STAGE:2025-10-25 21:14:17 670132:670132 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2025-10-25 21:14:18 670132:670132 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2025-10-25 21:14:18 670132:670132 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


   - Trainable Parameters: 4.74019600 M
Model Complexity (Profiler):
   - GFLOPs: 0.66533120
   - Trainable Parameters: 4.74019600 M
[Epoch 001] Time: 55.5390s | [Train] Loss: 4.28792878 Accuracy: Top1: 4.0585%, Top5: 16.0179% | [Test] Loss: 4.10508718 Accuracy: Top1: 6.9267%, Top5: 23.0480%
[Epoch 002] Time: 57.3784s | [Train] Loss: 3.96277302 Accuracy: Top1: 8.3790%, Top5: 27.1630% | [Test] Loss: 3.87655861 Accuracy: Top1: 10.0236%, Top5: 30.0408%
[Epoch 003] Time: 55.6566s | [Train] Loss: 3.74826213 Accuracy: Top1: 11.7288%, Top5: 34.2549% | [Test] Loss: 3.66287577 Accuracy: Top1: 13.6259%, Top5: 37.7625%
[Epoch 004] Time: 57.6250s | [Train] Loss: 3.57093498 Accuracy: Top1: 14.4844%, Top5: 39.4177% | [Test] Loss: 3.50172971 Accuracy: Top1: 16.0811%, Top5: 41.9784%
[Epoch 005] Time: 57.7203s | [Train] Loss: 3.39235998 Accuracy: Top1: 17.8017%, Top5: 44.5895% | [Test] Loss: 3.34094841 Accuracy: Top1: 19.2900%, Top5: 47.1783%
[Epoch 006] Time: 57.4128s | [Train] Loss: 3.27581040 Accura