# KAN Variational Inference

In [4]:
import kan
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from kan import KAN, create_dataset, KANLayer

In [7]:
# Mnist data
from torchvision import datasets, transforms
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size=100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=100, shuffle=True)

In [8]:
# Classical VAE

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = torch.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [9]:
# Define the loss function (ELBO)
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

In [10]:
# Train the VAE
vae = VAE()
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
vae.train()
losses = [] 
for epoch in range(10):
    for i, (data, _) in enumerate(train_loader):
        data = data.view(-1, 784)
        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        if i % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, i * len(data), len(train_loader.dataset),
                100. * i / len(train_loader), loss.item()))



In [18]:
# Test the VAE
vae.eval()
test_loss = 0
with torch.no_grad():
    for i, (data, _) in enumerate(test_loader):
        data = data.view(-1, 784)
        recon_batch, mu, logvar = vae(data)
        test_loss += loss_function(recon_batch, data, mu, logvar).item()
test_loss /= len(test_loader.dataset)
print('Test set loss: {:.4f}'.format(test_loss))

Test set loss: 105.3020


In [19]:
# Training loss plot
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

NameError: name 'losses' is not defined

In [None]:
# KAN-VAE

class KANVAE(nn.Module):
    def __init__(self):
        super(KANVAE, self).__init__()
        self.fc1 = nn.Linear(784, 400)
        self.