# Exploring Latent Spaces

In this notebook, you'll explore the latent space of a trained VAE on Fashion-MNIST.

**What you'll do:**
- Encode Fashion-MNIST items into the latent space and visualize the 2D manifold with t-SNE
- Interpolate between two images in latent space and watch smooth, coherent transitions
- Perform latent arithmetic — extract directions between encoded items and apply them to new items
- Sample from different regions of the latent space and map which regions generate which categories
- Design a targeted generation experiment to create specific clothing styles by navigating the latent space

**For each exercise, PREDICT the output before running the cell.** Wrong predictions are more valuable than correct ones — they reveal gaps in your mental model.

In [None]:
# Setup: self-contained for Google Colab
# No additional pip installs needed — torch, sklearn, matplotlib are all available by default

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np

# Reproducible results
torch.manual_seed(42)
np.random.seed(42)

# Nice plots
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = [10, 4]

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# ============================================================
# VAE Definition + Quick Training
# ============================================================
# If you have a saved checkpoint from the VAE notebook, you can
# load it instead. Otherwise, this trains a fresh VAE in ~2 min
# on a GPU (or ~5 min on CPU).

LATENT_DIM = 32

class VAE(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM):
        super().__init__()
        self.latent_dim = latent_dim

        # Encoder: image -> (mu, logvar)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),   # (32, 14, 14)
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),  # (64, 7, 7)
            nn.ReLU(),
            nn.Flatten(),                                # (64*7*7 = 3136)
        )
        self.fc_mu = nn.Linear(3136, latent_dim)
        self.fc_logvar = nn.Linear(3136, latent_dim)

        # Decoder: z -> image
        self.decoder_fc = nn.Linear(latent_dim, 3136)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),  # (32, 14, 14)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),   # (1, 28, 28)
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + std * eps

    def decode(self, z):
        h = self.decoder_fc(z)
        h = h.view(-1, 64, 7, 7)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    recon = F.binary_cross_entropy(recon_x, x, reduction='sum')
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta * kl


# Load Fashion-MNIST
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Category names for labeling
CATEGORY_NAMES = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# Train the VAE (or load a checkpoint)
model = VAE(LATENT_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

print('Training VAE...')
for epoch in range(15):
    model.train()
    total_loss = 0
    for batch, _ in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(batch)
        loss = vae_loss(recon, batch, mu, logvar, beta=1.0)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_dataset)
    if (epoch + 1) % 5 == 0:
        print(f'  Epoch {epoch+1:2d}/15  loss: {avg_loss:.2f}')

model.eval()
print('Done! VAE is ready for exploration.')

In [None]:
# ============================================================
# Shared Helpers
# ============================================================

def show_grid(images, nrow=5, title=None):
    """Display a batch of images as a grid."""
    n = images.shape[0]
    ncol = nrow
    nrow_actual = (n + ncol - 1) // ncol
    fig, axes = plt.subplots(nrow_actual, ncol, figsize=(ncol * 1.5, nrow_actual * 1.5))
    if nrow_actual == 1:
        axes = axes[np.newaxis, :]
    for i in range(nrow_actual * ncol):
        r, c = divmod(i, ncol)
        axes[r, c].axis('off')
        if i < n:
            axes[r, c].imshow(images[i].squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)
    if title:
        fig.suptitle(title, fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()


def show_interpolation_strip(images, labels=None, title=None):
    """Display a row of images as an interpolation strip."""
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=(n * 1.5, 2))
    for i, ax in enumerate(axes):
        ax.imshow(images[i].squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)
        ax.axis('off')
        if labels and i < len(labels):
            ax.set_title(labels[i], fontsize=9)
    if title:
        fig.suptitle(title, fontsize=13, y=1.05)
    plt.tight_layout()
    plt.show()


def get_items_by_category(dataset, category_idx, n=10):
    """Get n items from a specific category."""
    items = []
    for img, label in dataset:
        if label == category_idx:
            items.append(img)
            if len(items) >= n:
                break
    return torch.stack(items)


def encode_to_mu(images):
    """Encode images and return just the mu (mean) vectors."""
    with torch.no_grad():
        mu, _ = model.encode(images.to(device))
    return mu


print('Helpers loaded.')

---

## Exercise 1: Encode and Visualize the Latent Space (Guided)

The VAE compressed every 28x28 image (784 pixels) down to a 32-dimensional latent code. Similar items should encode to nearby points — T-shirts near T-shirts, sneakers near sneakers. But 32 dimensions are invisible. To see the structure, we project the latent codes down to 2D using **t-SNE** (t-distributed Stochastic Neighbor Embedding), which preserves neighborhood relationships: points that are close in 32D stay close in 2D.

We will encode the entire test set (10,000 images), project to 2D, and color each point by its category label.

**Before running, predict:**
- Will the categories form distinct clusters, or will everything be mixed together?
- Which categories do you think will overlap? (Hint: think about which clothing items look similar.)
- Will there be sharp boundaries between clusters, or smooth transitions?

In [None]:
# Encode the entire test set
all_z = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        mu, _ = model.encode(images.to(device))  # encode returns (mu, logvar) — use mu
        all_z.append(mu.cpu())
        all_labels.append(labels)

all_z = torch.cat(all_z).numpy()
all_labels = torch.cat(all_labels).numpy()

print(f'Encoded {len(all_z)} images to {all_z.shape[1]}-dimensional latent codes')
print(f'Now projecting to 2D with t-SNE (this takes ~30 seconds)...')

# Project to 2D with t-SNE
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
z_2d = tsne.fit_transform(all_z)

# Plot colored by category
fig, ax = plt.subplots(figsize=(10, 8))
scatter = ax.scatter(z_2d[:, 0], z_2d[:, 1],
                     c=all_labels, cmap='tab10', s=1, alpha=0.5)

# Add a legend with category names
handles = []
for i, name in enumerate(CATEGORY_NAMES):
    handles.append(plt.Line2D([0], [0], marker='o', color='w',
                              markerfacecolor=plt.cm.tab10(i / 10), markersize=8, label=name))
ax.legend(handles=handles, loc='upper right', fontsize=9, framealpha=0.8)

ax.set_title('VAE Latent Space (t-SNE projection)', fontsize=14)
ax.set_xlabel('t-SNE dim 1')
ax.set_ylabel('t-SNE dim 2')
plt.tight_layout()
plt.show()

print('\nLook for:')
print('  - Clusters: similar items group together')
print('  - Smooth transitions: related categories blur into each other')
print('  - Overlap: where categories share features (e.g., pullover and coat)')

**What just happened:**

The t-SNE plot reveals the structure that the KL regularizer created. Without KL (a plain autoencoder), you would see scattered points with no organization. With KL, the latent space is organized by similarity:

- **Clusters** — T-shirts near T-shirts, sneakers near sneakers. The VAE learned that these items are similar and encoded them nearby.
- **Smooth transitions** — Related categories blend together. Shoes blur into boots. Shirts blur into coats. There are no hard walls between categories.
- **Overlap** — Where categories share visual features (pullover and coat, for instance), their regions overlap. The VAE cannot fully separate them because they genuinely look similar.

This is the structure that makes interpolation and arithmetic possible. The space is not a random soup — it is a map of what the network learned about clothing.

**Caveat:** t-SNE is a *visualization* tool, not ground truth. It distorts distances and can exaggerate gaps. If you see a cluster, those points really are nearby in 32D. If you see a gap, it might be real or might be t-SNE exaggerating. Different runs (or different `perplexity` values) produce different layouts.

---

## Exercise 2: Interpolate Between Two Images (Guided)

The latent space is smooth — the KL regularizer filled all the gaps. That means you can **walk** from one encoded image to another and see coherent intermediate images at every step. This is **latent interpolation**.

The formula: encode image A to $z_A$, encode image B to $z_B$, then:

$$z_t = (1-t) \cdot z_A + t \cdot z_B \quad \text{for } t \in [0, 1]$$

At $t=0$ you get image A. At $t=1$ you get image B. In between, you get plausible intermediate images — not ghostly overlays, but coherent garments morphing from one form to another.

We will do two things:
1. **Pixel-space interpolation** — average the raw pixel values. This produces a ghostly double exposure.
2. **Latent-space interpolation** — average the latent codes, then decode. This produces coherent transitions.

**Before running, predict:**
- For pixel interpolation at $t=0.5$: what will you see? (Think: what does averaging two grayscale images look like?)
- For latent interpolation at $t=0.5$ between a T-shirt and a trouser: will the intermediate image look like clothing, or like noise?

In [None]:
# Get one T-shirt and one trouser from the test set
tshirt_images = get_items_by_category(test_dataset, 0, n=1)   # 0 = T-shirt/top
trouser_images = get_items_by_category(test_dataset, 1, n=1)  # 1 = Trouser

image_A = tshirt_images[0:1]   # shape: (1, 1, 28, 28)
image_B = trouser_images[0:1]

# Encode both images
z_A = encode_to_mu(image_A)  # shape: (1, 32)
z_B = encode_to_mu(image_B)

# Create interpolation steps
n_steps = 8
t_values = torch.linspace(0, 1, n_steps)

# --- Pixel-space interpolation (the naive approach) ---
pixel_interp = []
for t in t_values:
    blended = (1 - t) * image_A + t * image_B
    pixel_interp.append(blended.squeeze())

# --- Latent-space interpolation (the VAE approach) ---
latent_interp = []
with torch.no_grad():
    for t in t_values:
        z_t = (1 - t) * z_A + t * z_B  # interpolate in latent space
        decoded = model.decode(z_t)      # decode the interpolated code
        latent_interp.append(decoded.squeeze().cpu())

# Display both strips side by side
t_labels = [f't={t:.2f}' for t in t_values]

print('PIXEL-SPACE interpolation (averaging raw pixels):')
show_interpolation_strip(pixel_interp, labels=t_labels, title='Pixel Interpolation: T-shirt to Trouser')

print('LATENT-SPACE interpolation (averaging latent codes, then decoding):')
show_interpolation_strip(latent_interp, labels=t_labels, title='Latent Interpolation: T-shirt to Trouser')

**What just happened:**

The two strips tell completely different stories:

- **Pixel interpolation** produced a ghostly double exposure. At $t=0.5$, both shapes are visible at once — transparent and overlapping. It looks like a photography accident, not a real garment. That is because pixel-space does not understand "clothing." It just averages numbers.

- **Latent interpolation** produced a coherent transition. At every step, there is ONE solid shape that smoothly morphs from a T-shirt into trousers. The intermediate images look like actual clothing items. That is because the VAE's latent space is organized — the decoder knows what to do with every point along the path.

This is the key insight of continuous latent spaces: **interpolation is not blending two images.** It is asking the decoder, "What image lives at this intermediate point in the space you organized?" The answer is a coherent image because the KL regularizer filled all the gaps.

Remember the city map analogy from the VAE lesson: the KL term built roads connecting all the buildings. Interpolation is literally walking those roads. Every location along the path is a real place — a plausible image.

---

## Exercise 3: Latent Space Arithmetic (Supported)

If the latent space captures meaningful structure, then the **direction** between two encoded items captures the **difference** between them. You can extract that direction and apply it to something else.

The classic setup:
- Encode an ankle boot and a sneaker
- The vector $z(\text{ankle boot}) - z(\text{sneaker})$ captures roughly "what makes a boot different from a sneaker" — something about height or ankle coverage
- Add that direction to a sandal: $z(\text{sandal}) + [z(\text{ankle boot}) - z(\text{sneaker})]$
- If the space is well-organized, the result should look like a higher or more boot-like sandal

To reduce noise, we will average across multiple examples of each category rather than using single items. This gives us the "average sneaker," "average ankle boot," and "average sandal" in latent space.

**Your task:** Fill in the TODO sections to compute the attribute direction and apply it.

<details>
<summary>Hint</summary>

The arithmetic is just vector subtraction and addition on the latent codes (mu vectors). Compute the direction by subtracting one category's average from another's, then add that direction to a third category's average.

</details>

In [None]:
# Get multiple examples from each category and encode them
n_samples = 20  # average over 20 examples per category for a cleaner signal

sneaker_imgs = get_items_by_category(test_dataset, 7, n=n_samples)       # 7 = Sneaker
ankle_boot_imgs = get_items_by_category(test_dataset, 9, n=n_samples)    # 9 = Ankle boot
sandal_imgs = get_items_by_category(test_dataset, 5, n=n_samples)        # 5 = Sandal

# Encode each batch and compute the average latent code
z_sneaker_avg = encode_to_mu(sneaker_imgs).mean(dim=0, keepdim=True)       # shape: (1, 32)
z_ankle_boot_avg = encode_to_mu(ankle_boot_imgs).mean(dim=0, keepdim=True)
z_sandal_avg = encode_to_mu(sandal_imgs).mean(dim=0, keepdim=True)

# ============================================================
# TODO: Compute the "boot-ness" direction
# What direction in latent space takes you from a sneaker to an ankle boot?
# ============================================================
boot_direction = # TODO: subtract the average sneaker code from the average ankle boot code

# ============================================================
# TODO: Apply the boot direction to the sandal
# This should produce something more boot-like than a sandal
# ============================================================
z_result = # TODO: add the boot direction to the average sandal code

# Decode everything to see the results
with torch.no_grad():
    img_sneaker = model.decode(z_sneaker_avg)
    img_ankle_boot = model.decode(z_ankle_boot_avg)
    img_sandal = model.decode(z_sandal_avg)
    img_result = model.decode(z_result)

# Display: sneaker, ankle boot, the direction, sandal, result
fig, axes = plt.subplots(1, 5, figsize=(14, 3))
titles = ['Avg Sneaker', 'Avg Ankle Boot', '"Boot-ness"\ndirection', 'Avg Sandal', 'Sandal + Boot-ness']

axes[0].imshow(img_sneaker.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)
axes[1].imshow(img_ankle_boot.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)

# For the "direction" panel, show the difference in decoded space (just for visualization)
diff_image = (img_ankle_boot - img_sneaker).squeeze().cpu().numpy()
axes[2].imshow(diff_image, cmap='RdBu', vmin=-0.5, vmax=0.5)

axes[3].imshow(img_sandal.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)
axes[4].imshow(img_result.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)

for ax, title in zip(axes, titles):
    ax.set_title(title, fontsize=11)
    ax.axis('off')

# Add arithmetic symbols between panels
fig.text(0.26, 0.5, '-', fontsize=24, ha='center', va='center', color='white')
fig.text(0.445, 0.5, '=', fontsize=24, ha='center', va='center', color='white')
fig.text(0.635, 0.5, '+', fontsize=24, ha='center', va='center', color='white')
fig.text(0.82, 0.5, '=', fontsize=24, ha='center', va='center', color='white')

plt.tight_layout()
plt.show()

print('The result should look somewhat boot-like — taller than a sandal,')
print('possibly with more ankle coverage. It will be noisy — this is Fashion-MNIST,')
print('not a face dataset with smooth attribute variation.')

<details>
<summary>Solution</summary>

The key insight is that **directions in latent space capture relationships**. Subtracting two category averages extracts what makes them different. Adding that direction to a third category transfers the difference.

```python
# The "boot-ness" direction: what changes when you go from sneaker to ankle boot
boot_direction = z_ankle_boot_avg - z_sneaker_avg

# Apply it to the sandal: sandal + boot-ness = boot-like sandal
z_result = z_sandal_avg + boot_direction
```

The arithmetic is trivial — just vector subtraction and addition. The remarkable part is that it produces meaningful results: the decoded image should look like a sandal that has been made more boot-like (taller, more ankle coverage).

The results are noisy because Fashion-MNIST has discrete categories with limited smooth attribute variation. On face datasets (CelebA), this same technique transfers attributes like "wearing glasses" or "smiling" much more cleanly. The concept is real; clean results require data with consistent, continuous variation.

</details>

**What the result shows:**

If the arithmetic worked, the result image should be directionally correct — something sandal-like but taller or with more coverage. It will not be perfectly clean. That noisiness is honest: most random directions in latent space do not correspond to interpretable features. Only specific learned directions encode meaningful attributes, and even those work best when the training data has consistent, continuous variation in that attribute. Fashion-MNIST's discrete categories make this noisier than the famous face-attribute examples.

---

## Exercise 4: Sample from Different Regions (Supported)

The KL regularizer organized the latent space so that similar items cluster together. Different regions of the space generate different types of clothing. You can map this out by systematically sampling from different locations.

**Your task:** Create a 2D grid of latent codes by varying the first two dimensions of z while keeping the rest at 0. Decode each point and display the resulting images. This gives you a "map" of what the decoder produces across the latent space.

You will set dimensions 3 through 32 to zero (the mean of N(0,1)) and vary dimensions 1 and 2 across a grid from -3 to +3. Each grid cell is a decoded image from that location in latent space.

<details>
<summary>Hint</summary>

Create a meshgrid of values from -3 to +3 for two dimensions. For each (d1, d2) pair, build a latent vector that is all zeros except for those two dimensions. Decode each vector and place the image in the corresponding grid cell.

</details>

In [None]:
# ============================================================
# TODO: Create a grid of latent codes and decode them
# ============================================================

grid_size = 10  # 10x10 grid
grid_range = torch.linspace(-3, 3, grid_size)

# We will vary dimensions 0 and 1, keeping all others at 0
# This gives a 2D "slice" through the 32-dimensional latent space

decoded_grid = []

with torch.no_grad():
    for d2_val in reversed(grid_range):   # reversed so top of plot = positive d2
        row_images = []
        for d1_val in grid_range:
            # TODO: Create a latent vector z of shape (1, LATENT_DIM)
            # with all zeros, except dimension 0 = d1_val and dimension 1 = d2_val
            z = # TODO

            # TODO: Decode z to get an image
            img = # TODO

            row_images.append(img.squeeze().cpu().numpy())
        decoded_grid.append(row_images)

# Display as a single large image
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
for r in range(grid_size):
    for c in range(grid_size):
        axes[r, c].imshow(decoded_grid[r][c], cmap='gray', vmin=0, vmax=1)
        axes[r, c].axis('off')

# Label the axes
fig.text(0.5, 0.02, 'Latent Dimension 1 (from -3 to +3)', ha='center', fontsize=13)
fig.text(0.02, 0.5, 'Latent Dimension 2 (from -3 to +3)', va='center', rotation='vertical', fontsize=13)
fig.suptitle('Decoded Images Across a 2D Slice of the Latent Space', fontsize=14, y=0.98)
plt.tight_layout(rect=[0.03, 0.03, 1, 0.97])
plt.show()

print('Each cell is a decoded image from a different point in latent space.')
print('Look for smooth transitions as you move in any direction.')
print('Different regions produce different categories of clothing.')

<details>
<summary>Solution</summary>

The key insight is that each point in latent space decodes to a specific image. By sweeping two dimensions across a grid, you are exploring a 2D plane through the 32-dimensional space. Different regions of this plane produce different types of clothing, and the transitions between regions are smooth.

```python
# Create a zero vector and set dimensions 0 and 1
z = torch.zeros(1, LATENT_DIM).to(device)
z[0, 0] = d1_val
z[0, 1] = d2_val

# Decode
img = model.decode(z)
```

Note that we are only seeing a thin slice of the full space. The other 30 dimensions are fixed at 0. Moving along those other dimensions would reveal more structure, more categories, and more variation. What you see here is just a 2D cross-section through a much richer space.

Common things to notice:
- Certain regions consistently produce the same category (e.g., a corner might always produce sneakers)
- Transitions between regions are gradual, not abrupt
- The center (0, 0) tends to produce something generic or average-looking, because that is the most densely populated region of N(0,1)
- The edges (near -3 or +3) may produce odder or less recognizable items, because those are far from the training distribution's center

</details>

---

## Exercise 5: Targeted Generation Experiment (Independent)

You now have all the tools: encoding, decoding, interpolation, arithmetic, and region mapping. Use them to design an experiment that generates specific types of images by intentionally navigating the latent space.

**Task:** Pick a clothing category you want to generate variations of (e.g., different styles of shoes, or variations of a T-shirt). Use any combination of the techniques you have practiced to:

1. Find where that category lives in the latent space (encode several examples and look at their average latent code)
2. Generate a grid of variations by perturbing the average code in different directions
3. Display the results and describe what each direction of perturbation seems to control

**Requirements:**
- Generate at least 15 images that are variations of your chosen category
- Use at least two different techniques (e.g., sampling near the category center + arithmetic from another category)
- Include a brief written interpretation: what do the different directions seem to control?

This is independent — write the code from scratch. No skeleton is provided.

In [None]:
# ============================================================
# YOUR EXPERIMENT
# ============================================================
# Design and implement your targeted generation experiment here.
# Some ideas to get you started:
#
# Approach A: "Neighborhood sampling"
#   - Encode many examples of one category
#   - Compute the average latent code (the category center)
#   - Sample nearby points: z_center + small_noise
#   - Vary the noise magnitude to control how different the variations are
#
# Approach B: "Dimension exploration"
#   - Start at a category center
#   - Vary one latent dimension at a time while holding others fixed
#   - See which dimensions control which visual features
#
# Approach C: "Category blending"
#   - Pick two related categories (e.g., sneaker and ankle boot)
#   - Use interpolation with different t values to generate a spectrum
#   - Use arithmetic to transfer attributes from a third category
#
# Write your code below:


<details>
<summary>Solution (one possible approach)</summary>

There is no single correct answer — the point is to navigate the latent space intentionally and observe what happens. Here is one approach that combines neighborhood sampling with dimension exploration for sneaker variations:

**Reasoning:** To generate variations of sneakers, we first find where sneakers live in latent space by averaging many encoded sneakers. Then we explore two things: (1) what happens when we add small random noise (variations within the category), and (2) what each latent dimension controls (by varying one dimension at a time).

```python
# Find the sneaker center in latent space
sneaker_imgs = get_items_by_category(test_dataset, 7, n=50)
z_sneaker_center = encode_to_mu(sneaker_imgs).mean(dim=0, keepdim=True)

# --- Technique 1: Neighborhood sampling ---
# Small perturbations around the center produce sneaker variations
torch.manual_seed(42)
noise_scale = 0.5  # small enough to stay in the sneaker region
z_variations = z_sneaker_center + noise_scale * torch.randn(15, LATENT_DIM).to(device)

with torch.no_grad():
    variation_imgs = model.decode(z_variations)

show_grid(variation_imgs, nrow=5, title='Sneaker Variations (neighborhood sampling, scale=0.5)')

# --- Technique 2: Dimension exploration ---
# Vary the top 5 most-variant dimensions one at a time
# First, find which dimensions have the most variance across sneakers
z_sneakers = encode_to_mu(sneaker_imgs)  # (50, 32)
dim_variance = z_sneakers.var(dim=0)     # (32,)
top_dims = dim_variance.argsort(descending=True)[:5].cpu().numpy()

print(f'Top 5 highest-variance dimensions for sneakers: {top_dims}')
print(f'(These are the dimensions that vary most across different sneakers)\n')

# For each top dimension, sweep from -2 to +2 around the sneaker center
sweep_values = torch.linspace(-2, 2, 7)

for dim_idx in top_dims[:3]:  # just top 3 for space
    swept_images = []
    with torch.no_grad():
        for val in sweep_values:
            z = z_sneaker_center.clone()
            z[0, dim_idx] = val
            img = model.decode(z)
            swept_images.append(img.squeeze().cpu())
    labels = [f'{v:.1f}' for v in sweep_values]
    show_interpolation_strip(swept_images, labels=labels,
                           title=f'Varying dimension {dim_idx}')

print('Look at each strip. What visual feature does each dimension control?')
print('Common findings: width, brightness, shape curvature, sole thickness, etc.')
```

The neighborhood sampling shows that adding small noise to the sneaker center produces plausible sneaker variations — different shapes, widths, and styles, all recognizably sneakers. The dimension exploration reveals which latent dimensions control which visual features, though the features are often not cleanly interpretable (a single dimension might mix several visual attributes).

</details>

---

## Key Takeaways

1. **A trained VAE's latent space is a continuous, organized space where you can sample, interpolate, and do arithmetic.** The KL regularizer organized the space so that similar items cluster together, and every region is meaningful. This is the structure that makes generation possible.

2. **Interpolation in latent space produces coherent transitions; pixel-space interpolation does not.** Pixel averaging creates ghostly overlays. Latent averaging creates plausible intermediate images because the decoder understands the space between encoded points. Interpolation is not blending — it is asking "what lives at this point in the organized space?"

3. **Directions in latent space capture relationships between concepts.** Subtracting two encoded items gives a direction that represents their difference. Adding that direction to a third item transfers the difference. This works when the attribute has consistent, continuous variation in the training data.

4. **The structure in the latent space reflects what the network learned about the data.** Similar items cluster together. The t-SNE visualization makes this structure visible. Different regions generate different categories, with smooth transitions between them.

5. **VAE generation works but has a fundamental quality ceiling from the reconstruction-vs-KL tradeoff.** The blurriness is not a training failure — it is the price of a smooth, sampleable latent space. You cannot fix it with more epochs. Diffusion models take a fundamentally different approach to overcome this limitation.

You learned to sample from a distribution and create things that have never existed. That is the core of generative AI. Everything from here forward is about doing it better.