# FlashKAN: Grid size-independent computation of Kolmogorov Arnold networks using BSpline bases

## Timing a "regular" single layer KAN on the MNIST dataset

In [1]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from time import time
from Regular_KAN import Regular_KAN

In [2]:
# Select GPU if available
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")

Loading the dataset:

In [3]:
transform = transforms.ToTensor()

batch = 200
trainset = torchvision.datasets.MNIST("./Data", train=True, download=True,
                                      transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch,
                                          shuffle=True)

testset = torchvision.datasets.MNIST(root='./Data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch,
                                         shuffle=False)

We're not concerned about performance (in terms of loss/accuracy) as much as training/inference speeds of our models. So we benchmark the training speed of 3 models for different grid sizes: $G = 10, 50, 100$

In [4]:
def create_reg_kan(G):
    return nn.Sequential(
        nn.Flatten(),
        Regular_KAN(28*28, 10, G)
    ).to(device)

reg_kan1 = create_reg_kan(10)
reg_kan2 = create_reg_kan(50)
reg_kan3 = create_reg_kan(100)

In [5]:
criterion = nn.CrossEntropyLoss()
metric = lambda out, labels: (torch.argmax(out,1) == labels).float().mean()

opt1 = torch.optim.Adam(reg_kan1.parameters(), lr=0.001)
opt2 = torch.optim.Adam(reg_kan2.parameters(), lr=0.001)
opt3 = torch.optim.Adam(reg_kan3.parameters(), lr=0.001)

In [6]:
# Trains a single epoch on the MNIST dataset
def train_loop(net, opt, epochs=1):
    for epoch in range(epochs):  # loop over the dataset multiple times

        running_loss = 0.
        running_acc = 0.
        for i, data in enumerate(trainloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            opt.zero_grad()
            
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()
            with torch.no_grad():
                acc = metric(outputs, labels)

            # print statistics
            running_acc += acc.item()
            running_loss += loss.item()
            
    return running_loss*(batch/60_000), running_acc*(batch/60_000)

In [7]:
def train_time(net, opt):
    t0 = time()
    _, _ = train_loop(net, opt)
    t1 = time()
    return t1-t0

In [8]:
print("G: 10 ", "Time taken: ", train_time(reg_kan1, opt1))
print("G: 50 ", "Time taken: ", train_time(reg_kan2, opt2))
print("G: 100 ", "Time taken: ", train_time(reg_kan3, opt3))

G: 10  Time taken:  38.593632221221924
G: 50  Time taken:  133.68002247810364
G: 100  Time taken:  337.3856027126312


The time taken to perform a single epoch of training a single layer KAN on MNIST definitely increases with the grid size $G$.

Implementation details can be found in `Regular_KAN.py`. This is comparable in performance to implementations found in other sources like [efficient KAN](https://github.com/Blealtan/efficient-kan), especially wrt varying grid sizes.

## Training a single layer FlashKAN on the MNIST dataset

In [10]:
from FlashKAN import FlashKAN

This time, we use a different implementation of KAN:

In [11]:
def create_flash_kan(G):
    return nn.Sequential(
        nn.Flatten(),
        FlashKAN(28*28, 10, G)
    ).to(device)

flash_kan1 = create_flash_kan(10)
flash_kan2 = create_flash_kan(50)
flash_kan3 = create_flash_kan(100)

In [12]:
criterion = nn.CrossEntropyLoss()
metric = lambda out, labels: (torch.argmax(out,1) == labels).float().mean()

opt1 = torch.optim.Adam(flash_kan1.parameters(), lr=0.001)
opt2 = torch.optim.Adam(flash_kan2.parameters(), lr=0.001)
opt3 = torch.optim.Adam(flash_kan3.parameters(), lr=0.001)

In [13]:
print("G: 10 ", "Time taken: ", train_time(flash_kan1, opt1))
print("G: 50 ", "Time taken: ", train_time(flash_kan2, opt2))
print("G: 100 ", "Time taken: ", train_time(flash_kan3, opt3))

G: 10  Time taken:  12.132895708084106
G: 50  Time taken:  13.694677352905273
G: 100  Time taken:  15.494819164276123


The training time remains roughly the same for larger grid sizes! However, the size of model parameters remains the same as the regular GANs. 

This implementation exploits the limited support (parts of the domain where it has non-zero value) of the BSpline basis functions on the given grid and slices only parts of the weight array that support the given input data point before multiplication. Slicing the weight array also necessitates a custom gradient defined by subclassing `torch.autograd.Function`

See implementation in `KAN_Linear.py` for more details.

## Not just fast but actually works!

Let's train a single layer FlashKAN on MNIST for 10 epochs:

In [20]:
flash_kan = create_flash_kan(100)

In [21]:
# Slightly tweaked training loop logging every epoch
def train_loop2(net, opt, epochs=1):
    for epoch in range(epochs):  # loop over the dataset multiple times

        running_loss = 0.
        running_acc = 0.
        for i, data in enumerate(trainloader, 1):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            opt.zero_grad()
            
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            opt.step()
            with torch.no_grad():
                acc = metric(outputs, labels)

            # print statistics
            running_acc += acc.item()
            running_loss += loss.item()
            
        print(f'Epoch: {epoch+1}  Loss: {running_loss / i:.3f}', 
            f'Accuracy: {running_acc / i:.4f}')

In [22]:
opt = torch.optim.Adam(flash_kan.parameters(), lr=0.001)

train_loop2(flash_kan, opt, 10)

Epoch: 1  Loss: 0.962 Accuracy: 0.7042
Epoch: 2  Loss: 0.461 Accuracy: 0.8626
Epoch: 3  Loss: 0.380 Accuracy: 0.8859
Epoch: 4  Loss: 0.320 Accuracy: 0.9051
Epoch: 5  Loss: 0.267 Accuracy: 0.9214
Epoch: 6  Loss: 0.244 Accuracy: 0.9288
Epoch: 7  Loss: 0.220 Accuracy: 0.9360
Epoch: 8  Loss: 0.194 Accuracy: 0.9449
Epoch: 9  Loss: 0.168 Accuracy: 0.9544
Epoch: 10  Loss: 0.161 Accuracy: 0.9553


In [23]:
def test_fn(net):
    running_loss, running_acc = 0., 0.
    for i, data in enumerate(testloader, 1):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        with torch.no_grad():
            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            acc = metric(outputs, labels)

        # print statistics
        running_acc += acc.item()
        running_loss += loss.item()
    print(f'loss: {running_loss / i:.3f}', 
            f'accuracy: {running_acc / i:.4f}')

In [24]:
test_fn(flash_kan)

loss: 0.325 accuracy: 0.9045


The accuracy metric for both train and test datasets are above 90%, which means we get pretty much the same performance (as other KAN implementations) for better speeds!