In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class TraditionalMLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, num_layers):
        super(TraditionalMLP, self).__init__()

        self.layers = nn.ModuleList([nn.Sequential(
            nn.Linear(input_dim if i == 0 else hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim if i == num_layers - 1 else hidden_dim)
        ) for i in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [131]:
def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        m.bias.data.fill_(0.01)

def test_mlp():
    model = TraditionalMLP(2, 1, 4, 2)
    model.apply(weights_init)
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print("Parameters: ", params)
    optimizer = torch.optim.SGD(model.parameters(), lr=.1)
    objective_function = lambda u, v: torch.exp(v ** 2 + torch.sin(torch.pi * u)) + v
    with tqdm(range(1000)) as pbar:
        for i in pbar:
            loss = None

            def closure():
                optimizer.zero_grad()
                x = torch.rand(1024, 2)
                y = model(x)

                assert y.shape == (1024, 1)
                nonlocal loss
                u = x[:, 0]
                v = x[:, 1]
                loss = nn.functional.mse_loss(y.squeeze(-1), objective_function(u, v))

                loss.backward()
                return loss

            optimizer.step(closure)
            pbar.set_postfix(mse_loss=loss.item())

In [132]:
test_mlp()

Parameters:  57


100%|██████████| 1000/1000 [00:03<00:00, 268.16it/s, mse_loss=2.13]


In [91]:
model = TraditionalMLP(28 * 28, 10, 32, 2)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

27562


In [43]:
# Train on MNIST
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=512, shuffle=True)
valloader = DataLoader(valset, batch_size=512, shuffle=False)

# Define model
model = TraditionalMLP(28 * 28, 10, 32, 2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.CrossEntropyLoss()
for epoch in range(5):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )

100%|██████████| 118/118 [00:13<00:00,  8.88it/s, accuracy=0.875, loss=0.394, lr=0.001]


Epoch 1, Val Loss: 0.47861494272947314, Val Accuracy: 0.8656594663858413


100%|██████████| 118/118 [00:13<00:00,  8.99it/s, accuracy=0.854, loss=0.5, lr=0.0008]


Epoch 2, Val Loss: 0.3532404191792011, Val Accuracy: 0.8962201297283172


100%|██████████| 118/118 [00:13<00:00,  9.05it/s, accuracy=0.917, loss=0.292, lr=0.00064]


Epoch 3, Val Loss: 0.3266025297343731, Val Accuracy: 0.9040843278169632


100%|██████████| 118/118 [00:13<00:00,  9.03it/s, accuracy=0.927, loss=0.26, lr=0.000512]


Epoch 4, Val Loss: 0.31495350524783133, Val Accuracy: 0.9072437971830368


100%|██████████| 118/118 [00:13<00:00,  9.06it/s, accuracy=0.938, loss=0.244, lr=0.00041]


Epoch 5, Val Loss: 0.3057694613933563, Val Accuracy: 0.9112132340669632


In [99]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TemperatureSoftmax(nn.Module):
    def __init__(self, temperature=1, dim=None):
        super().__init__()
        self.temperature = temperature
        self.dim = dim

    def forward(self, input):
        return F.softmax(input / self.temperature, dim=self.dim)

class SplineMLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, control_points):
        super(SplineMLP, self).__init__()

        self.control_points = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, output_dim * control_points)
        )

        self.n = control_points
        self.output_dim = output_dim
        self.basis_network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, control_points),
            TemperatureSoftmax(0.5, dim=1)  # normalize the output to make it a proper basis function
        )

    def forward(self, x):
        # bs, control points
        basis_values = self.basis_network(x)  # learnable basis function
        outputs = self.control_points(x).reshape(-1, self.n, self.output_dim)
        outputs = torch.einsum('bno,bn->bo', outputs, basis_values)
        return outputs

class KAN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, control_points):
        super(KAN, self).__init__()
        self.fc1 = SplineMLP( input_dim, output_dim, hidden_dim, control_points)
    def forward(self, x):
        outputs = self.fc1(x)
        return outputs

In [135]:
import numpy as np
from tqdm import tqdm
import torch.nn as nn

def weights_init(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        m.bias.data.fill_(0.01)

def test_kan():
  model = KAN(2, 1, 3, 5)
  model.apply(weights_init)
  model_parameters = filter(lambda p: p.requires_grad, model.parameters())
  params = sum([np.prod(p.size()) for p in model_parameters])
  print("Parameters: ", params)
  optimizer = torch.optim.SGD(model.parameters(), lr=.1)
  objective_function = lambda u, v: torch.exp(v ** 2 + torch.sin(torch.pi * u)) + v
  with tqdm(range(1000)) as pbar:
      for i in pbar:
          loss = None

          def closure():
              optimizer.zero_grad()
              x = torch.rand(1024, 2)
              y = model(x)

              assert y.shape == (1024, 1)
              nonlocal loss
              u = x[:, 0]
              v = x[:, 1]
              loss = nn.functional.mse_loss(y.squeeze(-1), objective_function(u, v))

              loss.backward()
              return loss

          optimizer.step(closure)
          pbar.set_postfix(mse_loss=loss.item())

In [136]:
test_kan()

Parameters:  58


100%|██████████| 1000/1000 [00:04<00:00, 230.50it/s, mse_loss=0.0421]


In [64]:
model = KAN(28 * 28, 10, 16, 4)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

25868


In [65]:
# Train on MNIST
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# Load MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transform
)
valset = torchvision.datasets.MNIST(
    root="./data", train=False, download=True, transform=transform
)
trainloader = DataLoader(trainset, batch_size=512, shuffle=True)
valloader = DataLoader(valset, batch_size=512, shuffle=False)

# Define model
model = KAN(28 * 28, 10, 32, 4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# Define learning rate scheduler
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.8)

# Define loss
criterion = nn.CrossEntropyLoss()
for epoch in range(5):
    # Train
    model.train()
    with tqdm(trainloader) as pbar:
        for i, (images, labels) in enumerate(pbar):
            images = images.view(-1, 28 * 28).to(device)
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels.to(device))
            loss.backward()
            optimizer.step()
            accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
            pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

    # Validation
    model.eval()
    val_loss = 0
    val_accuracy = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            val_loss += criterion(output, labels.to(device)).item()
            val_accuracy += (
                (output.argmax(dim=1) == labels.to(device)).float().mean().item()
            )
    val_loss /= len(valloader)
    val_accuracy /= len(valloader)

    # Update learning rate
    scheduler.step()

    print(
        f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
    )

100%|██████████| 118/118 [00:13<00:00,  8.87it/s, accuracy=0.896, loss=0.488, lr=0.001]


Epoch 1, Val Loss: 0.38142223209142684, Val Accuracy: 0.8984202653169632


100%|██████████| 118/118 [00:13<00:00,  8.88it/s, accuracy=0.906, loss=0.27, lr=0.0008]


Epoch 2, Val Loss: 0.28616719916462896, Val Accuracy: 0.9168485760688782


100%|██████████| 118/118 [00:13<00:00,  8.95it/s, accuracy=0.896, loss=0.342, lr=0.00064]


Epoch 3, Val Loss: 0.2480481218546629, Val Accuracy: 0.9259306073188782


100%|██████████| 118/118 [00:13<00:00,  8.98it/s, accuracy=0.917, loss=0.291, lr=0.000512]


Epoch 4, Val Loss: 0.2346076589077711, Val Accuracy: 0.9322035849094391


100%|██████████| 118/118 [00:13<00:00,  8.65it/s, accuracy=0.969, loss=0.131, lr=0.00041]


Epoch 5, Val Loss: 0.2127841033041477, Val Accuracy: 0.9388614445924759
