In [None]:
pip install torch torchvision matplotlib


In [None]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt


In [None]:
# Define the VAE model
class VAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=20):
        super(VAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 400),
            nn.ReLU(),
            nn.Linear(400, 2 * latent_dim)  # Mean and log variance
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 400),
            nn.ReLU(),
            nn.Linear(400, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        params = self.encoder(x)
        mean, log_var = params[:, :latent_dim], params[:, latent_dim:]
        return mean, log_var

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        return mean + epsilon * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mean, log_var = self.encode(x)
        z = self.reparameterize(mean, log_var)
        return self.decode(z), mean, log_var

latent_dim = 20
vae = VAE(latent_dim=latent_dim)

In [None]:
def loss_function(reconstructed, original, mean, log_var):
    # Reconstruction loss
    reconstruction_loss = nn.functional.binary_cross_entropy(reconstructed, original, reduction="sum")
    # KL divergence
    kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reconstruction_loss + kl_divergence


In [None]:
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)


In [None]:
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
epochs = 10

for epoch in range(epochs):
    vae.train()
    train_loss = 0
    for batch in train_loader:
        images, _ = batch
        optimizer.zero_grad()
        reconstructed, mean, log_var = vae(images)
        loss = loss_function(reconstructed, images, mean, log_var)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {train_loss / len(train_loader.dataset):.4f}")


In [None]:
# Generate synthetic images
vae.eval()
with torch.no_grad():
    z = torch.randn(16, latent_dim)
    generated_images = vae.decode(z).view(-1, 28, 28)

# Plot the generated images
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated_images[i].cpu().numpy(), cmap="gray")
    ax.axis("off")
plt.suptitle("Generated Images")
plt.show()
