# PyTorch VAE Implementation

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

class VAE(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, latent_size):
        super(VAE, self).__init__()

        # Encoder layers
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3_mu = nn.Linear(hidden_size2, latent_size)
        self.fc3_logvar = nn.Linear(hidden_size2, latent_size)

        # Decoder layers
        self.fc4 = nn.Linear(latent_size, hidden_size2)
        self.fc5 = nn.Linear(hidden_size2, hidden_size1)
        self.fc6 = nn.Linear(hidden_size1, input_size)

    def encoder(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        mu = self.fc3_mu(h2)
        logvar = self.fc3_logvar(h2)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decoder(self, z):
        h4 = F.relu(self.fc4(z))
        h5 = F.relu(self.fc5(h4))
        return torch.sigmoid(self.fc6(h5))

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Training Configuration
input_size = 784
hidden_size1 = 512
hidden_size2 = 256
latent_size = 20
epochs = 10
batch_size = 128
learning_rate = 1e-3

# Load MNIST data (placeholder, replace with actual dataset loading code)
X_train = torch.rand(50000, input_size)  # Example data
X_valid = torch.rand(10000, input_size)
train_loader = DataLoader(TensorDataset(X_train, X_train), batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(TensorDataset(X_valid, X_valid), batch_size=batch_size)

# Model, optimizer, and training loop
model = VAE(input_size, hidden_size1, hidden_size2, latent_size)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for x, _ in train_loader:
        optimizer.zero_grad()
        recon_x, mu, logvar = model(x)
        loss = loss_function(recon_x, x, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset):.4f}")

# Evaluation
model.eval()
valid_loss = 0
with torch.no_grad():
    for x, _ in valid_loader:
        recon_x, mu, logvar = model(x)
        loss = loss_function(recon_x, x, mu, logvar)
        valid_loss += loss.item()

    print(f"Validation Loss: {valid_loss / len(valid_loader.dataset):.4f}")
