# part 1

In [27]:
import torch
import six
import torchvision
six.string_classes = str,
torch._six = six

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os

# 1. Hyperparameters and setup
batch_size = 128
learning_rate = 0.0001
num_epochs = 50
latent_dim = 20
device = torch.device("cpu")

# Create output directories
os.makedirs("output/epochs", exist_ok=True)
os.makedirs("output/samples", exist_ok=True)

# 2. Data loading (no normalization to keep data in [0,1] for BCE)
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 3. Model definitions
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),  # 28x28 -> 14x14
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1), # 14x14 -> 7x7
            nn.BatchNorm2d(64), nn.ReLU()
        )
        self.fc_mu = nn.Linear(64 * 7 * 7, latent_dim)
        self.fc_logvar = nn.Linear(64 * 7 * 7, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc_mu(x), self.fc_logvar(x)

class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 64 * 7 * 7)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1),  # 7x7 -> 14x14
            nn.BatchNorm2d(32), nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),   # 14x14 -> 28x28
            nn.Sigmoid()  # output in [0,1]
        )

    def forward(self, z):
        x = self.fc(z).view(z.size(0), 64, 7, 7)
        return self.deconv(x)

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar

# 4. Loss function
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + beta * KLD

# 5. Instantiate model, optimizer, and tracking
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_losses = []

# 6. Visualization functions
fixed_noise = torch.randn(64, latent_dim, device=device)

def plot_reconstructions(epoch, model, data_loader):
    model.eval()
    with torch.no_grad():
        data, _ = next(iter(data_loader))  # data in [0,1]
        data = data.to(device)
        recon, _, _ = model(data)
        comp = torch.cat([data[:8], recon[:8]])
        grid = utils.make_grid(comp.cpu(), nrow=8, pad_value=1)
        plt.figure(figsize=(8,2))
        plt.imshow(grid.permute(1,2,0).squeeze(), cmap='gray')
        plt.axis('off')
        plt.savefig(f'output/epochs/epoch_{epoch}_recon.png')
        plt.close()


def plot_samples(epoch, model):
    model.eval()
    with torch.no_grad():
        samples = model.decoder(fixed_noise)
        grid = utils.make_grid(samples.cpu(), nrow=8, pad_value=1)
        plt.figure(figsize=(8,8))
        plt.imshow(grid.permute(1,2,0).squeeze(), cmap='gray')
        plt.axis('off')
        plt.savefig(f'output/samples/epoch_{epoch}_samples.png')
        plt.close()

# 7. Training loop with checkpoints and plots
def train():
    model.train()
    for epoch in range(1, num_epochs + 1):
        epoch_loss = 0
        for data, _ in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            recon, mu, logvar = model(data)
            loss = vae_loss(recon, data, mu, logvar)
            loss.backward()
            epoch_loss += loss.item()
            optimizer.step()

        avg_loss = epoch_loss / len(train_loader.dataset)
        train_losses.append(avg_loss)
        print(f'Epoch {epoch}, Loss: {avg_loss:.4f}')

        # Save model checkpoint every 10 epochs
        if epoch % 10 == 0:
            torch.save(model.state_dict(), f'output/vae_epoch_{epoch}.pth')

        # Visualization
        plot_reconstructions(epoch, model, train_loader)
        plot_samples(epoch, model)

    # Plot loss curve
    plt.figure()
    plt.plot(train_losses)
    plt.title('VAE Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig('output/vae_loss_curve.png')
    plt.close()

if __name__ == '__main__':
    train()


Epoch 1, Loss: 269.3275
Epoch 2, Loss: 144.4843
Epoch 3, Loss: 127.0099
Epoch 4, Loss: 121.7240
Epoch 5, Loss: 118.6754
Epoch 6, Loss: 116.5082
Epoch 7, Loss: 114.8202
Epoch 8, Loss: 113.4181
Epoch 9, Loss: 112.2658
Epoch 10, Loss: 111.2998
Epoch 11, Loss: 110.4475
Epoch 12, Loss: 109.6446
Epoch 13, Loss: 109.0440
Epoch 14, Loss: 108.4405
Epoch 15, Loss: 107.8644
Epoch 16, Loss: 107.4612
Epoch 17, Loss: 107.0518
Epoch 18, Loss: 106.7034
Epoch 19, Loss: 106.3546
Epoch 20, Loss: 106.1110
Epoch 21, Loss: 105.8206
Epoch 22, Loss: 105.5513
Epoch 23, Loss: 105.3657
Epoch 24, Loss: 105.1612
Epoch 25, Loss: 104.9484
Epoch 26, Loss: 104.7859
Epoch 27, Loss: 104.5702
Epoch 28, Loss: 104.4047
Epoch 29, Loss: 104.2884
Epoch 30, Loss: 104.1232
Epoch 31, Loss: 104.0107
Epoch 32, Loss: 103.8521
Epoch 33, Loss: 103.7407
Epoch 34, Loss: 103.5849
Epoch 35, Loss: 103.4779
Epoch 36, Loss: 103.3703
Epoch 37, Loss: 103.2797
Epoch 38, Loss: 103.1848
Epoch 39, Loss: 103.0769
Epoch 40, Loss: 102.9422
Epoch 41,