In [16]:
import torch
import torch.nn as nn
import sys
sys.path.append('..')
from ptflops import get_model_complexity_info
from src import get_model
import argparse
from src.layers.KANLinear import KANLinear
from copy import deepcopy

In [17]:
import torch
import torch.nn as nn
import numpy as np
from src.layers.spline import *


class KANLinear(nn.Module):
    """
    KANLayer class
    

    Attributes:
    -----------
        in_dim: int
            input dimension
        out_dim: int
            output dimension
        num: int
            the number of grid intervals
        k: int
            the piecewise polynomial order of splines
        noise_scale: float
            spline scale at initialization
        coef: 2D torch.tensor
            coefficients of B-spline bases
        scale_base_mu: float
            magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_mu
        scale_base_sigma: float
            magnitude of the residual function b(x) is drawn from N(mu, sigma^2), mu = sigma_base_sigma
        scale_sp: float
            mangitude of the spline function spline(x)
        base_activation: fun
            residual function b(x)
        grid_eps: float in [0,1]
            a hyperparameter used in update_grid_from_samples. When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
            the id of activation functions that are locked
        device: str
            device
    """

    def __init__(self, 
        in_dim=3,
        out_dim=2,
        num=5, 
        k=3, 
        noise_scale=0.5, 
        scale_base_mu=0.0, 
        scale_base_sigma=1.0, 
        scale_sp=1.0, 
        base_activation=torch.nn.SiLU(), 
        grid_eps=0.02, 
        grid_range=[-1, 1], 
        sp_trainable=True, 
        sb_trainable=True, 
        save_plot_data = True,
        stochastic_variance = 0.1,
        device='cpu', 
        neuron_fun=None,
        noise_type=None
        ):
        ''''
        initialize a KANLayer
        
        Args:
        -----
            in_dim : int
                input dimension. Default: 2.
            out_dim : int
                output dimension. Default: 3.
            num : int
                the number of grid intervals = G. Default: 5.
            k : int
                the order of piecewise polynomial. Default: 3.
            noise_scale : float
                the scale of noise injected at initialization. Default: 0.5.
            scale_base_mu : float
                the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
            scale_base_sigma : float
                the scale of the residual function b(x) is intialized to be N(scale_base_mu, scale_base_sigma^2).
            scale_sp : float
                the scale of the base function spline(x).
            base_activation: function
                residual function b(x). Default: torch.nn.SiLU()
            grid_eps : float
                When grid_eps = 1, the grid is uniform; when grid_eps = 0, the grid is partitioned using percentiles of samples. 0 < grid_eps < 1 interpolates between the two extremes.
            grid_range : list/np.array of shape (2,)
                setting the range of grids. Default: [-1,1].
            sp_trainable : bool
                If true, scale_sp is trainable
            sb_trainable : bool
                If true, scale_base is trainable
            device : str
                device
            sparse_init : bool
                if sparse_init = True, sparse initialization is applied.
            
        Returns:
        --------
            self
            
        Example
        -------
        >>> from kan.KANLayer import *
        >>> model = KANLayer(in_dim=3, out_dim=5)
        >>> (model.in_dim, model.out_dim)
        '''
        super(KANLinear, self).__init__()
        print(f"Noise type: {noise_type}")
        print(f"Neuron fun: {neuron_fun}")
        print(f"Out dim: : {out_dim}")
        print(f"In dim:  {in_dim}")
        print(f"Grid size:  {num}")
        print(f"Spline order:  {k}")

        # size 
        self.out_dim = out_dim
        self.in_dim = in_dim
        self.num = num
        self.k = k
        print(f"num: {num}, grid_range: {grid_range}")

        grid = torch.linspace(grid_range[0], grid_range[1], steps=num + 1)[None,:].expand(self.in_dim, num+1)

        grid = extend_grid(grid, k_extend=k)

        self.grid = torch.nn.Parameter(grid).requires_grad_(False)

        noises = (torch.rand(self.num+1, self.in_dim, self.out_dim) - 1/2) * noise_scale / num

        self.coef = torch.nn.Parameter(curve2coef(self.grid[:,k:-k].permute(1,0), noises, self.grid, k))
        
        self.scale_base = torch.nn.Parameter(scale_base_mu * 1 / np.sqrt(in_dim) + \
                         scale_base_sigma * (torch.rand(in_dim, out_dim)*2-1) * 1/np.sqrt(in_dim)).requires_grad_(sb_trainable)

        self.scale_sp = torch.nn.Parameter(torch.ones(in_dim, out_dim) * scale_sp * 1 / np.sqrt(in_dim)).requires_grad_(sp_trainable)  # make scale trainable

        self.base_activation = base_activation
        
        self.stochastic_variance = torch.nn.Parameter(torch.tensor(stochastic_variance)).requires_grad_(True)

        self.neuron_fun = neuron_fun

        self.grid_eps = grid_eps

        self.noise_type = noise_type
        print(f"Noise type: {noise_type}")
        print(f"Neuron fun: {neuron_fun}")
        print(f"Out dim: : {out_dim}")
        print(f"In dim:  {in_dim}")
        print(f"Grid size:  {num}")
        print(f"Spline order:  {k}")


        self.to(device)
        
    def to(self, device):
        super(KANLinear, self).to(device)
        self.device = device    
        return self

    def forward(self, x):
        '''
        KANLayer forward given input x
        
        Args:
        -----
            x : 2D torch.float
                inputs, shape (number of samples, input dimension)
            
        Returns:
        --------
            y : 2D torch.float
                outputs, shape (number of samples, output dimension)
            preacts : 3D torch.float
                fan out x into activations, shape (number of sampels, output dimension, input dimension)
            postacts : 3D torch.float
                the outputs of activation functions with preacts as inputs
            postspline : 3D torch.float
                the outputs of spline functions with preacts as inputs
        
        Example
        -------
        >>> from kan.KANLayer import *
        >>> model = KANLayer(in_dim=3, out_dim=5)
        >>> x = torch.normal(0,1,size=(100,3))
        >>> y, preacts, postacts, postspline = model(x)
        >>> y.shape, preacts.shape, postacts.shape, postspline.shape
        '''
        batch = x.shape[0]            
        base = self.base_activation(x) # (batch, in_dim)
        y = coef2curve(x_eval=x, grid=self.grid, coef=self.coef, k=self.k)
        y = self.scale_base[None,:,:] * base[:,:,None] + self.scale_sp[None,:,:] * y
        # print(f"Shape of y before summation: {y.shape}")  # Add this line
        if self.neuron_fun == "sum":
            y = torch.sum(y, dim=1)
        elif self.neuron_fun == "mean":
            y = torch.mean(y, dim=1)
            
        # if self.training:
        #     # print("I'm training")
        #     if self.noise_type == "uniform":
        #         y += (torch.rand_like(y) * 2 - 1 ) * self.stochastic_variance        # Uniform range [-1, 1)
        #     elif self.noise_type == "normal":
        #         y += torch.randn_like(y) * self.stochastic_variance                  # Normal distribution

        return y


In [18]:
from timm.layers import trunc_normal_
from src.layers.custom_layers import LayerNorm, PositionalEncodingFourier
from src.layers.sdta_encoder import SDTAEncoder
from src.layers.conv_encoder import ConvEncoder
from src.layers.LoRaLin import LoRaLin
# from src.layers.KANLinear import KANLinear

class EdgeFaceKAN(nn.Module):
    def __init__(self, in_chans=3, num_features=512, rank_ratio = 0.6,
                 depths=[3, 3, 9, 3], dims=[32, 64, 100, 192],
                 global_block=[0, 0, 0, 3], global_block_type=['None', 'None', 'None', 'SDTA'],
                 drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1., expan_ratio=4,
                 kernel_sizes=[7, 7, 7, 7], heads=[8, 8, 8, 8], use_pos_embd_xca=[False, False, False, False],
                 use_pos_embd_global=False, d2_scales=[2, 3, 4, 5], grid_size=5, spline_order=3,
                 base_activation=nn.SiLU(), neuron_fun=None, noise_type=None):
        super().__init__()
        for g in global_block_type:
            assert g in ['None', 'SDTA']
        if use_pos_embd_global:
            self.pos_embd = PositionalEncodingFourier(dim=dims[0])
        else:
            self.pos_embd = None
        self.downsample_layers = nn.ModuleList()  # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()  # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage_blocks = []
            for j in range(depths[i]):
                if j > depths[i] - global_block[i] - 1:
                    if global_block_type[i] == 'SDTA':
                        stage_blocks.append(SDTAEncoder(dim=dims[i], rank_ratio=rank_ratio,drop_path=dp_rates[cur + j],
                                                        expan_ratio=expan_ratio, scales=d2_scales[i],
                                                        use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i]))
                    else:
                        raise NotImplementedError
                else:
                    stage_blocks.append(ConvEncoder(dim=dims[i], rank_ratio=rank_ratio,drop_path=dp_rates[cur + j],
                                                    layer_scale_init_value=layer_scale_init_value,
                                                    expan_ratio=expan_ratio, kernel_size=kernel_sizes[i]))

            self.stages.append(nn.Sequential(*stage_blocks))
            cur += depths[i]
        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # Final norm layer
        print(f"grid_size in EdgeFaceKAN: {grid_size}")
        self.head = KANLinear(dims[-1], num_features,
            num=grid_size,
            k=spline_order,
            base_activation=base_activation,
            neuron_fun=neuron_fun,
            noise_type=noise_type,
        )
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        x = self.downsample_layers[0](x)
        x = self.stages[0](x)
        if self.pos_embd:
            B, C, H, W = x.shape
            x = x + self.pos_embd(B, H, W)
        for i in range(1, 4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)

        return self.norm(x.mean([-2, -1]))  # Global average pooling, (N, C, H, W) -> (N, C)

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

In [19]:
model = EdgeFaceKAN(num_features=512, grid_size=5)
model

grid_size in EdgeFaceKAN: 5
Noise type: None
Neuron fun: None
Out dim: : 512
In dim:  192
Grid size:  5
Spline order:  3
num: 5, grid_range: [-1, 1]
Noise type: None
Neuron fun: None
Out dim: : 512
In dim:  192
Grid size:  5
Spline order:  3


EdgeFaceKAN(
  (downsample_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm()
    )
    (1): Sequential(
      (0): LayerNorm()
      (1): Conv2d(32, 64, kernel_size=(2, 2), stride=(2, 2))
    )
    (2): Sequential(
      (0): LayerNorm()
      (1): Conv2d(64, 100, kernel_size=(2, 2), stride=(2, 2))
    )
    (3): Sequential(
      (0): LayerNorm()
      (1): Conv2d(100, 192, kernel_size=(2, 2), stride=(2, 2))
    )
  )
  (stages): ModuleList(
    (0): Sequential(
      (0): ConvEncoder(
        (dwconv): Conv2d(32, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=32)
        (norm): LayerNorm()
        (pwconv1): LoRaLin(
          (linear1): Linear(in_features=32, out_features=19, bias=False)
          (linear2): Linear(in_features=19, out_features=128, bias=True)
        )
        (act): GELU(approximate='none')
        (pwconv2): LoRaLin(
          (linear1): Linear(in_features=128, out_features=

In [5]:
FLOPs_MAP = {
    "zero": 0,
    "identity": 0,
    "relu": 1,
    'square_relu': 2,
    "sigmoid":4,
    "silu":5,
    "tanh":6,
    "gelu": 14,
    "polynomial2": 1+2+3-1,
    "polynomial3": 1+2+3+4-1,
    "polynomial5": 1+2+3+4+5-1,
}

In [6]:
class CustomIdentity(nn.Module):
    def __init__(self, in_dim=None, out_dim=None):
        super(CustomIdentity, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim

    def forward(self, x):
        return x  # Pass through the input without modification

In [7]:
class ModelWithoutKANLinear(nn.Module):
    def __init__(self, model):
        super(ModelWithoutKANLinear, self).__init__()
        # Create a deepcopy of the model
        self.model = deepcopy(model)
        # Replace all KANLinear layers with CustomIdentity
        for name, module in list(self.model.named_modules()):  # Use list() to avoid size modification issues
            if isinstance(module, KANLinear):
                # Create a CustomIdentity layer with the same in_features and out_features
                custom_identity = CustomIdentity(in_dim=module.in_dim, out_dim=module.out_dim)
                # Locate the parent module and replace the KANLinear layer
                parent_module, attr_name = self._find_parent_module(name)
                if parent_module is not None:
                    setattr(parent_module, attr_name, custom_identity)

    def _find_parent_module(self, module_name):
        """
        Find the parent module and the attribute name corresponding to `module_name`.
        """
        names = module_name.split('.')
        parent = self.model
        for name in names[:-1]:
            parent = getattr(parent, name, None)
            if parent is None:
                return None, None
        return parent, names[-1]

    def forward(self, x):
        return self.model(x)

In [None]:
model = ModelWithoutKANLinear()

In [8]:
def custom_kan_linear_hook(module, input, output):
    # Extract input and output dimensions of the KANLinear layer
    din = input[0].shape[1]  # Input size (features in the input tensor)
    dout = output.shape[1]  # Output size (features in the output tensor)
    grid_size = module.num  # Grid size defined in KANLinear
    spline_order = module.k  # Spline order defined in KANLinear

    print(f"din: {din}, dout: {dout}, grid_size: {grid_size}, spline_order: {spline_order}")

    # Automatically calculate FLOPs and Parameters for KANLinear
    custom_flops = layer_flops(din, dout, shortcut_name="silu", grid=grid_size, k=spline_order)
    custom_params = layer_parameters(din, dout, shortcut_name="silu", grid=grid_size, k=spline_order)
    # Store custom FLOPs and Parameters as an attribute on the module
    module.__custom_flops__ = custom_flops
    module.__custom_params__ = custom_params

    # We return the output without modifying it
    return output

In [9]:
def register_kan_linear_hooks(model):
    for name, module in model.named_modules():
        if isinstance(module, KANLinear):
            module.register_forward_hook(custom_kan_linear_hook)

In [10]:
def layer_flops(din, dout, shortcut_name="silu", grid=5, k=3):
    """
    Custom FLOPs calculation for KANLinear.
    Args:
        din (int): Input dimensions.
        dout (int): Output dimensions.
        shortcut_name (str): Name of the shortcut activation. Default is "silu".
        grid (int): Grid size parameter.
        k (int): Spline order parameter.
    Returns:
        int: Calculated FLOPs for KANLinear.
    """
    flops = (din * dout) * (9 * k * (grid + 1.5 * k) + 2 * grid - 2.5 * k + 1)
    
    # Shortcut FLOPs
    if shortcut_name == "zero":
        shortcut_flops = 0
    else:
        shortcut_flops = FLOPs_MAP[shortcut_name] * din + 2 * din * dout
    
    return flops + shortcut_flops

In [11]:
def layer_parameters(din, dout, shortcut_name="silu", grid=5, k=3):
    """
    Custom Parameters calculation for KANLinear.
    Args:
        din (int): Input dimensions.
        dout (int): Output dimensions.
        shortcut_name (str): Name of the shortcut activation. Default is "silu".
        grid (int): Grid size parameter.
        k (int): Spline order parameter.
    Returns:
        int: Calculated Parameters for KANLinear.
    """
    parameters = din * dout * (grid + k + 2) + dout
    if shortcut_name == "zero":
        shortcut_parameters = 0
    else:
        shortcut_parameters = din * dout
    return parameters + shortcut_parameters

In [12]:

def info(grid_size=15, neuron_fun="mean", num_features=128, noise_type=None):
    # Load the model
    net = EdgeFaceKAN(grid_size=grid_size, neuron_fun=neuron_fun, num_features=num_features, noise_type=noise_type)
    # Step 1: Compute Base FLOPs (excluding KANLinear)
    net_without_kan = ModelWithoutKANLinear(net)
    macs, params = get_model_complexity_info(
        net_without_kan, (3, 112, 112), backend='pytorch', as_strings=False,
        print_per_layer_stat=False, verbose=True
    )
    base_flops = int(macs) * 2
    base_params = int(params)

    # Step 2: Register the custom hooks to KANLinear layers
    register_kan_linear_hooks(net)

    # Step 3: Calculate Total FLOPs (including custom KANLinear FLOPs)
    macs, params = get_model_complexity_info(
        net, (3, 112, 112), backend='pytorch', as_strings=False,
        print_per_layer_stat=False, verbose=True
    )

    total_flops = base_flops
    total_params = base_params
    for name, module in net.named_modules():
        if isinstance(module, KANLinear):
            total_flops += getattr(module, "__custom_flops__", 0)  # Add custom FLOPs from hook
            total_params += getattr(module, "__custom_params__", 0)  # Add custom Parameters from hook

    # Step 4: Print Results
    print(f"Base FLOPs (excluding KANLinear): {base_flops} FLOPs")
    print(f"KANLinear FLOPs: {total_flops - base_flops} FLOPs")
    print(f"Total FLOPs (with custom KANLinear FLOPs): {total_flops} FLOPs")
    print(f"Base Parameters (excluding KANLinear): {base_params} Parameters")
    print(f"KANLinear Parameters: {total_params - base_params} Parameters")
    print(f"Total Parameters (with custom KANLinear Parameters): {total_params} Parameters")

In [13]:
info(grid_size=5, neuron_fun="mean", num_features=128, noise_type=None)

grid_size in EdgeFaceKAN: 5
Noise type: None
Neuron fun: mean
Out dim: : 128
In dim:  192
Grid size:  5
Spline order:  3
num: 5, grid_range: [-1, 1]
Noise type: None
Neuron fun: mean
Out dim: : 128
In dim:  192
Grid size:  5
Spline order:  3
din: 192, dout: 128, grid_size: 5, spline_order: 3
Base FLOPs (excluding KANLinear): 158808464 FLOPs
KANLinear FLOPs: 6439872.0 FLOPs
Total FLOPs (with custom KANLinear FLOPs): 165248336.0 FLOPs
Base Parameters (excluding KANLinear): 1884916 Parameters
KANLinear Parameters: 270464 Parameters
Total Parameters (with custom KANLinear Parameters): 2155380 Parameters


In [14]:
from torchinfo import summary

model = EdgeFaceKAN(grid_size=5, neuron_fun="mean", num_features=128, noise_type=None)
summary(model, input_size=(1, 3, 112, 112))

grid_size in EdgeFaceKAN: 5
Noise type: None
Neuron fun: mean
Out dim: : 128
In dim:  192
Grid size:  5
Spline order:  3
num: 5, grid_range: [-1, 1]
Noise type: None
Neuron fun: mean
Out dim: : 128
In dim:  192
Grid size:  5
Spline order:  3


Layer (type:depth-idx)                        Output Shape              Param #
EdgeFaceKAN                                   [1, 128]                  --
├─ModuleList: 1-7                             --                        (recursive)
│    └─Sequential: 2-1                        [1, 32, 28, 28]           --
│    │    └─Conv2d: 3-1                       [1, 32, 28, 28]           1,568
│    │    └─LayerNorm: 3-2                    [1, 32, 28, 28]           64
├─ModuleList: 1-8                             --                        (recursive)
│    └─Sequential: 2-2                        [1, 32, 28, 28]           --
│    │    └─ConvEncoder: 3-3                  [1, 32, 28, 28]           7,936
│    │    └─ConvEncoder: 3-4                  [1, 32, 28, 28]           7,936
│    │    └─ConvEncoder: 3-5                  [1, 32, 28, 28]           7,936
├─ModuleList: 1-7                             --                        (recursive)
│    └─Sequential: 2-3                        [1, 64, 14

In [15]:
info(num=128, grid=10, noise=None)

TypeError: info() got an unexpected keyword argument 'num'