**Task IX: Kolmogorov-Arnold Network**

Implement a classical Kolmogorov-Arnold Network using basis-splines or some other KAN architecture and apply it to MNIST. Show its performance on the test data. Comment on potential ideas to extend this classical KAN architecture to a quantum KAN and sketch out the architecture in detail.


In [1]:
!pip install pennylane
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import time
import pennylane as qml

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device setting
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using: {device}")

# Synthetic MNIST-like dataset
class SynthMNIST(Dataset):
    def __init__(self, n_samples=10000, is_train=True):
        super().__init__()
        self.n_samples = n_samples
        self.is_train = is_train

        # Generate synthetic images and labels
        self.images = np.zeros((n_samples, 28, 28), dtype=np.float32)
        self.labels = np.zeros(n_samples, dtype=np.int64)

        # Create digit patterns (0-9)
        patterns = []
        for i in range(10):
            pattern = np.zeros((28, 28), dtype=np.float32)

            # Simple shapes for digits
            if i == 0:  # Circle
                for x in range(28):
                    for y in range(28):
                        if 8 < x < 20 and 8 < y < 20:
                            dx, dy = x - 14, y - 14
                            dist = np.sqrt(dx**2 + dy**2)
                            if 4 < dist < 6:
                                pattern[x, y] = 1.0
            elif i == 1:  # Vertical line
                pattern[5:23, 13:15] = 1.0
            elif i == 2:  # Zigzag
                for x in range(6, 22, 4):
                    pattern[x:x+4, 8:20] = 1.0
            elif i == 3:  # Cross
                pattern[10:18, 8:20] = 1.0
                pattern[6:22, 13:15] = 1.0
            elif i == 4:  # Square
                pattern[8:20, 8:20] = 1.0
                pattern[10:18, 10:18] = 0.0
            elif i == 5:  # Diamond
                for x in range(28):
                    for y in range(28):
                        if abs(x - 14) + abs(y - 14) < 8:
                            pattern[x, y] = 1.0
            elif i == 6:  # Plus
                pattern[13:15, 8:20] = 1.0
                pattern[8:20, 13:15] = 1.0
            elif i == 7:  # T-shape
                pattern[8:10, 8:20] = 1.0
                pattern[10:22, 13:15] = 1.0
            elif i == 8:  # Two circles
                for x in range(28):
                    for y in range(28):
                        if 8 < x < 20:
                            dy1 = y - 10
                            dy2 = y - 18
                            dx = x - 14
                            dist1 = np.sqrt(dx**2 + dy1**2)
                            dist2 = np.sqrt(dx**2 + dy2**2)
                            if dist1 < 4 or dist2 < 4:
                                pattern[x, y] = 1.0
            elif i == 9:  # Circle with tail
                for x in range(28):
                    for y in range(28):
                        if 8 < x < 20 and 8 < y < 16:
                            dx, dy = x - 14, y - 12
                            dist = np.sqrt(dx**2 + dy**2)
                            if dist < 4:
                                pattern[x, y] = 1.0
                    pattern[14:16, 16:22] = 1.0

            patterns.append(pattern)

        # Generate dataset samples
        samples_per_class = n_samples // 10
        for i in range(10):
            start_idx = i * samples_per_class
            end_idx = (i + 1) * samples_per_class

            for j in range(start_idx, end_idx):
                # Add noise and shift
                noise = np.random.normal(0, 0.1, (28, 28))
                shift_x = np.random.randint(-2, 3)
                shift_y = np.random.randint(-2, 3)

                image = np.roll(np.roll(patterns[i], shift_x, axis=0), shift_y, axis=1)
                image = np.clip(image + noise, 0, 1)

                self.images[j] = image
                self.labels[j] = i

        # Shuffle data
        shuffle_idx = np.random.permutation(n_samples)
        self.images = self.images[shuffle_idx]
        self.labels = self.labels[shuffle_idx]

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        image = self.images[idx].flatten()  # Flatten for KAN
        image = torch.FloatTensor(image)
        label = self.labels[idx]
        return image, label

# Create datasets
train_data = SynthMNIST(n_samples=5000, is_train=True)
test_data = SynthMNIST(n_samples=1000, is_train=False)

# Hyperparams
batch_sz = 64
epochs = 5
lr = 0.001

# DataLoaders
train_loader = DataLoader(train_data, batch_size=batch_sz, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_sz, shuffle=False)

# B-spline basis function
class BSplineBasis(nn.Module):
    def __init__(self, n_basis=10, degree=3):
        super().__init__()
        self.n_basis = n_basis
        self.degree = degree
        self.knots = torch.linspace(0, 1, n_basis + degree + 1)

    def cox_deboor(self, x, i, k):
        """Cox-DeBoor recursion for B-spline basis function"""
        if k == 0:
            return ((x >= self.knots[i]) & (x < self.knots[i+1])).float()

        term1 = 0.0
        denom1 = self.knots[i+k] - self.knots[i]
        if denom1 > 0:
            term1 = (x - self.knots[i]) / denom1 * self.cox_deboor(x, i, k-1)

        term2 = 0.0
        denom2 = self.knots[i+k+1] - self.knots[i+1]
        if denom2 > 0:
            term2 = (self.knots[i+k+1] - x) / denom2 * self.cox_deboor(x, i+1, k-1)

        return term1 + term2

    def forward(self, x):
        # Normalize input to [0, 1]
        x = x.clamp(0, 1)

        # Calculate basis function values
        basis_vals = torch.zeros(x.shape[0], self.n_basis, device=x.device)
        for i in range(self.n_basis):
            basis_vals[:, i] = self.cox_deboor(x, i, self.degree)

        basis_vals[x == 1, -1] = 1.0 # Edge case

        return basis_vals

# KAN Layer with B-splines
class KanLayer(nn.Module):
    def __init__(self, in_feats, out_feats, n_basis=10):
        super().__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.n_basis = n_basis

        # B-spline basis for each input feature
        self.splines = nn.ModuleList([BSplineBasis(n_basis=n_basis) for _ in range(in_feats)])

        # Weights for linear combination
        self.weights = nn.Parameter(torch.Tensor(out_feats, in_feats, n_basis))
        self.bias = nn.Parameter(torch.Tensor(out_feats))

        # Initialize weights and bias
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x):
        # x shape: (batch_size, in_features)
        batch_sz = x.shape[0]

        spline_outs = []
        for i in range(self.in_feats):
            spline_out = self.splines[i](x[:, i])  # (batch_size, n_basis)
            spline_outs.append(spline_out)

        outputs = torch.zeros(batch_sz, self.out_feats, device=x.device)
        for o in range(self.out_feats):
            for i in range(self.in_feats):
                outputs[:, o] += torch.matmul(spline_outs[i], self.weights[o, i])

        return outputs + self.bias

# Efficient KAN layer
import math
class EffKanLayer(nn.Module):
    def __init__(self, in_feats, out_feats, n_basis=10):
        super().__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.n_basis = n_basis

        # Basis function network for each feature
        self.basis_net = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1, n_basis),
                nn.Sigmoid(),
                nn.Linear(n_basis, n_basis),
                nn.Tanh()
            ) for _ in range(in_feats)
        ])

        # Output linear layer
        self.output_layer = nn.Linear(in_feats * n_basis, out_feats)

    def forward(self, x):
        # x shape: (batch_size, in_features)
        batch_sz = x.shape[0]

        basis_outs = []
        for i in range(self.in_feats):
            feature = x[:, i].view(-1, 1)  # (batch_size, 1)
            basis_out = self.basis_net[i](feature)  # (batch_size, n_basis)
            basis_outs.append(basis_out)

        combined = torch.cat(basis_outs, dim=1)  # (batch_size, in_features * n_basis)

        return self.output_layer(combined)

# Kolmogorov-Arnold Network (KAN)
class KanNet(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, n_basis=10):
        super().__init__()

        layers = []
        dims = [input_dim] + hidden_dims + [output_dim]
        for i in range(len(dims) - 1):
            layer = EffKanLayer(dims[i], dims[i+1], n_basis)
            layers.append(layer)
            if i < len(dims) - 2: # No ReLU after last layer
                layers.append(nn.ReLU())

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

# Model initialization
input_dim = 28*28  # Flattened image
hidden_dims = [256, 128]
output_dim = 10
n_basis = 6

model = KanNet(input_dim, hidden_dims, output_dim, n_basis).to(device)
print(f"KAN model params: {sum(p.numel() for p in model.parameters())}")

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
steps_per_epoch = len(train_loader)
train_losses = []
print("Training started...")
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    start_time = time.time()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (i+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{i+1}/{steps_per_epoch}], Loss: {loss.item():.4f}')

    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)
    epoch_time = time.time() - start_time
    print(f'Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}, Time: {epoch_time:.2f}s')

# Test model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f'Test Accuracy: {test_accuracy:.2f}%')

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.savefig('kan_training_loss.png')
plt.close()

# ===========================
# Quantum KAN (QKAN)
# ===========================

# Quantum Device
n_qbits = 4
q_device = qml.device("default.qubit", wires=n_qbits)
print(f"Quantum device qubits: {n_qbits}")

# Quantum basis circuit
@qml.qnode(q_device)
def quantum_basis_circuit(inputs, weights):
    # Input encoding
    qml.templates.AngleEmbedding(inputs, wires=range(n_qbits))

    # Parametrized quantum layers
    qml.templates.StronglyEntanglingLayers(weights, wires=range(n_qbits))

    # Measurement (PauliZ expectation)
    return [qml.expval(qml.PauliZ(i)) for i in range(n_qbits)]

# Quantum Basis Function Module
class QuantumBasisModule(nn.Module):
    def __init__(self, n_qbits=4, n_layers=2):
        super().__init__()
        self.n_qbits = n_qbits
        self.n_layers = n_layers

        # Quantum weights
        self.q_weights = nn.Parameter(
            torch.FloatTensor(n_layers, n_qbits, 3).uniform_(0, 2 * np.pi)
        )

    def forward(self, x):
        """Apply quantum basis to input feature x"""
        # Scale input to [0, 2pi]
        x_scaled = (x * 0.5 + 0.5) * 2 * np.pi

        batch_sz = x.shape[0]
        output = torch.zeros(batch_sz, self.n_qbits, device=x.device)

        for i in range(batch_sz):
            # Input for qubits (reusing same input value)
            inputs = torch.ones(self.n_qbits) * x_scaled[i].item()

            # Quantum circuit output
            weights = self.q_weights.detach().cpu().numpy() # detach for qml compatibility
            q_output = torch.tensor(quantum_basis_circuit(inputs, weights))

            output[i] = q_output

        return output

# Quantum KAN Layer
class QKanLayer(nn.Module):
    def __init__(self, in_feats, out_feats, n_qbits=4, n_layers=2):
        super().__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.n_qbits = n_qbits

        # Quantum basis for each feature
        self.quantum_basis = nn.ModuleList([
            QuantumBasisModule(n_qbits, n_layers) for _ in range(in_feats)
        ])

        # Linear output layer
        self.output_layer = nn.Linear(in_feats * n_qbits, out_feats)

    def forward(self, x):
        # x shape: (batch_size, in_features)
        batch_sz = x.shape[0]

        q_outputs = []
        for i in range(self.in_feats):
            feature = x[:, i]  # (batch_size)
            q_out = self.quantum_basis[i](feature)  # (batch_size, n_qbits)
            q_outputs.append(q_out)

        combined = torch.cat(q_outputs, dim=1)  # (batch_size, in_features * n_qbits)

        return self.output_layer(combined)

# Batched QKAN Layer (Simulated Quantum)
class BatchQKanLayer(nn.Module):
    def __init__(self, in_feats, out_feats, q_feats=10):
        super().__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.q_feats = q_feats

        # Simulated quantum processing with nonlinear layers
        self.quantum_sim = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1, q_feats),
                nn.Tanh(),
                nn.Linear(q_feats, q_feats),
                nn.Sigmoid()
            ) for _ in range(in_feats)
        ])

        # Linear output
        self.output_layer = nn.Linear(in_feats * q_feats, out_feats)

    def forward(self, x):
        # x shape: (batch_size, in_features)
        batch_sz = x.shape[0]

        q_outputs = []
        for i in range(min(self.in_feats, 16)):  # Limit features for demonstration
            feature = x[:, i].view(-1, 1)  # (batch_size, 1)
            q_out = self.quantum_sim[i](feature)  # (batch_size, quantum_features)
            q_outputs.append(q_out)

        combined = torch.cat(q_outputs, dim=1)  # (batch_size, features * quantum_features)

        return self.output_layer(combined)

# Quantum KAN (QKAN) Architecture
class QKanNet(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, q_feats=10):
        super().__init__()

        # Classical input reduction
        self.input_reduce = nn.Linear(input_dim, hidden_dims[0])
        self.act1 = nn.ReLU()

        # QKAN layer (simulated)
        self.qkan_layer = BatchQKanLayer(hidden_dims[0], hidden_dims[1], q_feats)
        self.act2 = nn.ReLU()

        # Output layer
        self.output_layer = nn.Linear(hidden_dims[1], output_dim)

    def forward(self, x):
        x = self.act1(self.input_reduce(x))
        x = self.act2(self.qkan_layer(x))
        x = self.output_layer(x)
        return x

print("\nQuantum KAN (QKAN) Architecture:")
print("QKAN enhances classical KAN with quantum basis functions.")

print("\nQKAN Components:")
print("1. Input Encoding: Classical features to quantum states")
print("2. Quantum Basis Functions: Quantum circuits for feature processing")
print("3. Quantum Processing: Variational quantum circuits transform data")
print("4. Measurement: Quantum to classical values")
print("5. Classical Aggregation: Combine quantum outputs")

print("\nQKAN Advantages:")
print("1. Enhanced expressivity")
print("2. Quantum interference for pattern capture")
print("3. Entanglement for basis function correlations")
print("4. Potential quantum speedup")

print("\nImplementation:")
print("1. Feature-wise quantum circuits")
print("2. Quantum embedding for classical data")
print("3. Variational circuits as basis functions")
print("4. Hybrid classical-quantum architecture")

print("\nChallenges & Future:")
print("1. Efficient quantum encoding for high-dim data")
print("2. Optimizing qubits for basis functions")
print("3. Training hybrid models")
print("4. Hardware implementation on NISQ devices")
print("5. Theoretical analysis of quantum advantage")

Collecting pennylane
  Downloading PennyLane-0.40.0-py3-none-any.whl.metadata (10 kB)
Collecting rustworkx>=0.14.0 (from pennylane)
  Downloading rustworkx-0.16.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting tomlkit (from pennylane)
  Downloading tomlkit-0.13.2-py3-none-any.whl.metadata (2.7 kB)
Collecting appdirs (from pennylane)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting autoray>=0.6.11 (from pennylane)
  Downloading autoray-0.7.0-py3-none-any.whl.metadata (5.8 kB)
Collecting pennylane-lightning>=0.40 (from pennylane)
  Downloading PennyLane_Lightning-0.40.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (27 kB)
Collecting diastatic-malt (from pennylane)
  Downloading diastatic_malt-2.15.2-py3-none-any.whl.metadata (2.6 kB)
Collecting scipy-openblas32>=0.3.26 (from pennylane-lightning>=0.40->pennylane)
  Downloading scipy_openblas32-0.3.29.0.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5