In [1]:
import torch
import torch.nn as nn

# Model: AutoEncoder (2-D bottleneck)
### What is an AutoEncoder (AE)?

An **AutoEncoder** learns to:
1) **Encode** an input \(x\) into a low-dimensional **latent vector** \(z\) (the **bottleneck**), and
2) **Decode** \(z\) back into a reconstruction \($\hat{x}\$) that resembles the original.

For MNIST (28×28 grayscale), we’ll use a **fully-connected (MLP: MultiLayer Perceptron) AE** with a **2-D bottleneck**:
- **Why 2-D?** So we can **visualize** the latent space later as a 2D scatter plot (each image → one point).
- **Encoder**: $(784 \rightarrow 300 \rightarrow 2)$
- **Decoder**: $(2 \rightarrow 300 \rightarrow 784)$

We keep it minimal:
- **Activations**: `LeakyReLU` after the hidden layers (helps gradients with small negative slope).
- **No output activation** on the decoder’s last layer because inputs are **standardized** (Mean=0.1307, Std=0.3081).  
  If you were using raw \([0,1]\) pixels, a `Sigmoid` could be sensible; with standardized inputs, an unrestricted linear output works well with L1/L2 losses.

---

### Shape flow

- Input batch \(x\): `[B, 1, 28, 28]`
- **Flatten** to `[B, 784]` before feeding the encoder.
- Encoder outputs \(z\): `[B, 2]`
- Decoder maps back to `[B, 784]`
- **Unflatten** to `[B, 1, 28, 28]` to compare against the original images.

This ensures the training loss can be computed as `loss(x, x_hat)` element-wise.

---

### Methods we implement

- `encode(x) → z`: only the encoder path (used later to visualize latents).
- `decode(z) → x_hat`: only the decoder path (used later to manipulate latents and decode).
- `forward(x) → x_hat`: full AE (encode then decode).

---

### Why `LeakyReLU`?

- Like ReLU but with a small slope on negatives (mitigates “dying ReLUs”).
- Often improves stability for simple MLP autoencoders.

---

### Sanity check idea (optional)

After defining the model, run a tiny batch through it and verify:
- `x.shape == recon.shape == [B, 1, 28, 28]`
- `features.shape == [B, 2]`

We’ll do that right after we instantiate the model.


### Why is it called a *bottleneck*?

In an autoencoder, the **bottleneck** is the **narrowest layer** in the middle — the smallest representation `z` through which all information about the input `x` must pass. By forcing data through this tight “neck,” the model must **compress** the input and keep only its most salient factors.

$$
\begin{array}{c}
x \in \mathbb{R}^{784} \xrightarrow{\text{Encoder}} z \in \mathbb{R}^{d} \xrightarrow{\text{Decoder}} \hat{x} \in \mathbb{R}^{784} \\
\text{(wide)} \quad \longrightarrow \quad \text{(narrow)} \quad \longrightarrow \quad \text{(wide)}
\end{array}
$$


---

### What the bottleneck does
- **Compression:** reduces dimensionality so the model captures key structure rather than copying pixels.  
- **Architectural regularization:** discourages the trivial identity mapping (especially when `d << D`).  
- **Generalization:** encourages discarding noise or redundancy and retaining meaningful features.

---

### Compression factor

Let `D` be the input dimensionality (e.g., `28 × 28 = 784` for MNIST) and `d` the bottleneck size.

$$
\text{compression factor} = \frac{D}{d}
$$

**Examples**
- `d = 2`  → \( \frac{784}{2} = 392\times \) compression (extreme; great for visualization).  
- `d = 64` → \( \frac{784}{64} \approx 12.25\times \) compression (typically better reconstructions).

---

### Choosing the bottleneck size `d` (trade-offs)
- **Too small `d`** → heavy compression → risk of **underfitting** (blurry or poor reconstructions).  
- **Too large `d`** → weak compression → risk of near **identity mapping** (little abstraction).  
- **Balanced `d`** → good reconstruction quality **and** compact, structured latents.

---

### Undercomplete vs. Overcomplete
- **Undercomplete** (`d < D`): a true *bottleneck*; capacity is limited by the latent dimensionality.  
- **Overcomplete** (`d ≥ D`): no geometric bottleneck; to avoid identity mapping you need extra regularization, e.g.:  
  - **Sparsity** penalties (e.g., L1 on activations),  
  - **Denoising** (corrupt input, reconstruct clean),  
  - **Contractive/variational** objectives (e.g., VAE).

---

### Why use a 2-D bottleneck in demos?
A 2-D latent `z` can be **plotted directly**: each image becomes a point in 2D, typically forming clusters by digit class. This makes the learned representation easy to interpret and debug.



In [2]:
class AutoEncoder(nn.Module):
    def __init__(self, bottleneck_dim: int = 2):
        super().__init__()
        # Flatten image to vector and back
        self.flatten   = nn.Flatten()                          # [B, 1, 28, 28] -> [B, 784]
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(1, 28, 28))  # [B, 784] -> [B, 1, 28, 28]

        # Encoder: 784 -> 300 -> bottleneck_dim
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 300),
            nn.LeakyReLU(),
            nn.Linear(300, bottleneck_dim)
        )

        # Decoder: bottleneck_dim -> 300 -> 784
        self.decoder = nn.Sequential(
            nn.Linear(bottleneck_dim, 300),
            nn.LeakyReLU(),
            nn.Linear(300, 28*28)
            # No final activation since inputs are standardized; use L1/L2 loss.
            # If using raw [0,1] pixels, consider nn.Sigmoid() here.
        )

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x = self.flatten(x)       # [B, 1, 28, 28] -> [B, 784]
        z = self.encoder(x)       # [B, 784] -> [B, bottleneck_dim]
        return z

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        x_hat = self.decoder(z)   # [B, bottleneck_dim] -> [B, 784]
        x_hat = self.unflatten(x_hat)  # [B, 784] -> [B, 1, 28, 28]
        return x_hat

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decode(self.encode(x))


# (Optional) quick shape sanity check
if __name__ == "__main__":
    model = AutoEncoder(bottleneck_dim=2)
    dummy = torch.randn(4, 1, 28, 28)     # batch of 4
    with torch.no_grad():
        z = model.encode(dummy)           # [4, 2]
        recon = model(dummy)              # [4, 1, 28, 28]
    print("z shape:", z.shape)
    print("recon shape:", recon.shape)


z shape: torch.Size([4, 2])
recon shape: torch.Size([4, 1, 28, 28])
