# FlashKAN: Grid size-independent computation of Kolmogorov Arnold networks using BSpline bases

## Timing a "regular" single layer KAN on the MNIST dataset

In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from time import time
from layers import Regular_KAN, KANLinear, FlashKAN

In [2]:
# Select GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

Loading the dataset:

In [3]:
transform = transforms.ToTensor()

batch = 200
trainset = torchvision.datasets.MNIST("./Data", train=True, download=True,
                                      transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch,
                                          shuffle=True)

testset = torchvision.datasets.MNIST(root='./Data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch,
                                         shuffle=False)

We're not concerned about performance (in terms of loss/accuracy) as much as training/inference speeds of our models. So we benchmark the training speed of 3 models for different grid sizes: $G = 10, 50, 100$

In [13]:
def create_reg_kan(G):
    return nn.Sequential(
        nn.Flatten(),
        KANLinear(28*28, 10, G)
    ).to(device)

reg_kan1 = create_reg_kan(10)
reg_kan2 = create_reg_kan(100)
reg_kan3 = create_reg_kan(500)

In [14]:
criterion = nn.CrossEntropyLoss()
metric = lambda out, labels: (torch.argmax(out,1) == labels).float().mean()

opt1 = torch.optim.Adam(reg_kan1.parameters(), lr=0.001)
opt2 = torch.optim.Adam(reg_kan2.parameters(), lr=0.001)
opt3 = torch.optim.Adam(reg_kan3.parameters(), lr=0.001)

In [15]:
# Trains a single epoch on the MNIST dataset
def train_loop(net, opt, epochs=1):
    for epoch in range(epochs):  # loop over the dataset multiple times

        running_loss = 0.
        running_acc = 0.
        
        for i, data in enumerate(trainloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            opt.zero_grad()
            
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()
            with torch.no_grad():
                acc = metric(outputs, labels)

            # print statistics
            running_acc += acc.item()
            running_loss += loss.item()
            
    return running_loss*(batch/60_000), running_acc*(batch/60_000)

In [16]:
def train_time(net, opt):
    t0 = time()
    _, _ = train_loop(net, opt)
    t1 = time()
    return t1-t0

In [17]:
print("G: 10 ", "Time taken: ", train_time(reg_kan1, opt1))
print("G: 100 ", "Time taken: ", train_time(reg_kan2, opt2))
print("G: 500 ", "Time taken: ", train_time(reg_kan3, opt3))

G: 10  Time taken:  4.063426494598389
G: 100  Time taken:  7.6076741218566895
G: 500  Time taken:  28.56937551498413


The time taken to perform a single epoch of training a single layer KAN on MNIST definitely increases with the grid size $G$.

This implementation was taken from an existing work found in [efficient KAN](https://github.com/Blealtan/efficient-kan).

## Training a single layer FlashKAN on the MNIST dataset

In [18]:
from layers.FlashKAN import FlashKAN

This time, we use a different implementation of KAN:

In [22]:
def create_flash_kan(G):
    return nn.Sequential(
        nn.Flatten(),
        FlashKAN(28*28, 10, G)
    ).to(device)

flash_kan1 = create_flash_kan(10)
flash_kan2 = create_flash_kan(100)
flash_kan3 = create_flash_kan(500)

In [23]:
criterion = nn.CrossEntropyLoss()
metric = lambda out, labels: (torch.argmax(out,1) == labels).float().mean()

opt1 = torch.optim.Adam(flash_kan1.parameters(), lr=0.001)
opt2 = torch.optim.Adam(flash_kan2.parameters(), lr=0.001)
opt3 = torch.optim.Adam(flash_kan3.parameters(), lr=0.001)

In [24]:
print("G: 10 ", "Time taken: ", train_time(flash_kan1, opt1))
print("G: 100 ", "Time taken: ", train_time(flash_kan2, opt2))
print("G: 500 ", "Time taken: ", train_time(flash_kan3, opt3))

G: 10  Time taken:  7.005167007446289
G: 100  Time taken:  7.101221323013306
G: 500  Time taken:  7.40746808052063


The training time remains roughly the same for larger grid sizes! The benefits are especially apparent for $G>100$. However, the size of model parameters and the memory complexity are roughly the same. 

This implementation exploits the limited support (parts of the domain where it has non-zero value) of the BSpline basis functions on the given grid and slices only parts of the weight array that support the given input data point before multiplication. Slicing the weight array also necessitates a custom gradient defined by subclassing `torch.autograd.Function`

See implementation in `FlashKAN.py` for more details.

## Not just fast but actually works!

Let's train a single layer FlashKAN on MNIST for 10 epochs:

In [25]:
flash_kan = create_flash_kan(100)

In [26]:
# Slightly tweaked training loop logging every epoch
def train_loop2(net, opt, epochs=1):
    for epoch in range(epochs):  # loop over the dataset multiple times

        running_loss = 0.
        running_acc = 0.
        for i, data in enumerate(trainloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            opt.zero_grad()
            
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()
            with torch.no_grad():
                acc = metric(outputs, labels)

            # print statistics
            running_acc += acc.item()
            running_loss += loss.item()
            
        print(f'Epoch: {epoch+1}  Loss: {running_loss / i:.3f}', 
            f'Accuracy: {running_acc / i:.4f}')

In [27]:
opt = torch.optim.Adam(flash_kan.parameters(), lr=0.001)

train_loop2(flash_kan, opt, 10)

Epoch: 1  Loss: 0.998 Accuracy: 0.6991
Epoch: 2  Loss: 0.490 Accuracy: 0.8539
Epoch: 3  Loss: 0.382 Accuracy: 0.8850
Epoch: 4  Loss: 0.325 Accuracy: 0.9053
Epoch: 5  Loss: 0.268 Accuracy: 0.9218
Epoch: 6  Loss: 0.240 Accuracy: 0.9307
Epoch: 7  Loss: 0.224 Accuracy: 0.9353
Epoch: 8  Loss: 0.225 Accuracy: 0.9323
Epoch: 9  Loss: 0.187 Accuracy: 0.9469
Epoch: 10  Loss: 0.170 Accuracy: 0.9507


In [28]:
def test_fn(net):
    running_loss, running_acc = 0., 0.
    for i, data in enumerate(testloader, 1):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            acc = metric(outputs, labels)

        # print statistics
        running_acc += acc.item()
        running_loss += loss.item()
    print(f'loss: {running_loss / i:.3f}', 
            f'accuracy: {running_acc / i:.4f}')

In [29]:
test_fn(flash_kan)

loss: 0.322 accuracy: 0.9061


The accuracy metric for both train and test datasets are above 90%, which means we get pretty much the same performance (as other KAN implementations) for better speeds!