In [1]:
from efficient_kan import KAN

# Train on MNIST
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
valloader = DataLoader(valset, batch_size=64, shuffle=False)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [3]:
# Define model
model = KAN([28 * 28, 64, 10])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

In [4]:
# Define loss
criterion = nn.CrossEntropyLoss()
for epoch in range(10):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )

100%|████████████████████████████████| 938/938 [00:16<00:00, 56.10it/s, accuracy=0.875, loss=0.453, lr=0.001]


Epoch 1, Val Loss: 0.2179765210125097, Val Accuracy: 0.9352109872611465


100%|██████████████████████████████| 938/938 [00:15<00:00, 61.95it/s, accuracy=0.969, loss=0.0653, lr=0.0008]


Epoch 2, Val Loss: 0.1636626766697996, Val Accuracy: 0.9527269108280255


100%|██████████████████████████████| 938/938 [00:15<00:00, 62.02it/s, accuracy=0.906, loss=0.403, lr=0.00064]


Epoch 3, Val Loss: 0.13273732252656276, Val Accuracy: 0.9610867834394905


100%|████████████████████████████████| 938/938 [00:15<00:00, 60.96it/s, accuracy=1, loss=0.0401, lr=0.000512]


Epoch 4, Val Loss: 0.11652722207170906, Val Accuracy: 0.9655652866242038


100%|█████████████████████████████| 938/938 [00:15<00:00, 60.54it/s, accuracy=0.938, loss=0.0946, lr=0.00041]


Epoch 5, Val Loss: 0.11061898619529738, Val Accuracy: 0.966062898089172


100%|████████████████████████████████| 938/938 [00:14<00:00, 62.53it/s, accuracy=1, loss=0.0639, lr=0.000328]


Epoch 6, Val Loss: 0.1027873326389558, Val Accuracy: 0.9682523885350318


100%|████████████████████████████████| 938/938 [00:15<00:00, 59.65it/s, accuracy=1, loss=0.0129, lr=0.000262]


Epoch 7, Val Loss: 0.10016763444610272, Val Accuracy: 0.9692476114649682


100%|█████████████████████████████| 938/938 [00:15<00:00, 61.16it/s, accuracy=0.969, loss=0.0619, lr=0.00021]


Epoch 8, Val Loss: 0.09976962811586441, Val Accuracy: 0.9690485668789809


100%|████████████████████████████████| 938/938 [00:15<00:00, 59.73it/s, accuracy=1, loss=0.0077, lr=0.000168]


Epoch 9, Val Loss: 0.09615444730460097, Val Accuracy: 0.9707404458598726


100%|████████████████████████████| 938/938 [00:15<00:00, 59.18it/s, accuracy=0.969, loss=0.0438, lr=0.000134]


Epoch 10, Val Loss: 0.09370725156122427, Val Accuracy: 0.9717356687898089
