In [24]:
import torch
import torch.nn as nn

class BSplineActivation(nn.Module):
    def __init__(self, num_points=5, domain=(-1, 1)):
        super(BSplineActivation, self).__init__()
        # Number of B-spline grid points (trainable parameters)
        self.num_points = num_points
        self.domain = domain
        
        # Trainable grid points (control points for the B-spline)
        self.grid_points = nn.Parameter(torch.linspace(domain[0], domain[1], num_points))
        
        # Coefficients for each B-spline segment
        self.coefficients = nn.Parameter(torch.ones(num_points))

    def forward(self, x):
        # Map input x to be in the domain of [-1, 1]
        x = torch.clamp(x, self.domain[0], self.domain[1])
        
        # Find the indices of the grid points
        grid_indices = torch.bucketize(x, self.grid_points)

        # Interpolate between the grid points using B-splines
        output = torch.zeros_like(x)
        for i in range(self.num_points - 1):
            mask = (grid_indices == i)
            output[mask] = self.coefficients[i] * (x[mask] - self.grid_points[i])
        
        return output

In [25]:
class KANModel(nn.Module):
    def __init__(self, input_size=28*28, hidden_size=512, output_size=10, num_splines=5):
        super(KANModel, self).__init__()
        
        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        
        # B-spline activations
        self.b_spline = BSplineActivation(num_points=num_splines)
        
    def forward(self, x):
        # Flatten the input image
        x = x.view(x.size(0), -1)
        
        # Apply layers with B-spline activations
        x = self.fc1(x)
        x = self.b_spline(x)
        x = self.fc2(x)
        x = self.b_spline(x)
        x = self.fc3(x)
        
        return x

In [27]:
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import Dataset

# Define the transform for the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset from Hugging Face
dataset = load_dataset("ylecun/mnist")

# Define a custom Dataset class for MNIST with transformation
class MNISTDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = self.dataset[idx]['image']
        label = self.dataset[idx]['label']
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Apply the transformation to the dataset
train_dataset = MNISTDataset(dataset['train'], transform)
test_dataset = MNISTDataset(dataset['test'], transform)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Instantiate the models
kan_model = KANModel(input_size=28*28, hidden_size=512, output_size=10, num_splines=5)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
kan_optimizer = optim.Adam(kan_model.parameters(), lr=0.001)

# Training Loop for MLP and KAN
def train(model, train_loader, optimizer, criterion, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

print("Training KAN model...")
train(kan_model, train_loader, kan_optimizer, criterion)

# Testing function to evaluate model
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy: {accuracy:.2f}%')


print("Evaluating KAN model...")
test(kan_model, test_loader)

Training KAN model...
Epoch 1, Loss: 0.40556850476559797
Epoch 2, Loss: 0.22589783714428893
Epoch 3, Loss: 0.2111099288721424
Epoch 4, Loss: 0.21233184316725745
Epoch 5, Loss: 0.22737231423288012
Evaluating KAN model...
Accuracy: 93.02%


In [32]:
import torch
import torch.nn as nn
import torch.optim as optim

class KANModelWithGridExtension(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_splines=10):
        super(KANModelWithGridExtension, self).__init__()
        
        # Define layers
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        # Define trainable grid points for the B-splines
        self.grid_points = nn.Parameter(torch.randn(num_splines))
        
        # Set number of splines
        self.num_splines = num_splines
    
    def forward(self, x):
        # First hidden layer
        x = self.fc1(x.view(x.size(0), -1))  # Flatten the input
        x = torch.relu(x)
        
        # Apply the KAN activation with grid extension
        x = self.kan_activation(x)
        
        # Second layer
        x = self.fc2(x)
        
        return x
    
    def kan_activation(self, x):
        # Apply a simple form of grid-based activation
        # Interpolate activations based on grid points
        grid_values = self.interpolate_grid(x, self.grid_points)
        return grid_values
    
    def interpolate_grid(self, x, grid_points):
        # Linear interpolation function between grid points and activations
        batch_size = x.size(0)
        
        # Normalize x to [0, 1]
        min_val, _ = torch.min(x, dim=1, keepdim=True)
        max_val, _ = torch.max(x, dim=1, keepdim=True)
        norm_x = (x - min_val) / (max_val - min_val + 1e-6)  # Avoid division by zero
        
        # Rescale norm_x to map it to the range of the grid points
        grid_size = grid_points.size(0) - 1
        indices = torch.floor(norm_x * grid_size).long()  # Get the indices for interpolation
        
        # Get the fractional part of norm_x
        frac = norm_x * grid_size - indices.float()
        
        # Ensure indices are within valid range
        indices = torch.clamp(indices, 0, grid_size - 1)
        
        # Interpolate between the grid points
        lower = grid_points[indices]
        upper = grid_points[indices + 1]
        
        # Linear interpolation formula
        interp_values = lower + frac * (upper - lower)
        
        return interp_values

# Check if a GPU is available, otherwise use the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model with grid extension and send it to the selected device
kan_model_with_grid = KANModelWithGridExtension(input_dim=28*28, hidden_dim=128, output_dim=10, num_splines=20).to(device)

# Define loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(kan_model_with_grid.parameters(), lr=0.001)

# Training loop with grid extension
epochs = 10
for epoch in range(epochs):
    kan_model_with_grid.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()  # Zero the gradients
        
        outputs = kan_model_with_grid(images)  # Forward pass
        loss = loss_fn(outputs, labels)  # Compute loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

# Evaluate the model
kan_model_with_grid.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = kan_model_with_grid(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Accuracy: {accuracy:.2f}%")

Epoch 1/10, Loss: 0.5185
Epoch 2/10, Loss: 0.2424
Epoch 3/10, Loss: 0.1857
Epoch 4/10, Loss: 0.1526
Epoch 5/10, Loss: 0.1316
Epoch 6/10, Loss: 0.1133
Epoch 7/10, Loss: 0.1018
Epoch 8/10, Loss: 0.0882
Epoch 9/10, Loss: 0.0856
Epoch 10/10, Loss: 0.0749
Accuracy: 97.03%


In [33]:
# Save KAN model
torch.save(kan_model_with_grid.state_dict(), "kan_model.pt")
print("KAN model saved successfully.")

KAN model saved successfully.
