# DATASCI 503, Group Work 13: Autoencoders with MNIST

**Instructions:** During lab section, and afterward as necessary, you will collaborate in two-person teams to complete the problems that are interspersed below. The GSI will help individual teams encountering difficulty, make announcements addressing common issues, and help ensure progress for all teams. **During lab, feel free to flag down your GSI to ask questions at any point!** Upon completion, one member of the team should submit their team's work through Canvas **as html**.

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

---

Autoencoders are the deep learning frontier of dimensionality reduction.

**Challenge**: Summarize a wide dataset with as few dimensions as possible.

In the past, we have covered dimensionality techniques, primarily PCA. The idea behind PCA was maximizing the total variance in some principal directions that may differ from the raw direction of the data.

In this notebook, we attempt to take in data $X$ and map it to a latent space $Z$ and see if we can learn a representative enough $Z$ to reconstruct $X$. This exercise in deep learning is accomplished with a network structure called an [autoencoder](https://www.geeksforgeeks.org/auto-encoders/#).

---

### Problem 1: Data Setup

We have provided most of the dataset construction for the MNIST dataset. Please create training, validation, and testing datasets as well as `DataLoader` objects for each dataset.

Specifically:
1. Apply `transforms.ToTensor()` to convert images to tensors and scale pixel values to [0, 1].
2. Download the MNIST training set (60,000 examples) and split it into 80% training and 20% validation.
3. Download the MNIST test set (10,000 examples).
4. Create `DataLoader` objects for each dataset with `batch_size=512`.

Store your results in variables named `train_dataset`, `val_dataset`, `test_dataset`, `train_loader`, `val_loader`, and `test_loader`.

In [None]:
# BEGIN SOLUTION
# Apply transforms to convert images to tensors and scale pixel values to [0, 1]
transform = transforms.ToTensor()

# Download the MNIST training set (60,000 examples)
full_trainset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)

# Split training set into 80% training and 20% validation
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size
train_dataset, val_dataset = random_split(full_trainset, [train_size, val_size])

# Download the MNIST test set (10,000 examples)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# Create DataLoaders to load data in batches
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)
# END SOLUTION

In [None]:
# Test assertions
assert len(train_dataset) == 48000, f"Expected 48000 training samples, got {len(train_dataset)}"
assert len(val_dataset) == 12000, f"Expected 12000 validation samples, got {len(val_dataset)}"
assert len(test_dataset) == 10000, f"Expected 10000 test samples, got {len(test_dataset)}"
assert train_loader.batch_size == 512, "train_loader batch_size should be 512"
assert val_loader.batch_size == 512, "val_loader batch_size should be 512"
assert test_loader.batch_size == 512, "test_loader batch_size should be 512"
print("All tests passed!")

# BEGIN HIDDEN TESTS
assert train_loader.dataset is train_dataset, "train_loader should use train_dataset"
assert val_loader.dataset is val_dataset, "val_loader should use val_dataset"
assert test_loader.dataset is test_dataset, "test_loader should use test_dataset"
sample_batch, _ = next(iter(train_loader))
assert sample_batch.shape == torch.Size(
    [512, 1, 28, 28]
), "Sample batch shape should be [512, 1, 28, 28]"
# END HIDDEN TESTS

### Problem 2: Creating an Autoencoder

Fill in the following `AutoEncoder` class by constructing an encoder network that has a depth of at least 3 layers. Note that the latent dimension is an argument that we pass into the class constructor.

The encoder should:
- Take flattened 28x28 images (784 features) as input
- Have at least 3 hidden layers with ReLU activations
- Output a latent representation of size `latent_dim`

The decoder should:
- Take the latent representation as input
- Mirror the encoder structure
- Output reconstructed images of size 784 with Sigmoid activation

The `forward` method should return both the reconstruction and the latent representation.

In [None]:
class AutoEncoder(torch.nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        # BEGIN SOLUTION
        # Encoder: 784 -> 256 -> 128 -> 64 -> latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim),
        )

        # Decoder: latent_dim -> 64 -> 128 -> 256 -> 784
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28 * 28),
            nn.Sigmoid(),
        )
        # END SOLUTION

    def forward(self, x):
        # BEGIN SOLUTION
        latents = self.encoder(x)
        reconstruction = self.decoder(latents)
        return reconstruction, latents
        # END SOLUTION

In [None]:
# Test assertions
test_ae = AutoEncoder(latent_dim=2)
test_input = torch.randn(10, 784)
test_output, test_latent = test_ae(test_input)
assert test_output.shape == (10, 784), f"Expected output shape (10, 784), got {test_output.shape}"
assert test_latent.shape == (10, 2), f"Expected latent shape (10, 2), got {test_latent.shape}"
assert hasattr(test_ae, "encoder"), "AutoEncoder should have an encoder attribute"
assert hasattr(test_ae, "decoder"), "AutoEncoder should have a decoder attribute"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Test with different latent dimensions
test_ae_10 = AutoEncoder(latent_dim=10)
test_output_10, test_latent_10 = test_ae_10(test_input)
assert test_latent_10.shape == (10, 10), "Latent dim should match constructor argument"
# Check output is in valid range (0, 1) due to Sigmoid
assert torch.all(test_output >= 0) and torch.all(
    test_output <= 1
), "Output should be between 0 and 1"
# END HIDDEN TESTS

### Problem 3: Training an Autoencoder

Last week we learned to train neural networks with early stopping as a regularizer. Autoencoders are no different in our case than a neural network because we implement (overwrite, actually) the `forward` method of the `torch.nn.Module` class.

Please finish implementing the `train_AE` function to train your autoencoder. The function should:
1. Use MSE loss as the criterion
2. Use the Adam optimizer with learning rate 0.001
3. Implement early stopping based on validation loss
4. Print the epoch, validation loss, and patience counter each epoch

Please leave the print statement that states what the current epoch, validation loss, and patience counter is.

In [None]:
def train_autoencoder(
    train_loader,
    val_loader,
    model=None,
    num_epochs=100,
    early_stopping_patience=5,
):
    if model is None:
        model = AutoEncoder(latent_dim=2)

    # BEGIN SOLUTION
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    patience_counter = 0
    lowest_val_loss = float("inf")

    for epoch in range(num_epochs):
        # Training loop
        model.train()
        for data in train_loader:
            images, _ = data
            images = images.view(images.shape[0], -1)

            outputs, _latents = model(images)
            loss = criterion(outputs, images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Validation loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data in val_loader:
                images, _ = data
                images = images.view(images.shape[0], -1)

                outputs, _latents = model(images)
                val_loss += criterion(outputs, images).item()
        val_loss /= len(val_loader)

        # Early stopping check
        if val_loss < lowest_val_loss:
            lowest_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                break

        print(
            f"Epoch [{epoch + 1}/{num_epochs}], "
            f"Validation Loss: {val_loss:.4f}, Patience Count: {patience_counter}"
        )
    # END SOLUTION

    return model

In [None]:
# Test assertions
# Test that train_autoencoder returns a model
import inspect

sig = inspect.signature(train_autoencoder)
params = list(sig.parameters.keys())
assert "train_loader" in params, "train_autoencoder should have train_loader parameter"
assert "val_loader" in params, "train_autoencoder should have val_loader parameter"
assert "num_epochs" in params, "train_autoencoder should have num_epochs parameter"
assert (
    "early_stopping_patience" in params
), "train_autoencoder should have early_stopping_patience parameter"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Quick test with 1 epoch to verify function works
quick_model = AutoEncoder(latent_dim=2)
quick_result = train_autoencoder(train_loader, val_loader, model=quick_model, num_epochs=1)
assert isinstance(quick_result, AutoEncoder), "train_autoencoder should return an AutoEncoder"
# END HIDDEN TESTS

In [None]:
trained_autoencoder = train_autoencoder(train_loader, val_loader, num_epochs=100)

### Problem 4: Understanding Big Picture Details of Autoencoders

Now that you have built an autoencoder which consists of an encoder and a decoder, please write in markdown the answers to the following questions:

**1)** In 1-2 sentences **MAX**, what is the encoder doing to the input data?

> BEGIN SOLUTION

The encoder compresses the high-dimensional input data (784 pixels) into a lower-dimensional latent representation (e.g., 2 dimensions), capturing the most important features needed for reconstruction.
> END SOLUTION


**2)** Why as practitioners might we care about a latent space?

> BEGIN SOLUTION

The latent space provides a compact, meaningful representation of the data that can be used for visualization, clustering, interpolation between data points, anomaly detection, and as features for downstream tasks. It also enables generative modeling by sampling from the latent space.
> END SOLUTION


**3)** In 1-2 sentences **MAX**, what is the decoder doing given values from the latent space?

> BEGIN SOLUTION

The decoder reconstructs the original high-dimensional data from the compressed latent representation by learning to map from the low-dimensional latent space back to the original image space.
> END SOLUTION


### Problem 5: Bring the Latent Spaces to Life

Create a 2D uniform mesh grid of values to plug in as inputs to the decoder network.

Plot the reconstruction images on a grid to see how changing the latent values changes what reconstruction we come out with.

Use `SIDE_LENGTH = 25` for the grid dimensions and create latent values ranging from -5 to 5.

In [None]:
SIDE_LENGTH = 25
NUM_IMAGES = SIDE_LENGTH**2

# BEGIN SOLUTION
# Create a mesh grid of latent values
latent_x = torch.linspace(-5, 5, SIDE_LENGTH)
latent_y = torch.linspace(-5, 5, SIDE_LENGTH)
xx, yy = torch.meshgrid(latent_x, latent_y, indexing="ij")

# Create the plot grid
fig, ax = plt.subplots(SIDE_LENGTH, SIDE_LENGTH, figsize=(SIDE_LENGTH, SIDE_LENGTH))
ax = ax.flatten()

# Generate and plot reconstructions for each latent point
for idx, latent in enumerate(zip(xx.flatten(), yy.flatten())):
    latent_tensor = torch.tensor(latent)
    reconstruction = trained_autoencoder.decoder(latent_tensor)
    ax[idx].imshow(reconstruction.detach().numpy().reshape(28, 28), cmap="gray")
    ax[idx].set_xticks([])
    ax[idx].set_yticks([])

plt.tight_layout()
plt.show()
# END SOLUTION

In [None]:
# Test assertions
assert SIDE_LENGTH == 25, "SIDE_LENGTH should be 25"
assert NUM_IMAGES == 625, f"NUM_IMAGES should be 625, got {NUM_IMAGES}"
assert latent_x.shape == (25,), "latent_x should have 25 values"
assert latent_y.shape == (25,), "latent_y should have 25 values"
print("All tests passed!")

# BEGIN HIDDEN TESTS
assert latent_x.min() == -5, "latent_x should start at -5"
assert latent_x.max() == 5, "latent_x should end at 5"
assert latent_y.min() == -5, "latent_y should start at -5"
assert latent_y.max() == 5, "latent_y should end at 5"
# END HIDDEN TESTS

### Problem 6: Latent Space Interpolation

Write a function called `extrapolate_digits` that takes as input 2 digits and returns a spectrum of 12 images shifting from one digit to the other in latent space. Then, present **3 separate cells** that call the function with 3 unique pairs of digits of your choosing.

**Hint:** Think about the steps necessary to pull this off:
1. Get 2 inputs corresponding to digits.
2. Get their latent values from your trained autoencoder.
3. Find a way to create equal intervals in latent space and send those intermediate values through the decoder.
4. Tile your outcomes together and present the final plot.

We provide a helper function `get_test_digit` that retrieves an example image of a specified digit from the test set.

In [None]:
def get_test_digit(testloader, digit_val=0):
    """Retrieve an example image of a specified digit from the test set."""
    shuffled_loader = torch.utils.data.DataLoader(
        testloader.dataset, batch_size=testloader.batch_size, shuffle=True
    )

    for data in shuffled_loader:
        images, labels = data
        for idx in range(len(labels)):
            if labels[idx] == digit_val:
                return images[idx]

    return f"Test set doesn't contain digit {digit_val}."

In [None]:
def extrapolate_digits(digit1, digit2, dataloader, encoder, decoder, num_images=12):
    # BEGIN SOLUTION
    # Get example images for each digit
    digit1_image = get_test_digit(dataloader, digit1)
    digit2_image = get_test_digit(dataloader, digit2)

    # Encode to latent space
    digit1_latent = encoder(digit1_image.view(1, -1)).detach()
    digit2_latent = encoder(digit2_image.view(1, -1)).detach()

    # Create interpolations in latent space
    interpolation_weights = torch.linspace(0, 1, num_images)
    interpolations = torch.stack(
        [
            digit1_latent + (digit2_latent - digit1_latent) * weight
            for weight in interpolation_weights
        ]
    )

    # Plot the interpolations
    _fig, ax = plt.subplots(1, num_images, figsize=(num_images, 1))
    ax = ax.flatten()

    for idx, latent in enumerate(interpolations):
        reconstruction = decoder(latent)
        ax[idx].imshow(reconstruction.detach().numpy().reshape(28, 28), cmap="gray")
        ax[idx].set_xticks([])
        ax[idx].set_yticks([])

    plt.tight_layout()
    plt.show()
    # END SOLUTION

In [None]:
# Test assertions
import inspect

sig = inspect.signature(extrapolate_digits)
params = list(sig.parameters.keys())
assert "digit1" in params, "extrapolate_digits should have digit1 parameter"
assert "digit2" in params, "extrapolate_digits should have digit2 parameter"
assert "encoder" in params, "extrapolate_digits should have encoder parameter"
assert "decoder" in params, "extrapolate_digits should have decoder parameter"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Test that get_test_digit returns a tensor
test_digit_img = get_test_digit(test_loader, 5)
assert isinstance(test_digit_img, torch.Tensor), "get_test_digit should return a tensor"
assert test_digit_img.shape == torch.Size([1, 28, 28]), "Digit image should be shape [1, 28, 28]"
# END HIDDEN TESTS

In [None]:
# BEGIN SOLUTION
extrapolate_digits(3, 6, test_loader, trained_autoencoder.encoder, trained_autoencoder.decoder)
# END SOLUTION

In [None]:
# BEGIN SOLUTION
extrapolate_digits(2, 5, test_loader, trained_autoencoder.encoder, trained_autoencoder.decoder)
# END SOLUTION

In [None]:
# BEGIN SOLUTION
extrapolate_digits(1, 7, test_loader, trained_autoencoder.encoder, trained_autoencoder.decoder)
# END SOLUTION

### Problem 7: Motivation for Variational Autoencoders

Based on what you understand about variational autoencoders (VAEs), why might one want to use a VAE over a standard AE? Answer in 1-2 sentences **MAX**.

> BEGIN SOLUTION

VAEs learn a continuous, regularized latent space where sampling from the prior (e.g., standard normal) produces meaningful outputs, enabling generative modeling. Unlike standard AEs, VAEs provide a principled probabilistic framework that prevents "holes" in the latent space and allows for smooth interpolation and generation of new data.
> END SOLUTION


---

**Challenge**: Summarize a wide dataset with as few dimensions as possible... but now ensure that we can sample from the distribution of the latent space and have it mean something.

This is the mission statement of **variational** autoencoders, the probabilistic extension of the autoencoder. If we assume that the true data is generated by some latent space, we want to maximize $p(x) = \int p(x|z)p(z)dz$. Ideally, we want to learn a decoder $p_{\theta}(x|z)$ that is able to make probabilistic, but accurate reconstructions given some sample from the latent space. If we assume the latent space has a prior $p(z)$, then a posterior distribution would take the form $p(z|x) = \frac{p(x|z)p(z)}{\int p(x|z)p(z) dx}$. This denominator being intractable is what motivates us to learn an approximate posterior $q_{\phi}(z|x)$.

In order to learn the approximate posterior however, we need a new objective. This is the evidence lower bound (ELBO). The derivation of the formula can be found [here](https://fangdahan.medium.com/derivation-of-elbo-in-vae-25ad7991fdf7).

More details about the full theory decomposition of VAEs and all formula derivations can be found in lecture materials and this [wonderful Medium post](https://medium.com/@j.zh/mathematics-behind-variational-autoencoders-c69297301957).

The most important takeaway is that the ELBO can be decomposed as

$$\textrm{ELBO} = E_{q(z|x)}(\log p(x|z)) + \textrm{KL}(q(z|x) || p(z))$$

The decoder $p(x|z)$ outputs a predicted mean for the reconstruction and that loss boils down to an $\textrm{MSE} = (x - \hat{x})^2$. If we assume that the approximate posterior is ALSO a Gaussian distribution, then provided $\mu$ and $\sigma$ for the approximate posterior, we can write the KL divergence as:

$$\mathrm{KL}=-\frac{1}{2} \sum\left(1+\log \sigma^2-\mu^2-\sigma^2\right)$$

The second term in the ELBO involves sampling $z \sim q_\phi(z \mid x)$. But we cannot backpropagate through a sample unless we rewrite the sampling process.
So instead of:

$$
z \sim \mathcal{N}\left(\mu_\phi(x), \sigma_\phi(x)^2\right)
$$


We write:

$$
z=\mu_\phi(x)+\sigma_\phi(x) \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)
$$

Now the randomness is pushed to $\epsilon$ (which is independent), and everything else is deterministic and differentiable. This is the reparameterization trick. You will need to implement such a method in the VAE class to ensure it backpropagates.

---

### Problem 8: Implementing a Variational Autoencoder

We now move from autoencoders (AE) to variational autoencoders (VAE). Please fill in each method of the VAE class for training purposes.

The VAE should have:
- An encoder that outputs both `mu` and `logvar` (log variance)
- A `reparameterize` method that implements the reparameterization trick
- A decoder that reconstructs the input from the latent sample
- A `forward` method that returns the reconstruction, mu, and logvar

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim

        # BEGIN SOLUTION
        # Encoder network
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
        )
        # Output layers for mu and logvar
        self.mu = nn.Linear(256, latent_dim)
        self.logvar = nn.Linear(256, latent_dim)

        # Decoder network
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Sigmoid(),
        )
        # END SOLUTION

    def encode(self, x):
        # BEGIN SOLUTION
        hidden = self.encoder(x)
        mu = self.mu(hidden)
        logvar = self.logvar(hidden)
        return mu, logvar
        # END SOLUTION

    def reparameterize(self, mu, logvar):
        # BEGIN SOLUTION
        # Reparameterization trick: z = mu + std * epsilon
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
        # END SOLUTION

    def decode(self, z):
        # BEGIN SOLUTION
        return self.decoder(z)
        # END SOLUTION

    def forward(self, x):
        # BEGIN SOLUTION
        mu, logvar = self.encode(x)
        latent_sample = self.reparameterize(mu, logvar)
        reconstruction = self.decode(latent_sample)
        return reconstruction, mu, logvar
        # END SOLUTION

In [None]:
# Test assertions
test_vae = VAE(latent_dim=2)
test_input = torch.randn(10, 784)
test_recon, test_mu, test_logvar = test_vae(test_input)
assert test_recon.shape == (
    10,
    784,
), f"Expected reconstruction shape (10, 784), got {test_recon.shape}"
assert test_mu.shape == (10, 2), f"Expected mu shape (10, 2), got {test_mu.shape}"
assert test_logvar.shape == (10, 2), f"Expected logvar shape (10, 2), got {test_logvar.shape}"
assert hasattr(test_vae, "encode"), "VAE should have encode method"
assert hasattr(test_vae, "decode"), "VAE should have decode method"
assert hasattr(test_vae, "reparameterize"), "VAE should have reparameterize method"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Test reparameterize produces different samples
sample1 = test_vae.reparameterize(test_mu, test_logvar)
sample2 = test_vae.reparameterize(test_mu, test_logvar)
assert not torch.allclose(sample1, sample2), "Reparameterize should produce stochastic samples"
# Test decode output range
assert torch.all(test_recon >= 0) and torch.all(
    test_recon <= 1
), "Reconstruction should be between 0 and 1"
# END HIDDEN TESTS

### Problem 9: Training a Variational Autoencoder

Please finish implementing the `train_VAE` function to train your variational autoencoder. The function should:
1. Use the VAE loss (reconstruction loss + KL divergence)
2. Use the Adam optimizer with learning rate 0.001
3. Implement early stopping based on validation loss
4. Print the epoch, validation loss, and patience counter each epoch

Please leave the print statement that states what the current epoch, validation loss, and patience counter is.

In [None]:
def train_vae(
    train_loader,
    val_loader,
    model=None,
    num_epochs=100,
    early_stopping_patience=5,
):
    if model is None:
        model = VAE(latent_dim=2)

    # BEGIN SOLUTION
    def vae_loss(reconstruction, x, mu, logvar, kl_weight=1.0):
        """Compute VAE loss: reconstruction loss + KL divergence."""
        mse_loss = F.mse_loss(reconstruction, x, reduction="sum")
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return mse_loss + kl_loss * kl_weight

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    patience_counter = 0
    lowest_val_loss = float("inf")

    for epoch in range(num_epochs):
        # Training loop
        model.train()
        for data in train_loader:
            images, _ = data
            images = images.view(images.shape[0], -1)

            outputs, latents_mu, latents_logvar = model(images)
            loss = vae_loss(outputs, images, latents_mu, latents_logvar)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Validation loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for data in val_loader:
                images, _ = data
                images = images.view(images.shape[0], -1)

                outputs, latents_mu, latents_logvar = model(images)
                val_loss += vae_loss(outputs, images, latents_mu, latents_logvar).item()
        val_loss /= len(val_loader)

        # Early stopping check
        if val_loss < lowest_val_loss:
            lowest_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= early_stopping_patience:
                break

        print(
            f"Epoch [{epoch + 1}/{num_epochs}], "
            f"Validation Loss: {val_loss:.4f}, Patience Count: {patience_counter}"
        )
    # END SOLUTION

    return model

In [None]:
# Test assertions
import inspect

sig = inspect.signature(train_vae)
params = list(sig.parameters.keys())
assert "train_loader" in params, "train_vae should have train_loader parameter"
assert "val_loader" in params, "train_vae should have val_loader parameter"
assert "num_epochs" in params, "train_vae should have num_epochs parameter"
assert (
    "early_stopping_patience" in params
), "train_vae should have early_stopping_patience parameter"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Quick test with 1 epoch to verify function works
quick_vae = VAE(latent_dim=2)
quick_vae_result = train_vae(train_loader, val_loader, model=quick_vae, num_epochs=1)
assert isinstance(quick_vae_result, VAE), "train_vae should return a VAE model"
# END HIDDEN TESTS

In [None]:
vae_model = VAE(2)

trained_vae = train_vae(train_loader, val_loader, vae_model, num_epochs=5)

### Problem 10: Sampling from the Latent Space

Now that we have trained a variational autoencoder, let's take some samples from the prior space and pass them through the decoder to see what reconstructions we can get.

Create `latent_samples` by sampling 250 points from a standard normal distribution (the prior), then decode them to get reconstructions. Run the prewritten code afterwards to plot the reconstructions on a scatterplot.

What happens at (0, 0)? Answer in the markdown cell below.

In [None]:
NUM_IMAGES = 250

# BEGIN SOLUTION
# Sample from the prior (standard normal)
latent_samples = torch.randn(NUM_IMAGES, 2)
# Decode the samples
reconstruction = trained_vae.decode(latent_samples)
reconstruction = reconstruction.view(NUM_IMAGES, 28, 28)
# END SOLUTION

In [None]:
# Test assertions
assert latent_samples.shape == (
    250,
    2,
), f"Expected latent_samples shape (250, 2), got {latent_samples.shape}"
assert reconstruction.shape == (
    250,
    28,
    28,
), f"Expected reconstruction shape (250, 28, 28), got {reconstruction.shape}"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Check that latent samples are roughly standard normal
assert abs(latent_samples.mean()) < 0.3, "Latent samples should have mean near 0"
assert abs(latent_samples.std() - 1.0) < 0.3, "Latent samples should have std near 1"
# END HIDDEN TESTS

In [None]:
from matplotlib.offsetbox import AnnotationBbox, OffsetImage

fig, ax = plt.subplots(figsize=(10, 10))

# First plot your latent points
ax.scatter(latent_samples[:, 0], latent_samples[:, 1], alpha=0.2)

# Now overlay each image
for latent in latent_samples:
    reconstruction = trained_vae.decode(latent)
    img = reconstruction.detach().numpy().reshape(28, 28)
    # Wrap it in an OffsetImage
    im = OffsetImage(img, zoom=0.5, cmap="gray", origin="upper")
    # Create an AnnotationBbox, disable the frame so only image shows up
    ab = AnnotationBbox(im, latent, frameon=False)
    ax.add_artist(ab)

# Set limits to cover all points, plus a little padding
x_min, x_max = latent_samples[:, 0].min(), latent_samples[:, 0].max()
y_min, y_max = latent_samples[:, 1].min(), latent_samples[:, 1].max()
pad_x = (x_max - x_min) * 0.05
pad_y = (y_max - y_min) * 0.05
ax.set_xlim(x_min - pad_x, x_max + pad_x)
ax.set_ylim(y_min - pad_y, y_max + pad_y)

ax.set_xlabel("Latent Dimension 1")
ax.set_ylabel("Latent Dimension 2")
ax.set_title("Reconstruction Images in Latent Space")

plt.show()

**What happens at (0, 0)?**

> BEGIN SOLUTION

At (0, 0), which is the center of the prior distribution, the VAE typically produces an "average" or blended digit that represents the most common features across all training digits. This is because the KL divergence term encourages the encoder to map inputs near the origin, and the decoder learns to produce reasonable outputs for this high-density region. The reconstruction at (0, 0) often appears as a smooth blend of multiple digits.
> END SOLUTION


### Problem 11: Evaluating Latent Space Effectiveness

One problem with working with VAEs is that ideally, we would get meaningful reconstructions if we just sampled from the prior and took those samples through the decoder. **This is because the entire point of the exercise is learning a continuous representation of digits that can be summarized in 2 dimensions.** If the aggregate posterior $q(z)$ does not match the prior $p(z)$ then sampling from the prior may land us in "dead zones" - parts of latent space that the decoder never learned to handle.

We have started some code to try to compare the distribution of the approximate posterior and the prior. Please finish the code by:
1. Creating `prior` samples from a standard normal distribution
2. Creating `approx_posterior` samples using the reparameterization trick with the encoded mu and logvar
3. Plotting prior samples in red and approximate posterior samples in blue

Do the distributions line up? What differences (if any) do you observe? Answer in markdown.

In [None]:
NUM_SAMPLES = 500

# Get test samples
comparison_loader = DataLoader(test_dataset, batch_size=NUM_SAMPLES, shuffle=True)
test_samples, _ = next(iter(comparison_loader))
test_samples = test_samples.view(NUM_SAMPLES, -1)

# Encode the test samples to get mu and logvar
with torch.no_grad():
    reconstruction, latent_mu, latent_logvar = trained_vae(test_samples)

# BEGIN SOLUTION
# Sample from the prior
prior = torch.randn(NUM_SAMPLES, 2)

# Sample from the approximate posterior using reparameterization
approx_posterior = latent_mu + torch.exp(0.5 * latent_logvar) * torch.randn_like(latent_mu)
posterior_samples = approx_posterior.detach().numpy()

# Plot both distributions
plt.figure(figsize=(8, 8))
plt.scatter(prior[:, 0], prior[:, 1], c="red", alpha=0.5, label="Prior")
plt.scatter(
    posterior_samples[:, 0],
    posterior_samples[:, 1],
    c="blue",
    alpha=0.5,
    label="Approx Posterior",
)
plt.legend()
plt.xlabel("Latent Dimension 1")
plt.ylabel("Latent Dimension 2")
plt.title("Prior vs Approximate Posterior Distributions")
plt.show()
# END SOLUTION

In [None]:
# Test assertions
assert prior.shape == (500, 2), f"Expected prior shape (500, 2), got {prior.shape}"
assert approx_posterior.shape == (
    500,
    2,
), f"Expected approx_posterior shape (500, 2), got {approx_posterior.shape}"
print("All tests passed!")

# BEGIN HIDDEN TESTS
# Check prior is standard normal
assert abs(prior.mean()) < 0.2, "Prior should have mean near 0"
assert abs(prior.std() - 1.0) < 0.2, "Prior should have std near 1"
# END HIDDEN TESTS

**Do the distributions line up? What differences (if any) do you observe?**

> BEGIN SOLUTION

The distributions may not perfectly align. Common observations include:
1. The approximate posterior may have smaller variance than the prior, indicating "posterior collapse" or that the encoder is not fully utilizing the latent space.
2. The approximate posterior may form clusters corresponding to different digit classes, while the prior is uniformly spread.
3. With limited training (5 epochs), the KL divergence may not have fully regularized the posterior to match the prior.

If the distributions differ significantly, sampling from the prior may produce lower-quality reconstructions because the decoder was trained on a different distribution of latent codes.
> END SOLUTION
