In [5]:
# PadeKAN + CNN
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


# Define the PadeLayer
class PadeLayer(nn.Module):
    def __init__(self, input_dim, out_dim, order_n, order_m, addbias=True):
        super(PadeLayer, self).__init__()
        self.input_dim = input_dim
        self.out_dim = out_dim
        self.order_n = order_n
        self.order_m = order_m
        self.addbias = addbias

        # Initialize coefficients for the numerator (a_i) and denominator (b_i)
        self.numerator_coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order_n) * 0.01)
        self.denominator_coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order_m) * 0.01)
        if self.addbias:
            self.bias = nn.Parameter(torch.zeros(1, out_dim))

    def forward(self, x):
        shape = x.shape
        outshape = shape[0:-1] + (self.out_dim,)
        x = torch.reshape(x, (-1, self.input_dim))

        x_expanded = x.unsqueeze(1).expand(-1, self.out_dim, -1)

        # Compute the numerator: P(x) = a_0 + a_1 x + a_2 x^2 + ... + a_N x^N
        numerator = torch.zeros((x.shape[0], self.out_dim), device=x.device)
        for i in range(self.order_n):
            term = (x_expanded ** i) * self.numerator_coeffs[:, :, i]
            numerator += term.sum(dim=-1)

        # Compute the denominator: Q(x) = 1 + b_1 x + b_2 x^2 + ... + b_M x^M
        denominator = torch.ones((x.shape[0], self.out_dim), device=x.device)
        for i in range(1, self.order_m+1 ):
            term = (x_expanded ** i) * self.denominator_coeffs[:, :, i-1]
            denominator += term.sum(dim=-1)

        # Compute the final output as P(x) / Q(x)
        #y = numerator / torch.clamp(denominator, min=1e-8)
        y = numerator / (1+abs(denominator))

        if self.addbias:
            y += self.bias

        y = torch.reshape(y, outshape)
        return y

class PadeCNN(nn.Module):
  def __init__(self, input_dim=28, order_n=3, order_m=3):
    super(PadeCNN, self).__init__()
    self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1) # in_channels=1
    self.pool1 = nn.MaxPool2d(2)
    self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
    self.pool2 = nn.MaxPool2d(2)
    self.pade1 = PadeLayer(32 * 7*7, 128, order_n=order_n, order_m=order_m)  # Reduced input_dim due to pooling
    self.pade2 = PadeLayer(128, 10, order_n=order_n, order_m=order_m)

  def forward(self, x):
    #x = x.unsqueeze(1)  # Add a channel dimension for Conv1d, shape becomes [batch_size, 1, 32]
    x = F.selu(self.conv1(x))
    x = self.pool1(x)
    x = F.selu(self.conv2(x))
    x = self.pool2(x)
    x = x.view(x.size(0), -1)
    x = self.pade1(x)
    x = self.pade2(x)
    return x


# Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PadeCNN().to(device)

# optimizer = optim.LBFGS(model.parameters(), lr=0.001)
# optimizer = optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-4, momentum=0.9)
optimizer = optim.RAdam(model.parameters(), lr=0.0001)



# Training
def train(model, device, train_loader, optimizer, epoch):
  model.train()
  for i, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    output = model(data.to(device))
    loss = nn.CrossEntropyLoss()(output, target.to(device))
    loss.backward()
    optimizer.step()
    if i % 10 == 0:
      print(f'Train Epoch: {epoch} [{i * len(data)}/{len(train_loader.dataset)} ({100. * i / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Evaluation
def evaluate(model, device, test_loader):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      data, target = data.to(device), target.to(device)
      output = model(data)
      test_loss += nn.CrossEntropyLoss()(output, target).item()
      pred = output.argmax(dim=1, keepdim=True)
      correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Running
for epoch in range(0, 20):
  train(model, device, train_loader, optimizer, epoch)
evaluate(model, device, test_loader)


Test set: Average loss: 0.0002, Accuracy: 9890/10000 (99%)



In [4]:
import torch
print(torch.cuda.is_available())

True


In [5]:
torch.cuda.get_arch_list()

[]