In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


## Downlaod MNSIT

In [None]:
# Binarize MNIST dataset
def binarize_data(x):
    return (x > 0.5).float()

# Load MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(binarize_data)  # Binarize the data
])

train_dataset = datasets.MNIST('./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST('./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)



## Define VAE

In [None]:
class BernoulliVAE(nn.Module):
    def __init__(self, latent_dim=40):
        super(BernoulliVAE, self).__init__()
        self.latent_dim = latent_dim

        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar


In [None]:
def loss_function(recon_x, x, mu, logvar):
    # Reconstruction loss
    bce = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 28 * 28), reduction='sum')

    # KL Divergence
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return bce + kl_div


## Train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 40
vae = BernoulliVAE(latent_dim=latent_dim).to(device)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

epochs = 50
vae.train()
for epoch in range(epochs):
    train_loss = 0
    for batch in train_loader:
        data, _ = batch
        data = data.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = vae(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}")


Epoch 1, Loss: 156.9529
Test Loss: 116.1648
Epoch 2, Loss: 108.2379
Test Loss: 101.4151
Epoch 3, Loss: 99.7825
Test Loss: 96.2774
Epoch 4, Loss: 95.8233
Test Loss: 93.6222
Epoch 5, Loss: 93.5137
Test Loss: 92.1666
Epoch 6, Loss: 91.8868
Test Loss: 90.6394
Epoch 7, Loss: 90.6144
Test Loss: 89.8236
Epoch 8, Loss: 89.6615
Test Loss: 89.1215
Epoch 9, Loss: 88.8404
Test Loss: 88.3381
Epoch 10, Loss: 88.2634
Test Loss: 88.0514


## Show Results

In [None]:
vae.eval()
with torch.no_grad():
    z = torch.randn(16, latent_dim).to(device)
    samples = vae.decode(z).view(-1, 1, 28, 28).cpu()

    # Plot generated samples
    plt.figure(figsize=(5, 5))
    for i, img in enumerate(samples[:16]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(img.squeeze(), cmap='gray')
        plt.axis('off')
    plt.show()
