In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.optim import Adam
from tqdm import tqdm
import argparse
import os


def laguerre_polynomial(x, degree):
    if degree == 0:
        return torch.ones_like(x)
    elif degree == 1:
        return 1 - x

    L_prev_prev = torch.ones_like(x)
    L_prev = 1 - x
    for n in range(2, degree + 1):
        L = ((2 * n - 1 - x) * L_prev - (n - 1) * L_prev_prev) / n
        L_prev_prev, L_prev = L_prev, L
    return L

class KANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree=3, grid_size=4, use_resid=False, resid=nn.SiLU):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.degree = degree
        self.grid_size = grid_size
        self.use_resid = use_resid

        self.grid_centers = nn.Parameter(torch.linspace(0, 1, grid_size).view(1, 1, 1, grid_size))
        self.grid_scales = nn.Parameter(torch.ones(input_dim, grid_size))
        self.coefficients = nn.Parameter(torch.randn(input_dim, grid_size, output_dim))

        if self.use_resid:
            self.resid_act = resid()
            self.resid_linear = nn.Linear(input_dim, output_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.constant_(self.grid_scales, 1.0)

        nn.init.normal_(self.coefficients, std=1 / (self.grid_size ** 0.5))
        if self.use_resid:
            nn.init.kaiming_normal_(self.resid_linear.weight)
            nn.init.zeros_(self.resid_linear.bias)

    def kan_transform(self, x):
        x_transformed = (x.unsqueeze(-1) - self.grid_centers) * self.grid_scales.abs()
        x_transformed = torch.clamp(x_transformed, -1, 1)  # Clamp to avoid large values

        basis = laguerre_polynomial(x_transformed, self.degree)
        out = torch.einsum('bldg,dgo->blo', basis, self.coefficients)
        return out

    def forward(self, x):
        out = self.kan_transform(x)
        if self.use_resid:
            res = self.resid_linear(self.resid_act(x))
            out = out + res
        return out



class ResNet18WithKAN(nn.Module):
    def __init__(self, num_classes, hidden_dims=[256, 128], kan_degree=3, grid_size=4):
        super().__init__()
        self.backbone = models.resnet18(pretrained=False)
        self.backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.backbone.maxpool = nn.Identity()
        self.backbone.fc = nn.Identity()

        in_features = 512
        layers = []
        input_dim = in_features
        for hidden_dim in hidden_dims:
            layers.append(KANLayer(input_dim, hidden_dim, degree=kan_degree, grid_size=grid_size, use_resid=True))
            input_dim = hidden_dim

        # Final classification layer as a KANLayer
        layers.append(KANLayer(input_dim, num_classes, degree=kan_degree, grid_size=grid_size, use_resid=True))
        self.kan_head = nn.Sequential(*layers)

    def forward(self, x):
        x = self.backbone(x)
        x = x.unsqueeze(1)  # [B, 1, D]
        x = self.kan_head(x)  # [B, 1, num_classes]
        return x.squeeze(1)




In [2]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in tqdm(loader, desc="Train", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    for images, labels in tqdm(loader, desc="Eval", leave=False):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        total_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return total_loss / total, correct / total





In [3]:
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Load CIFAR-10 or CIFAR-100
    if args.dataset == 'cifar10':
        trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
    else:
        trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        num_classes = 100

    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4)

    model = ResNet18WithKAN(num_classes=num_classes,
                            hidden_dims=[256, 128],
                            kan_degree=args.degree,
                            grid_size=args.grid_size).to(device)

    optimizer = Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0
    for epoch in range(args.epochs):
        print(f"\nEpoch {epoch+1}/{args.epochs}")
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
        print(f"Test  Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), f'best_model_{args.dataset}.pth')

    print(f"\nBest Accuracy on {args.dataset}: {best_acc:.4f}")



In [4]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--degree', type=int, default=3)
    parser.add_argument('--grid_size', type=int, default=4)


    args, _ = parser.parse_known_args()
    return args

if __name__ == '__main__':
    args = get_args()
    main(args)





Epoch 1/100




Train Loss: 23.7257, Acc: 0.1335
Test  Loss: 19.8254, Acc: 0.1387

Epoch 2/100




Train Loss: 13.9062, Acc: 0.1331
Test  Loss: 11.2631, Acc: 0.1471

Epoch 3/100




Train Loss: 10.3081, Acc: 0.1464
Test  Loss: 7.8683, Acc: 0.1548

Epoch 4/100




Train Loss: 8.1679, Acc: 0.1495
Test  Loss: 5.4112, Acc: 0.1411

Epoch 5/100




Train Loss: 8.2081, Acc: 0.1301
Test  Loss: 7.9060, Acc: 0.1293

Epoch 6/100




Train Loss: 5.9967, Acc: 0.1431
Test  Loss: 6.6687, Acc: 0.1610

Epoch 7/100




Train Loss: 5.4053, Acc: 0.1517
Test  Loss: 4.7777, Acc: 0.1697

Epoch 8/100




Train Loss: 4.4721, Acc: 0.1764
Test  Loss: 4.1053, Acc: 0.2125

Epoch 9/100




Train Loss: 3.7352, Acc: 0.1882
Test  Loss: 4.3028, Acc: 0.1604

Epoch 10/100




Train Loss: 3.6369, Acc: 0.1950
Test  Loss: 3.4858, Acc: 0.1976

Epoch 11/100




Train Loss: 3.6537, Acc: 0.1750
Test  Loss: 3.0397, Acc: 0.1784

Epoch 12/100




Train Loss: 3.3779, Acc: 0.1805
Test  Loss: 2.7328, Acc: 0.2216

Epoch 13/100




Train Loss: 3.3451, Acc: 0.1902
Test  Loss: 3.3320, Acc: 0.1913

Epoch 14/100




Train Loss: 3.4006, Acc: 0.1832
Test  Loss: 4.6605, Acc: 0.1669

Epoch 15/100




Train Loss: 3.2147, Acc: 0.1927
Test  Loss: 2.8948, Acc: 0.1856

Epoch 16/100




Train Loss: 2.9005, Acc: 0.2027
Test  Loss: 2.5556, Acc: 0.2311

Epoch 17/100




Train Loss: 2.6629, Acc: 0.2173
Test  Loss: 2.8212, Acc: 0.2278

Epoch 18/100




Train Loss: 2.5728, Acc: 0.2252
Test  Loss: 2.4684, Acc: 0.2217

Epoch 19/100




Train Loss: 2.4590, Acc: 0.2383
Test  Loss: 2.5389, Acc: 0.2018

Epoch 20/100




Train Loss: 2.4992, Acc: 0.2403
Test  Loss: 2.4137, Acc: 0.2621

Epoch 21/100




Train Loss: 2.4786, Acc: 0.2448
Test  Loss: 2.2618, Acc: 0.2435

Epoch 22/100




Train Loss: 2.3802, Acc: 0.2523
Test  Loss: 2.4256, Acc: 0.2465

Epoch 23/100




Train Loss: 2.3794, Acc: 0.2442
Test  Loss: 2.3651, Acc: 0.2385

Epoch 24/100




Train Loss: 2.2371, Acc: 0.2590
Test  Loss: 2.1843, Acc: 0.2808

Epoch 25/100




Train Loss: 2.1561, Acc: 0.2738
Test  Loss: 1.9979, Acc: 0.2857

Epoch 26/100




Train Loss: 2.1185, Acc: 0.2783
Test  Loss: 2.4236, Acc: 0.2326

Epoch 27/100




Train Loss: 2.1605, Acc: 0.2588
Test  Loss: 2.4910, Acc: 0.2092

Epoch 28/100




Train Loss: 2.4037, Acc: 0.2326
Test  Loss: 2.3781, Acc: 0.2458

Epoch 29/100




Train Loss: 2.1823, Acc: 0.2541
Test  Loss: 2.1517, Acc: 0.2822

Epoch 30/100




Train Loss: 2.0407, Acc: 0.2868
Test  Loss: 2.0500, Acc: 0.2674

Epoch 31/100




Train Loss: 1.9925, Acc: 0.2998
Test  Loss: 1.9186, Acc: 0.3170

Epoch 32/100




Train Loss: 1.9923, Acc: 0.3161
Test  Loss: 2.0424, Acc: 0.2961

Epoch 33/100




Train Loss: 2.1321, Acc: 0.3104
Test  Loss: 1.9245, Acc: 0.3295

Epoch 34/100




Train Loss: 1.9700, Acc: 0.3225
Test  Loss: 2.1076, Acc: 0.3176

Epoch 35/100




Train Loss: 2.0632, Acc: 0.2999
Test  Loss: 1.8058, Acc: 0.3365

Epoch 36/100




Train Loss: 2.0639, Acc: 0.2989
Test  Loss: 2.0710, Acc: 0.2938

Epoch 37/100




Train Loss: 1.9820, Acc: 0.3212
Test  Loss: 2.1257, Acc: 0.3004

Epoch 38/100




Train Loss: 1.9741, Acc: 0.3296
Test  Loss: 1.7902, Acc: 0.3629

Epoch 39/100




Train Loss: 1.8559, Acc: 0.3441
Test  Loss: 2.0263, Acc: 0.3309

Epoch 40/100




Train Loss: 1.8523, Acc: 0.3461
Test  Loss: 1.6600, Acc: 0.3928

Epoch 41/100




Train Loss: 1.7668, Acc: 0.3661
Test  Loss: 1.8282, Acc: 0.3699

Epoch 42/100




Train Loss: 1.7739, Acc: 0.3731
Test  Loss: 1.7178, Acc: 0.3794

Epoch 43/100




Train Loss: 1.7324, Acc: 0.3817
Test  Loss: 1.6441, Acc: 0.4006

Epoch 44/100




Train Loss: 1.6988, Acc: 0.3973
Test  Loss: 1.6008, Acc: 0.4073

Epoch 45/100




Train Loss: 1.6977, Acc: 0.4122
Test  Loss: 1.6226, Acc: 0.4409

Epoch 46/100




Train Loss: 1.6418, Acc: 0.4222
Test  Loss: 1.6253, Acc: 0.4102

Epoch 47/100




Train Loss: 1.6312, Acc: 0.4293
Test  Loss: 1.7807, Acc: 0.3618

Epoch 48/100




Train Loss: 1.6339, Acc: 0.4181
Test  Loss: 1.4585, Acc: 0.4839

Epoch 49/100




Train Loss: 1.5932, Acc: 0.4402
Test  Loss: 1.6592, Acc: 0.4106

Epoch 50/100




Train Loss: 1.5089, Acc: 0.4638
Test  Loss: 1.4371, Acc: 0.4947

Epoch 51/100




Train Loss: 1.4432, Acc: 0.4883
Test  Loss: 1.5563, Acc: 0.4534

Epoch 52/100




Train Loss: 1.4497, Acc: 0.4870
Test  Loss: 1.4624, Acc: 0.4991

Epoch 53/100




Train Loss: 1.4632, Acc: 0.4996
Test  Loss: 2.4327, Acc: 0.3083

Epoch 54/100




Train Loss: 1.8692, Acc: 0.3720
Test  Loss: 1.8455, Acc: 0.3570

Epoch 55/100




Train Loss: 1.6528, Acc: 0.4236
Test  Loss: 1.6521, Acc: 0.4214

Epoch 56/100




Train Loss: 1.4394, Acc: 0.4906
Test  Loss: 1.5554, Acc: 0.4629

Epoch 57/100




Train Loss: 1.4117, Acc: 0.5107
Test  Loss: 1.5974, Acc: 0.4738

Epoch 58/100




Train Loss: 1.5059, Acc: 0.4838
Test  Loss: 1.2836, Acc: 0.5363

Epoch 59/100




Train Loss: 1.3551, Acc: 0.5238
Test  Loss: 1.3891, Acc: 0.5214

Epoch 60/100




Train Loss: 1.3068, Acc: 0.5471
Test  Loss: 1.2717, Acc: 0.5478

Epoch 61/100




Train Loss: 1.2788, Acc: 0.5581
Test  Loss: 1.1695, Acc: 0.5781

Epoch 62/100




Train Loss: 1.2514, Acc: 0.5660
Test  Loss: 1.3646, Acc: 0.5488

Epoch 63/100




Train Loss: 1.1895, Acc: 0.5833
Test  Loss: 1.3269, Acc: 0.5743

Epoch 64/100




Train Loss: 1.1996, Acc: 0.5858
Test  Loss: 1.2567, Acc: 0.5853

Epoch 65/100




Train Loss: 1.2044, Acc: 0.5872
Test  Loss: 1.1406, Acc: 0.5997

Epoch 66/100




Train Loss: 1.1237, Acc: 0.6129
Test  Loss: 1.1853, Acc: 0.6120

Epoch 67/100




Train Loss: 1.1690, Acc: 0.5942
Test  Loss: 1.2323, Acc: 0.5865

Epoch 68/100




Train Loss: 1.1365, Acc: 0.6102
Test  Loss: 1.1697, Acc: 0.6016

Epoch 69/100




Train Loss: 1.0732, Acc: 0.6304
Test  Loss: 1.1763, Acc: 0.6105

Epoch 70/100




Train Loss: 1.2300, Acc: 0.5833
Test  Loss: 1.1887, Acc: 0.5850

Epoch 71/100




Train Loss: 1.1269, Acc: 0.6125
Test  Loss: 1.0202, Acc: 0.6337

Epoch 72/100




Train Loss: 1.0730, Acc: 0.6293
Test  Loss: 1.0946, Acc: 0.6349

Epoch 73/100




Train Loss: 1.0501, Acc: 0.6414
Test  Loss: 0.9404, Acc: 0.6684

Epoch 74/100




Train Loss: 1.0087, Acc: 0.6527
Test  Loss: 1.0425, Acc: 0.6411

Epoch 75/100




Train Loss: 1.0082, Acc: 0.6530
Test  Loss: 1.2340, Acc: 0.5956

Epoch 76/100




Train Loss: 1.0219, Acc: 0.6510
Test  Loss: 1.0745, Acc: 0.6368

Epoch 77/100




Train Loss: 1.0359, Acc: 0.6479
Test  Loss: 1.1774, Acc: 0.6087

Epoch 78/100




Train Loss: 0.9865, Acc: 0.6598
Test  Loss: 1.0937, Acc: 0.6276

Epoch 79/100




Train Loss: 0.9955, Acc: 0.6610
Test  Loss: 0.9892, Acc: 0.6619

Epoch 80/100




Train Loss: 0.9400, Acc: 0.6792
Test  Loss: 0.9403, Acc: 0.6758

Epoch 81/100




Train Loss: 0.9284, Acc: 0.6821
Test  Loss: 1.0230, Acc: 0.6628

Epoch 82/100




Train Loss: 0.8963, Acc: 0.6939
Test  Loss: 0.9096, Acc: 0.6937

Epoch 83/100




Train Loss: 0.9540, Acc: 0.6790
Test  Loss: 1.1310, Acc: 0.6451

Epoch 84/100




Train Loss: 0.8936, Acc: 0.6977
Test  Loss: 1.0730, Acc: 0.6679

Epoch 85/100




Train Loss: 0.8507, Acc: 0.7126
Test  Loss: 0.8809, Acc: 0.6899

Epoch 86/100




Train Loss: 0.9064, Acc: 0.7007
Test  Loss: 1.0673, Acc: 0.6538

Epoch 87/100




Train Loss: 0.8595, Acc: 0.7129
Test  Loss: 0.8439, Acc: 0.7119

Epoch 88/100




Train Loss: 0.9048, Acc: 0.7010
Test  Loss: 1.0228, Acc: 0.6761

Epoch 89/100




Train Loss: 0.7850, Acc: 0.7340
Test  Loss: 0.9452, Acc: 0.6980

Epoch 90/100




Train Loss: 0.8077, Acc: 0.7318
Test  Loss: 0.7811, Acc: 0.7359

Epoch 91/100




Train Loss: 0.7655, Acc: 0.7456
Test  Loss: 0.7929, Acc: 0.7402

Epoch 92/100




Train Loss: 0.7618, Acc: 0.7441
Test  Loss: 0.8241, Acc: 0.7392

Epoch 93/100




Train Loss: 0.7453, Acc: 0.7557
Test  Loss: 0.9399, Acc: 0.7100

Epoch 94/100




Train Loss: 0.8335, Acc: 0.7311
Test  Loss: 0.8177, Acc: 0.7342

Epoch 95/100




Train Loss: 0.7518, Acc: 0.7503
Test  Loss: 0.9005, Acc: 0.7398

Epoch 96/100




Train Loss: 0.7454, Acc: 0.7552
Test  Loss: 0.9049, Acc: 0.7159

Epoch 97/100




Train Loss: 0.7194, Acc: 0.7632
Test  Loss: 0.7230, Acc: 0.7660

Epoch 98/100




Train Loss: 0.6961, Acc: 0.7713
Test  Loss: 0.7850, Acc: 0.7493

Epoch 99/100




Train Loss: 0.7028, Acc: 0.7727
Test  Loss: 0.8087, Acc: 0.7458

Epoch 100/100




Train Loss: 0.6714, Acc: 0.7821
Test  Loss: 0.6717, Acc: 0.7848

✅ Best Accuracy on cifar10: 0.7848
