In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
from torch.optim import Adam
import timm
from tqdm import tqdm
import argparse
import os


def laguerre_polynomial(x, degree):
    if degree == 0:
        return torch.ones_like(x)
    elif degree == 1:
        return 1 - x

    L_prev_prev = torch.ones_like(x)
    L_prev = 1 - x
    for n in range(2, degree + 1):
        L = ((2 * n - 1 - x) * L_prev - (n - 1) * L_prev_prev) / n
        L_prev_prev, L_prev = L_prev, L
    return L

class KANLayer(nn.Module):
    def __init__(self, input_dim, output_dim, degree=3, grid_size=4, use_resid=False, resid=nn.SiLU):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.degree = degree
        self.grid_size = grid_size
        self.use_resid = use_resid

        self.grid_centers = nn.Parameter(torch.linspace(0, 1, grid_size).view(1, 1, 1, grid_size))
        self.grid_scales = nn.Parameter(torch.ones(input_dim, grid_size))
        self.coefficients = nn.Parameter(torch.randn(input_dim, grid_size, output_dim))

        if self.use_resid:
            self.resid_act = resid()
            self.resid_linear = nn.Linear(input_dim, output_dim)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.constant_(self.grid_scales, 1.0)

        nn.init.normal_(self.coefficients, std=1 / (self.grid_size ** 0.5))
        if self.use_resid:
            nn.init.kaiming_normal_(self.resid_linear.weight)
            nn.init.zeros_(self.resid_linear.bias)

    def kan_transform(self, x):
        x_transformed = (x.unsqueeze(-1) - self.grid_centers) * self.grid_scales.abs()
        x_transformed = torch.clamp(x_transformed, -1, 1)  # Clamp to avoid large values

        basis = laguerre_polynomial(x_transformed, self.degree)
        out = torch.einsum('bldg,dgo->blo', basis, self.coefficients)
        return out

    def forward(self, x):
        out = self.kan_transform(x)
        if self.use_resid:
            res = self.resid_linear(self.resid_act(x))
            out = out + res
        return out



class ViTWithKAN(nn.Module):
    def __init__(self, num_classes, hidden_dims=[256, 128], kan_degree=3, grid_size=4, vit_model_name='vit_base_patch16_224'):
        super().__init__()

        # Load pretrained ViT from timm
        self.backbone = timm.create_model(vit_model_name, pretrained=True)
        self.backbone.head = nn.Identity()  # Remove the classification head

        in_features = self.backbone.num_features  # Typically 768 for vit_base_patch16_224

        layers = []
        input_dim = in_features
        for hidden_dim in hidden_dims:
            layers.append(KANLayer(input_dim, hidden_dim, degree=kan_degree, grid_size=grid_size, use_resid=True))
            input_dim = hidden_dim

        layers.append(KANLayer(input_dim, num_classes, degree=kan_degree, grid_size=grid_size, use_resid=True))
        self.kan_head = nn.Sequential(*layers)

    def forward(self, x):
        x = self.backbone(x)  # [B, D] output of CLS token
        x = x.unsqueeze(1)    # [B, 1, D]
        x = self.kan_head(x)  # [B, 1, num_classes]
        return x.squeeze(1)   # [B, num_classes]



In [2]:
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for images, labels in tqdm(loader, desc="Train", leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return total_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    for images, labels in tqdm(loader, desc="Eval", leave=False):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        total_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    return total_loss / total, correct / total





In [3]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim import Adam

# Assuming ViTWithKAN is defined elsewhere and imported

def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # ViT expects 224x224 inputs and ImageNet normalization
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.RandomCrop(224, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),  # ImageNet mean
                             (0.229, 0.224, 0.225))  # ImageNet std
    ])

    # Load CIFAR-10 or CIFAR-100
    if args.dataset == 'cifar10':
        trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
        num_classes = 10
    else:
        trainset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
        testset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
        num_classes = 100

    train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4)


    model = ViTWithKAN(num_classes=num_classes,
                   hidden_dims=[256, 128],
                   kan_degree=args.degree,
                   grid_size=args.grid_size,
                   vit_model_name='vit_tiny_patch16_224').to(device)


    optimizer = Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    best_acc = 0
    for epoch in range(args.epochs):
        print(f"\nEpoch {epoch + 1}/{args.epochs}")
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
        print(f"Test  Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), f'best_model_{args.dataset}_vit.pth')

    print(f"\nBest Accuracy on {args.dataset}: {best_acc:.4f}")


In [4]:
def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'])
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--degree', type=int, default=3)
    parser.add_argument('--grid_size', type=int, default=4)


    args, _ = parser.parse_known_args()
    return args

if __name__ == '__main__':
    args = get_args()
    main(args)


100%|██████████| 170M/170M [00:03<00:00, 48.9MB/s] 


model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]


Epoch 1/100


                                                        

Train Loss: 15.9984, Acc: 0.1042
Test  Loss: 7.3344, Acc: 0.1003

Epoch 2/100


                                                        

Train Loss: 10.4319, Acc: 0.1083
Test  Loss: 5.7018, Acc: 0.1125

Epoch 3/100


                                                        

Train Loss: 4.4480, Acc: 0.1498
Test  Loss: 4.7515, Acc: 0.1899

Epoch 4/100


                                                        

Train Loss: 3.9559, Acc: 0.1686
Test  Loss: 2.9759, Acc: 0.1467

Epoch 5/100


                                                        

Train Loss: 3.5872, Acc: 0.1767
Test  Loss: 2.6561, Acc: 0.1981

Epoch 6/100


                                                        

Train Loss: 9.2493, Acc: 0.1545
Test  Loss: 4.7433, Acc: 0.0999

Epoch 7/100


                                                        

Train Loss: 4.0293, Acc: 0.1011
Test  Loss: 2.9365, Acc: 0.1000

Epoch 8/100


                                                        

Train Loss: 3.4918, Acc: 0.1226
Test  Loss: 3.2144, Acc: 0.1577

Epoch 9/100


                                                        

Train Loss: 2.7458, Acc: 0.1606
Test  Loss: 3.1173, Acc: 0.1478

Epoch 10/100


                                                        

Train Loss: 2.7386, Acc: 0.1723
Test  Loss: 2.6438, Acc: 0.2179

Epoch 11/100


                                                        

Train Loss: 2.6901, Acc: 0.1834
Test  Loss: 3.6549, Acc: 0.1679

Epoch 12/100


                                                        

Train Loss: 2.5382, Acc: 0.1942
Test  Loss: 2.7452, Acc: 0.2057

Epoch 13/100


                                                        

Train Loss: 3.2722, Acc: 0.1736
Test  Loss: 2.1230, Acc: 0.2160

Epoch 14/100


                                                        

Train Loss: 2.5011, Acc: 0.1971
Test  Loss: 2.8160, Acc: 0.1799

Epoch 15/100


                                                        

Train Loss: 2.4341, Acc: 0.2014
Test  Loss: 2.3839, Acc: 0.1831

Epoch 16/100


                                                        

Train Loss: 2.3579, Acc: 0.2116
Test  Loss: 2.7545, Acc: 0.1279

Epoch 17/100


                                                        

Train Loss: 2.3135, Acc: 0.2284
Test  Loss: 2.2397, Acc: 0.2385

Epoch 18/100


                                                        

Train Loss: 2.4971, Acc: 0.2223
Test  Loss: 2.4942, Acc: 0.2249

Epoch 19/100


                                                        

Train Loss: 2.5255, Acc: 0.2074
Test  Loss: 2.7357, Acc: 0.1570

Epoch 20/100


                                                        

Train Loss: 2.4228, Acc: 0.1842
Test  Loss: 2.3642, Acc: 0.2300

Epoch 21/100


                                                        

Train Loss: 2.6251, Acc: 0.1818
Test  Loss: 2.3267, Acc: 0.1791

Epoch 22/100


                                                        

Train Loss: 2.3489, Acc: 0.1906
Test  Loss: 2.5528, Acc: 0.1624

Epoch 23/100


                                                        

Train Loss: 2.2651, Acc: 0.2096
Test  Loss: 2.4101, Acc: 0.2146

Epoch 24/100


                                                        

Train Loss: 2.2155, Acc: 0.2201
Test  Loss: 2.2639, Acc: 0.2303

Epoch 25/100


                                                        

Train Loss: 2.6520, Acc: 0.2024
Test  Loss: 3.1907, Acc: 0.1647

Epoch 26/100


                                                        

Train Loss: 2.6716, Acc: 0.1690
Test  Loss: 2.2439, Acc: 0.1989

Epoch 27/100


                                                        

Train Loss: 2.3671, Acc: 0.1990
Test  Loss: 2.2594, Acc: 0.2181

Epoch 28/100


                                                        

Train Loss: 2.3997, Acc: 0.1899
Test  Loss: 2.7607, Acc: 0.1437

Epoch 29/100


                                                        

Train Loss: 2.3334, Acc: 0.2043
Test  Loss: 2.2389, Acc: 0.2187

Epoch 30/100


                                                        

Train Loss: 2.1828, Acc: 0.2247
Test  Loss: 1.9870, Acc: 0.2744

Epoch 31/100


                                                        

Train Loss: 2.1356, Acc: 0.2308
Test  Loss: 2.2037, Acc: 0.2372

Epoch 32/100


                                                        

Train Loss: 2.1412, Acc: 0.2296
Test  Loss: 2.0527, Acc: 0.2549

Epoch 33/100


                                                        

Train Loss: 2.2002, Acc: 0.2340
Test  Loss: 1.9195, Acc: 0.2685

Epoch 34/100


                                                        

Train Loss: 2.0685, Acc: 0.2482
Test  Loss: 2.1878, Acc: 0.2548

Epoch 35/100


                                                        

Train Loss: 2.0061, Acc: 0.2602
Test  Loss: 1.9171, Acc: 0.2728

Epoch 36/100


                                                        

Train Loss: 2.0658, Acc: 0.2515
Test  Loss: 1.8679, Acc: 0.3087

Epoch 37/100


                                                        

Train Loss: 2.0648, Acc: 0.2502
Test  Loss: 1.9933, Acc: 0.2432

Epoch 38/100


                                                        

Train Loss: 2.0714, Acc: 0.2501
Test  Loss: 2.2982, Acc: 0.2012

Epoch 39/100


                                                        

Train Loss: 2.1901, Acc: 0.2396
Test  Loss: 2.2118, Acc: 0.2555

Epoch 40/100


                                                        

Train Loss: 2.0389, Acc: 0.2622
Test  Loss: 1.9753, Acc: 0.2524

Epoch 41/100


                                                        

Train Loss: 1.9871, Acc: 0.2721
Test  Loss: 1.8585, Acc: 0.2682

Epoch 42/100


                                                        

Train Loss: 1.9500, Acc: 0.2875
Test  Loss: 1.9214, Acc: 0.2917

Epoch 43/100


                                                        

Train Loss: 1.9649, Acc: 0.2817
Test  Loss: 2.0313, Acc: 0.2836

Epoch 44/100


                                                        

Train Loss: 2.0564, Acc: 0.2708
Test  Loss: 2.0747, Acc: 0.2395

Epoch 45/100


                                                        

Train Loss: 2.0866, Acc: 0.2538
Test  Loss: 1.8563, Acc: 0.3008

Epoch 46/100


                                                        

Train Loss: 2.0037, Acc: 0.2731
Test  Loss: 1.9399, Acc: 0.2788

Epoch 47/100


                                                        

Train Loss: 1.9680, Acc: 0.2797
Test  Loss: 1.7966, Acc: 0.3111

Epoch 48/100


                                                        

Train Loss: 1.9052, Acc: 0.3051
Test  Loss: 1.9095, Acc: 0.2806

Epoch 49/100


                                                        

Train Loss: 1.9030, Acc: 0.2994
Test  Loss: 1.9960, Acc: 0.2850

Epoch 50/100


                                                        

Train Loss: 1.8534, Acc: 0.3176
Test  Loss: 1.9689, Acc: 0.2980

Epoch 51/100


                                                        

Train Loss: 1.8523, Acc: 0.3178
Test  Loss: 1.8792, Acc: 0.2988

Epoch 52/100


                                                        

Train Loss: 1.8405, Acc: 0.3281
Test  Loss: 1.9980, Acc: 0.2978

Epoch 53/100


                                                        

Train Loss: 1.8758, Acc: 0.3145
Test  Loss: 1.7811, Acc: 0.3173

Epoch 54/100


                                                        

Train Loss: 1.8372, Acc: 0.3279
Test  Loss: 2.0169, Acc: 0.2849

Epoch 55/100


                                                        

Train Loss: 1.8671, Acc: 0.3244
Test  Loss: 1.8132, Acc: 0.3400

Epoch 56/100


                                                        

Train Loss: 1.9459, Acc: 0.2957
Test  Loss: 2.3872, Acc: 0.2236

Epoch 57/100


                                                        

Train Loss: 1.9007, Acc: 0.3004
Test  Loss: 1.8461, Acc: 0.2925

Epoch 58/100


                                                        

Train Loss: 1.8272, Acc: 0.3329
Test  Loss: 2.3987, Acc: 0.2590

Epoch 59/100


                                                        

Train Loss: 1.8138, Acc: 0.3352
Test  Loss: 1.8488, Acc: 0.3238

Epoch 60/100


                                                        

Train Loss: 1.8145, Acc: 0.3375
Test  Loss: 1.7152, Acc: 0.3635

Epoch 61/100


                                                        

Train Loss: 1.7942, Acc: 0.3537
Test  Loss: 1.9260, Acc: 0.2978

Epoch 62/100


                                                        

Train Loss: 1.8715, Acc: 0.3204
Test  Loss: 1.8369, Acc: 0.3323

Epoch 63/100


                                                        

Train Loss: 1.8074, Acc: 0.3490
Test  Loss: 1.7721, Acc: 0.3655

Epoch 64/100


                                                        

Train Loss: 1.7483, Acc: 0.3600
Test  Loss: 1.8401, Acc: 0.3660

Epoch 65/100


                                                        

Train Loss: 1.7348, Acc: 0.3765
Test  Loss: 1.9797, Acc: 0.3734

Epoch 66/100


                                                        

Train Loss: 1.7233, Acc: 0.3775
Test  Loss: 1.8874, Acc: 0.3555

Epoch 67/100


                                                        

Train Loss: 1.7533, Acc: 0.3712
Test  Loss: 2.3197, Acc: 0.2398

Epoch 68/100


                                                        

Train Loss: 1.8625, Acc: 0.3258
Test  Loss: 1.9247, Acc: 0.3280

Epoch 69/100


                                                        

Train Loss: 1.7871, Acc: 0.3479
Test  Loss: 1.7398, Acc: 0.3546

Epoch 70/100


                                                        

Train Loss: 1.8078, Acc: 0.3462
Test  Loss: 1.7829, Acc: 0.3109

Epoch 71/100


                                                        

Train Loss: 1.7698, Acc: 0.3612
Test  Loss: 1.8553, Acc: 0.3633

Epoch 72/100


                                                        

Train Loss: 1.7175, Acc: 0.3739
Test  Loss: 1.8173, Acc: 0.3675

Epoch 73/100


                                                        

Train Loss: 1.7508, Acc: 0.3737
Test  Loss: 1.7348, Acc: 0.3642

Epoch 74/100


                                                        

Train Loss: 1.7935, Acc: 0.3582
Test  Loss: 1.6942, Acc: 0.3886

Epoch 75/100


                                                        

Train Loss: 1.7107, Acc: 0.3800
Test  Loss: 1.8898, Acc: 0.3713

Epoch 76/100


                                                        

Train Loss: 1.7234, Acc: 0.3785
Test  Loss: 1.7579, Acc: 0.3569

Epoch 77/100


                                                        

Train Loss: 1.7407, Acc: 0.3696
Test  Loss: 1.6731, Acc: 0.3830

Epoch 78/100


                                                        

Train Loss: 1.8296, Acc: 0.3491
Test  Loss: 2.0560, Acc: 0.3055

Epoch 79/100


                                                        

Train Loss: 1.8273, Acc: 0.3406
Test  Loss: 1.6328, Acc: 0.3779

Epoch 80/100


                                                        

Train Loss: 1.7295, Acc: 0.3681
Test  Loss: 1.6977, Acc: 0.3839

Epoch 81/100


                                                        

Train Loss: 1.7122, Acc: 0.3773
Test  Loss: 1.7942, Acc: 0.3597

Epoch 82/100


                                                        

Train Loss: 1.7207, Acc: 0.3821
Test  Loss: 1.6987, Acc: 0.3723

Epoch 83/100


                                                        

Train Loss: 1.7223, Acc: 0.3760
Test  Loss: 1.7225, Acc: 0.3551

Epoch 84/100


                                                        

Train Loss: 1.7262, Acc: 0.3811
Test  Loss: 1.7847, Acc: 0.3577

Epoch 85/100


                                                        

Train Loss: 1.7404, Acc: 0.3684
Test  Loss: 1.6339, Acc: 0.3860

Epoch 86/100


                                                        

Train Loss: 1.6763, Acc: 0.3977
Test  Loss: 1.6277, Acc: 0.3939

Epoch 87/100


                                                        

Train Loss: 1.6759, Acc: 0.3946
Test  Loss: 1.6027, Acc: 0.4047

Epoch 88/100


                                                        

Train Loss: 1.6515, Acc: 0.4023
Test  Loss: 1.6028, Acc: 0.4108

Epoch 89/100


                                                        

Train Loss: 1.7035, Acc: 0.3852
Test  Loss: 1.6233, Acc: 0.4122

Epoch 90/100


                                                        

Train Loss: 1.6859, Acc: 0.3893
Test  Loss: 1.8174, Acc: 0.3506

Epoch 91/100


                                                        

Train Loss: 1.6952, Acc: 0.3893
Test  Loss: 1.6502, Acc: 0.3965

Epoch 92/100


                                                        

Train Loss: 1.7031, Acc: 0.3852
Test  Loss: 1.8170, Acc: 0.3030

Epoch 93/100


                                                        

Train Loss: 1.6502, Acc: 0.3995
Test  Loss: 1.8379, Acc: 0.3625

Epoch 94/100


                                                        

Train Loss: 1.7768, Acc: 0.3702
Test  Loss: 1.7085, Acc: 0.3691

Epoch 95/100


                                                        

Train Loss: 1.7276, Acc: 0.3812
Test  Loss: 1.8175, Acc: 0.3803

Epoch 96/100


                                                        

Train Loss: 1.6835, Acc: 0.3941
Test  Loss: 1.7718, Acc: 0.3902

Epoch 97/100


                                                        

Train Loss: 1.6523, Acc: 0.4027
Test  Loss: 1.6904, Acc: 0.3821

Epoch 98/100


                                                        

Train Loss: 1.7083, Acc: 0.3824
Test  Loss: 1.7798, Acc: 0.3619

Epoch 99/100


                                                        

Train Loss: 1.6715, Acc: 0.3958
Test  Loss: 1.7978, Acc: 0.3716

Epoch 100/100


                                                        

Train Loss: 1.6688, Acc: 0.3992
Test  Loss: 1.5379, Acc: 0.4249

Best Accuracy on cifar10: 0.4249


