# Task IX. Kolmogorov-Arnold Network (KAN)
Implementation of a classical Kolmogorov-Arnold Network with b-splines applied to the MNIST dataset. 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose([transforms.ToTensor()])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
test_data = MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = DataLoader(test_data, batch_size=128)

In [31]:
class BSplineActivation(nn.Module):
    def __init__(self, num_bases=10, degree=3, input_dim=784):
        super().__init__()
        self.num_bases = num_bases
        self.degree = degree
        self.input_dim = input_dim

        self.knots = torch.linspace(0, 1, num_bases + degree + 1)
        self.register_buffer("knots_buffer", self.knots)

        self.coeffs = nn.Parameter(torch.randn(input_dim, num_bases))

    def forward(self, x):
        x = (x - x.min()) / (x.max() - x.min() + 1e-6)
        B = self.bspline_basis(x)  # (batch, input_dim, num_bases)
        B = B.squeeze()
        out = torch.einsum('bik,ik->bi', B, self.coeffs)  # (batch, input_dim)
        return out

    def bspline_basis(self, x):
        batch_size, input_dim = x.shape
        x = x.unsqueeze(-1)  # (batch, input_dim, 1)
        knots = self.knots_buffer
        degree = self.degree
        num_bases = self.num_bases

        B = []
        for i in range(num_bases):
            left = knots[i]
            right = knots[i + 1]
            B.append(((x >= left) & (x < right)).float())
        B = torch.stack(B, dim=2)  # (batch, input_dim, num_bases)

        for d in range(1, degree + 1):
            B_new = []
            for i in range(num_bases):
                denom1 = knots[i + d] - knots[i]
                denom2 = knots[i + d + 1] - knots[i + 1]

                term1 = 0
                if denom1 > 0:
                    term1 = ((x.squeeze(-1) - knots[i]) / denom1).unsqueeze(-1) * B[:, :, i]

                term2 = 0
                if i + 1 < num_bases and denom2 > 0:
                    term2 = ((knots[i + d + 1] - x.squeeze(-1)) / denom2).unsqueeze(-1) * B[:, :, i + 1]

                B_new.append(term1 + term2)

            B = torch.stack(B_new, dim=2)

        return B

class BSplineKAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(784, 784)
        self.bkan = BSplineActivation(input_dim=784)
        self.linear2 = nn.Linear(784, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.bkan(x)
        x = self.linear2(x)
        return x


In [32]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BSplineKAN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")



Epoch 1, Loss: 0.2803
Epoch 2, Loss: 0.1053
Epoch 3, Loss: 0.0727
Epoch 4, Loss: 0.0435
Epoch 5, Loss: 0.0325


In [33]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 97.98%


## Quantum KAN

A major concern with classical KANs is computational efficiency. The activation functions are not able to be learned in parallel using GPUs, so MLP still dominate deep learning. In the quantum architecture, we can try to compute multiple parameters in parallel, speeding up the training process of KANs and increasing their practicality. Inspired by https://arxiv.org/pdf/2410.04435