### Employing Kolmogorov-Arnold Networks In Context Of MNIST Dataset:
Courtesy of the developers of TorchKan: https://github.com/1ssb/torchkan 

In [None]:
#Feature Extractor
import torch
import torch.nn as nn
import torch.nn.functional as F


# Step 1 & 2
#Convolutional Feature Extraction: The model begins with two convolutional layers, 
# each paired with ReLU activation and max-pooling. The first layer employs 16 filters of size 3x3, 
# while the second increases the feature maps to 32 channels.


def feature_extractor():
    return nn.Sequential(
        nn.Conv2d(1,16,kernel_size = 3,stride = 1,padding =1),
        nn.ReLU(),
        nn.MaxPool2d(2,2),
        nn.Conv2d(16,32,kernel_size = 3,stride = 1,padding =1),
        nn.ReLU(),
        nn.MaxPool2d(2,2),
    )


Main Idea:
1. Extract the features via Convolutional Feature Extraction

2. Apply polynomial feature xform => up to n'th order to flattened convolutional output oin order to discern non-linear relationships

3. Monomials (Powers of the polynomials) are calculated to a specific order in order to capture non-linear interactions => richer and informative represenation downstream.

4. Monomials calculated are used to adjust the output of the linear layers before activation => introducing an additional dimension of feature intercation => enabling more complex learning of the data

In [None]:
class KANvolver(nn.Module):
    def __init__(self, layers_hidden, poly_order = 2, base_activation = nn.ReLU):
        super(KANvolver, self).__init__()
        self.layers_hidden = layers_hidden
        self.poly_order = poly_order
        self.base_activation = base_activation

        self.feature_extractor = feature_extractor()

        flat_features = 32 * 7 * 7
        self.layers_hidden = [flat_features] + self.layers_hidden

        self.base_weights = nn.ModuleList()
        self.poly_weights = nn.ModuleList()
        self.batch_norms = nn.ModuleList()


        for in_features, out in zip(self.layers_hidden[:-1], self.layers_hidden[1:]):
            self.base_weights.append(nn.Linear(in_features, out))
            self.poly_weights.append(nn.Linear(in_features * (poly_order + 1), out))
            self.batch_norms.append(nn.BatchNorm1d(out))

    # How Monomials Work: In the context of this model, monomials are polynomial powers of the input features. 
    # By computing monomials up to a specified order, the model can capture non-linear interactions between the features, 
    # potentially leading to richer and more informative representations for downstream tasks.


    # They are used to adjust output of the hidden layers before activation => addiitonal dimension of feature interaction => allowing more complex patterns to be learnt

    def compute_eff_monomials(self, x, order):
        power = torch.arange(order +1, device = x.device, dtype = x.dtype)
        x_expanded = x.unsqueeze(-1).repeat(1,1,order+1)
        return torch.pow(x_expanded, power)
    
    def forward(self, x):
        x = x.view(-1,1,28,28)
        x = self.feature_extractor(x)

        x = x.view(x.size(0), -1) # Flatten the features from conv layers

        for base, poly, batch_norm in zip(self.base_weights, self.poly_weights, self.batch_norm):
            output = base(x)
            monomial_base = self.compute_eff_monomials(x, self.poly_order)
            monomial_base = monomial_base.view(x.size(0), -1) # Flattening the basis
            output_poly = poly(monomial_base)
            x = self.base_activation(output + output_poly) # The part where the activation function is applied
        
        return x