In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import math
import matplotlib.pyplot as plt
import matplotlib as mpl

def laguerre_polynomials(x, degree):
    """
    Compute the first `degree` Laguerre polynomials L_0(x), ..., L_{degree-1}(x).
    x should be a tensor of shape [...], returns a tensor of shape [..., degree].
    """
    L = [torch.ones_like(x), 1 - x]  # L_0, L_1

    for n in range(2, degree):
        Ln = ((2 * n - 1 - x) * L[-1] - (n - 1) * L[-2]) / n
        L.append(Ln)

    return torch.stack(L[:degree], dim=-1)  # [..., degree]

class KANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, grid=3, use_resid=False, resid=nn.SiLU):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.grid = grid
        self.use_resid = use_resid

        self.laguerre_scale = nn.Parameter(torch.ones(input_dim))  # Learnable scale for each input dim

        self.kan_proj = nn.Linear(input_dim * grid, output_dim)

        if self.use_resid:
            self.resid_act = resid()
            self.resid_linear = nn.Linear(input_dim, output_dim)
        else:
            self.resid_act = None
            self.resid_linear = None

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.kaiming_normal_(self.kan_proj.weight, nonlinearity='linear')
        nn.init.zeros_(self.kan_proj.bias)
        if self.use_resid:
            nn.init.kaiming_normal_(self.resid_linear.weight, nonlinearity='linear')
            nn.init.zeros_(self.resid_linear.bias)

    def kan_transform(self, x):
        # x: [B, L, D]
        B, L, D = x.shape
        x_scaled = x * self.laguerre_scale  # [B, L, D]
        x_reshaped = x_scaled.view(B * L * D)
        basis = laguerre_polynomials(x_reshaped, self.grid)  # [B*L*D, G]
        basis = basis.view(B, L, D, self.grid)  # [B, L, D, G]
        phi = basis.reshape(B, L, D * self.grid)
        return self.kan_proj(phi)

    def forward(self, x):
        out = self.kan_transform(x)
        if self.use_resid:
            res = self.resid_linear(self.resid_act(x))
            out = out + res
        return out


In [3]:
def image_to_patches(x, patch_size):
    B, C, H, W = x.shape
    x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    x = x.permute(0, 2, 3, 1, 4, 5).contiguous()  # [B, num_patches_y, num_patches_x, C, pH, pW]
    x = x.view(B, -1, C * patch_size * patch_size)  # [B, num_patches, patch_dim]
    return x

In [4]:
class KANBlock(nn.Module):
    def __init__(self, input_dim, output_dim, grid=4, dropout=0.5):
        super().__init__()
        self.kan = KANLayer(input_dim, output_dim, grid=grid, use_resid=True)
        self.norm = nn.LayerNorm(output_dim)
        self.dropout = nn.Dropout(dropout)
        self.res_connection = (
            nn.Linear(input_dim, output_dim) if input_dim != output_dim else nn.Identity()
        )

    def forward(self, x):
        res = self.res_connection(x)
        x = self.kan(x)
        x = self.norm(x + res)  # Add residual and normalize
        x = self.dropout(x)
        return x

class KANForCIFAR(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes, grid=4, dropout=0.1):
        super().__init__()
        self.kan1 = KANBlock(input_dim=input_dim, output_dim=hidden_dim, grid=grid, dropout=dropout)
        self.kan2 = KANBlock(input_dim=hidden_dim, output_dim=hidden_dim, grid=grid, dropout=dropout)
        self.pool = nn.AdaptiveAvgPool1d(1)  # Pool over patch tokens
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = image_to_patches(x, patch_size=4)  # [B, P, D]
        x = self.kan1(x)                        # [B, P, H]
        x = self.kan2(x)                        # [B, P, H]
        x = x.permute(0, 2, 1)                  # [B, H, P]
        x = self.pool(x).squeeze(-1)            # [B, H]
        return self.fc(x)



In [5]:
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Hyperparameters
    patch_size = 4
    patch_dim = 3 * patch_size * patch_size
    hidden_dim = 128
    num_classes = 10
    grid = 4
    epochs = 100
    batch_size = 128

    model = KANForCIFAR(
        input_dim=patch_dim,
        hidden_dim=hidden_dim,
        num_classes=num_classes,
        grid=grid
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
    criterion = nn.CrossEntropyLoss()

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_data = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    test_data = datasets.CIFAR10(root="./data", train=False, transform=transform)

    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

    for epoch in range(epochs):
        model.train()
        total_loss, correct_train = 0.0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * images.size(0)
            correct_train += (logits.argmax(dim=1) == labels).sum().item()

        train_acc = correct_train / len(train_loader.dataset)
        avg_loss = total_loss / len(train_loader.dataset)

        # Evaluation on test set
        model.eval()
        correct_test = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                logits = model(images)
                correct_test += (logits.argmax(dim=1) == labels).sum().item()

        test_acc = correct_test / len(test_loader.dataset)

        # Print both train and test accuracy
        print(f"Epoch {epoch+1:03d}: Loss = {avg_loss:.4f}, Train Acc = {train_acc:.4f}, Test Acc = {test_acc:.4f}")

    return model 


In [6]:
model = train()

100%|██████████| 170M/170M [00:11<00:00, 14.5MB/s]


Epoch 001: Loss = 1.9632, Train Acc = 0.2704, Test Acc = 0.3187
Epoch 002: Loss = 1.7718, Train Acc = 0.3346, Test Acc = 0.3429
Epoch 003: Loss = 1.7150, Train Acc = 0.3585, Test Acc = 0.3772
Epoch 004: Loss = 1.6734, Train Acc = 0.3804, Test Acc = 0.3792
Epoch 005: Loss = 1.6310, Train Acc = 0.4025, Test Acc = 0.4114
Epoch 006: Loss = 1.5991, Train Acc = 0.4147, Test Acc = 0.4125
Epoch 007: Loss = 1.5670, Train Acc = 0.4274, Test Acc = 0.4301
Epoch 008: Loss = 1.5484, Train Acc = 0.4362, Test Acc = 0.4318
Epoch 009: Loss = 1.5344, Train Acc = 0.4408, Test Acc = 0.4383
Epoch 010: Loss = 1.5192, Train Acc = 0.4497, Test Acc = 0.4401
Epoch 011: Loss = 1.5068, Train Acc = 0.4546, Test Acc = 0.4500
Epoch 012: Loss = 1.5013, Train Acc = 0.4556, Test Acc = 0.4594
Epoch 013: Loss = 1.4910, Train Acc = 0.4582, Test Acc = 0.4525
Epoch 014: Loss = 1.4811, Train Acc = 0.4643, Test Acc = 0.4566
Epoch 015: Loss = 1.4770, Train Acc = 0.4660, Test Acc = 0.4671
Epoch 016: Loss = 1.4693, Train Acc = 0.