In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, kl_divergence
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Hyperparameters
latent_dim = 20
input_dim = 784  # MNIST 28x28
hidden_dim = 400
num_samples = 5  # Number of importance samples
batch_size = 128
epochs = 10
learning_rate = 1e-3

# VAE Model
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # mu
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # logvar
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
    
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc21(h), self.fc22(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, input_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Importance-weighted loss function
def importance_weighted_loss(recon_x, x, mu, logvar, num_samples=5):
    batch_size = x.size(0)
    x = x.view(-1, input_dim)
    
    # Repeat data for multiple samples
    x = x.repeat(num_samples, 1)
    mu = mu.repeat(num_samples, 1)
    logvar = logvar.repeat(num_samples, 1)
    
    # Sample z using reparameterization
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std
    
    # Decode samples
    recon_x = torch.sigmoid(model.fc4(F.relu(model.fc3(z))))
    
    # Calculate log probabilities
    log_p_x_given_z = -F.binary_cross_entropy(recon_x, x, reduction='none').sum(1)
    log_p_z = Normal(0, 1).log_prob(z).sum(1)
    log_q_z_given_x = Normal(mu, std).log_prob(z).sum(1)
    
    # Calculate log weights
    log_w = log_p_x_given_z + log_p_z - log_q_z_given_x
    
    # Reshape to (num_samples, batch_size) and compute log mean exp
    log_w = log_w.view(num_samples, batch_size)
    log_likelihood = torch.logsumexp(log_w, dim=0) - torch.log(torch.tensor(num_samples, dtype=torch.float32))
    
    # Return negative log likelihood (to be minimized)
    return -log_likelihood.mean()

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize model and optimizer
model = VAE()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        
        # Forward pass
        mu, logvar = model.encode(data.view(-1, input_dim))
        
        # Calculate importance-weighted loss
        loss = importance_weighted_loss(None, data, mu, logvar, num_samples=num_samples)
        
        # Backward pass
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')
    
    print(f'====> Epoch: {epoch} Average loss: {train_loss / len(train_loader.dataset):.4f}')

# Run training
for epoch in range(1, epochs + 1):
    train(epoch)