In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter
import os
import matplotlib.pyplot as plt
import numpy as np

# Hyperparameters
latent_dim = 20
input_dim = 784  # MNIST 28x28
hidden_dim = 400
k_samples_train = 5  # Number of importance samples during training
k_samples_eval = 50  # Number of importance samples during evaluation
batch_size = 128
epochs = 15
learning_rate = 1e-3
beta = 1.0  # Weight for KL term (can be adjusted for β-IWAE)

# Create directories for saving outputs
os.makedirs('./results/iwae_samples', exist_ok=True)
os.makedirs('./logs', exist_ok=True)

# Initialize TensorBoard writer
writer = SummaryWriter('./logs/iwae')

# IWAE Model
class IWAE(nn.Module):
    def __init__(self):
        super(IWAE, self).__init__()
        
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)  # mu
        self.fc22 = nn.Linear(hidden_dim, latent_dim)  # logvar
        
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
    
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc21(h), self.fc22(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))
    
    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, input_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# IWAE Loss function
def iwae_loss(x, model, k=5, beta=1.0, global_step=0, mode='train'):
    batch_size = x.size(0)
    x = x.view(batch_size, -1)
    
    # Encode input
    mu, logvar = model.encode(x)
    
    # Repeat parameters for k samples
    mu = mu.repeat(k, 1)
    logvar = logvar.repeat(k, 1)
    x = x.repeat(k, 1)
    
    # Sample from q(z|x)
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + eps * std
    
    # Decode samples
    recon_x = model.decode(z)
    
    # Calculate log probabilities
    log_p_x_given_z = -F.binary_cross_entropy(recon_x, x, reduction='none').sum(1)
    log_p_z = Normal(0, 1).log_prob(z).sum(1)
    log_q_z_given_x = Normal(mu, std).log_prob(z).sum(1)
    
    # Calculate log weights
    log_w = log_p_x_given_z + beta * (log_p_z - log_q_z_given_x)
    
    # Reshape to (k, batch_size) and compute IWAE objective
    log_w = log_w.view(k, batch_size)
    log_likelihood = torch.logsumexp(log_w, dim=0) - torch.log(torch.tensor(k, dtype=torch.float32))
    
    # Calculate effective sample size for monitoring
    with torch.no_grad():
        normalized_weights = F.softmax(log_w, dim=0)
        ess = 1.0 / (normalized_weights.pow(2).sum(dim=0)).mean()
    
    # Log statistics
    writer.add_scalar(f'Loss/{mode}_log_likelihood', log_likelihood.mean(), global_step)
    writer.add_scalar(f'Loss/{mode}_kl_divergence', (log_q_z_given_x.mean() - log_p_z.mean()), global_step)
    writer.add_scalar(f'Metrics/{mode}_ess', ess, global_step)
    
    if mode == 'train':
        writer.add_histogram('Latent/mu', mu, global_step)
        writer.add_histogram('Latent/logvar', logvar, global_step)
    
    # Return negative log likelihood (to be minimized)
    return -log_likelihood.mean()

# Visualization functions
def visualize_reconstruction(data, epoch):
    with torch.no_grad():
        sample = data[:8]
        recon_batch, _, _ = model(sample)
        
        comparison = torch.cat([sample.view(8, 1, 28, 28), 
                               recon_batch.view(8, 1, 28, 28)])
        save_image(comparison.cpu(),
                 f'./results/iwae_samples/reconstruction_{epoch}.png', nrow=8)
        
        writer.add_images('Reconstruction/original_vs_reconstructed', 
                         comparison.unsqueeze(1), epoch)

def visualize_latent_space(data_loader, epoch):
    with torch.no_grad():
        latent_vectors = []
        labels = []
        for data, label in data_loader:
            mu, _ = model.encode(data.view(-1, input_dim))
            latent_vectors.append(mu)
            labels.append(label)
        
        latent_vectors = torch.cat(latent_vectors, dim=0).cpu().numpy()
        labels = torch.cat(labels, dim=0).cpu().numpy()
        
        # Plot 2D visualization if latent_dim >= 2
        if latent_dim >= 2:
            plt.figure(figsize=(10, 10))
            plt.scatter(latent_vectors[:, 0], latent_vectors[:, 1], c=labels, 
                      cmap='tab10', alpha=0.5)
            plt.colorbar()
            plt.title(f'Latent Space Visualization (Epoch {epoch})')
            plt.savefig(f'./results/iwae_samples/latent_space_{epoch}.png')
            plt.close()
            
            writer.add_figure('Latent/2D_visualization', plt.gcf(), epoch)

def generate_random_samples(epoch, num_samples=8):
    with torch.no_grad():
        # Sample from prior
        z = torch.randn(num_samples, latent_dim)
        samples = model.decode(z)
        
        save_image(samples.view(num_samples, 1, 28, 28),
                 f'./results/iwae_samples/generated_{epoch}.png', nrow=4)
        
        writer.add_images('Generation/random_samples', 
                         samples.view(-1, 1, 28, 28), epoch)

def evaluate(test_loader, epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (data, _) in enumerate(test_loader):
            global_step = (epoch - 1) * len(test_loader) + batch_idx
            loss = iwae_loss(data, model, k=k_samples_eval, beta=beta, 
                            global_step=global_step, mode='test')
            test_loss += loss.item()
    
    avg_loss = test_loss / len(test_loader.dataset)
    print(f'====> Test set loss: {avg_loss:.4f}')
    writer.add_scalar('Loss/test', avg_loss, epoch)
    return avg_loss

# Load datasets
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Initialize model and optimizer
model = IWAE()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        global_step = (epoch - 1) * len(train_loader) + batch_idx
        optimizer.zero_grad()
        
        # Calculate IWAE loss
        loss = iwae_loss(data, model, k=k_samples_train, beta=beta, 
                         global_step=global_step, mode='train')
        
        # Backward pass
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        # Log gradients and parameters periodically
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item() / len(data):.6f}')
            
            for name, param in model.named_parameters():
                writer.add_histogram(f'Params/{name}', param, global_step)
                if param.grad is not None:
                    writer.add_histogram(f'Grads/{name}', param.grad, global_step)
    
    avg_loss = train_loss / len(train_loader.dataset)
    print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')
    writer.add_scalar('Loss/train', avg_loss, epoch)
    
    # Visualizations
    with torch.no_grad():
        test_data, _ = next(iter(test_loader))
        visualize_reconstruction(test_data, epoch)
        visualize_latent_space(test_loader, epoch)
        generate_random_samples(epoch)
    
    # Evaluation
    test_loss = evaluate(test_loader, epoch)
    return test_loss

# Main training loop
best_test_loss = float('inf')
for epoch in range(1, epochs + 1):
    test_loss = train(epoch)
    
    # Save best model
    if test_loss < best_test_loss:
        best_test_loss = test_loss
        torch.save(model.state_dict(), './results/iwae_best_model.pth')

# Close TensorBoard writer
writer.close()