# Adversarial Autoencoder (AAE) Tutorial

## Introduction

An Adversarial Autoencoder (AAE) is a type of autoencoder that incorporates the principles of Generative Adversarial Networks (GANs) to impose a prior distribution on the latent space. This allows the AAE to learn more structured latent representations and generate more realistic samples.

## Architecture

An AAE consists of three main components:
1. **Encoder**: Maps the input to a latent-space representation.
2. **Decoder**: Reconstructs the input from the latent space representation.
3. **Discriminator**: Distinguishes between true samples from a prior distribution and encoded samples from the latent space.

### Encoder

The encoder function, $z = f(x)$, maps the input $x$ to a latent representation $z$. Mathematically, this can be written as:

$$
z = f(x) = \sigma(Wx + b)
$$

where:
- $W$ is a weight matrix
- $b$ is a bias vector
- $\sigma$ is an activation function (e.g., ReLU, sigmoid)

### Decoder

The decoder function, $\hat{x} = g(z)$, maps the latent representation $z$ back to the original input space. Mathematically, this can be written as:

$$
\hat{x} = g(z) = \sigma(W'z + b')
$$

where:
- $W'$ is a weight matrix (not necessarily the transpose of $W$)
- $b'$ is a bias vector
- $\sigma$ is an activation function

### Discriminator

The discriminator function, $D(z)$, distinguishes between true samples from the prior distribution and encoded samples from the latent space. Mathematically, this can be written as:

$$
D(z) = \text{sigmoid}(W''z + b'')
$$

where:
- $W''$ is a weight matrix
- $b''$ is a bias vector

## Loss Function

The loss function for an AAE consists of three parts:
1. **Reconstruction Loss**: Measures how well the decoder reconstructs the input.
2. **Adversarial Loss (Discriminator)**: Measures how well the discriminator distinguishes between true and encoded samples.
3. **Adversarial Loss (Generator/Encoder)**: Measures how well the encoder fools the discriminator.

The total loss is:

$$
L = \text{Reconstruction Loss} + \lambda_1 \cdot \text{Adversarial Loss (Discriminator)} + \lambda_2 \cdot \text{Adversarial Loss (Generator/Encoder)}
$$

#### Reconstruction Loss

The reconstruction loss is typically the mean squared error (MSE):

$$
\text{Reconstruction Loss} = \frac{1}{n} \sum_{i=1}^{n} (x_i - \hat{x}_i)^2
$$

#### Adversarial Loss (Discriminator)

The adversarial loss for the discriminator is the binary cross-entropy loss:

$$
\text{Adversarial Loss (Discriminator)} = -\left( \mathbb{E}_{z \sim P_z}[\log D(z)] + \mathbb{E}_{z \sim Q_z}[\log (1 - D(z))] \right)
$$

where $P_z$ is the prior distribution and $Q_z$ is the distribution of encoded samples.

#### Adversarial Loss (Generator/Encoder)

The adversarial loss for the generator/encoder is also the binary cross-entropy loss, but with flipped labels:

$$
\text{Adversarial Loss (Generator/Encoder)} = -\mathbb{E}_{z \sim Q_z}[\log D(z)]
$$

## Training Process

Training an AAE involves alternating between optimizing the encoder/decoder and the discriminator.

### Derivatives

Let's derive the gradients for the encoder/decoder and discriminator weights.

#### Decoder Gradients

For the decoder, the gradient of the loss function with respect to the decoder weights $W'$ is:

$$
\frac{\partial L}{\partial W'} = \frac{\partial L}{\partial \hat{x}} \cdot \frac{\partial \hat{x}}{\partial W'}
$$

Since $\hat{x} = \sigma(W'z + b')$, we have:

$$
\frac{\partial \hat{x}}{\partial W'} = z \cdot \sigma'(W'z + b')
$$

Thus,

$$
\frac{\partial L}{\partial W'} = (x - \hat{x}) \cdot \sigma'(W'z + b') \cdot z^T
$$

#### Encoder Gradients

For the encoder, the gradient of the loss function with respect to the encoder weights $W$ is:

$$
\frac{\partial L}{\partial W} = \frac{\partial L}{\partial z} \cdot \frac{\partial z}{\partial W} + \lambda_2 \cdot \frac{\partial \text{Adversarial Loss (Generator/Encoder)}}{\partial W}
$$

Since $z = \sigma(Wx + b)$, we have:

$$
\frac{\partial L}{\partial z} = \frac{\partial L}{\partial \hat{x}} \cdot \frac{\partial \hat{x}}{\partial z} = (x - \hat{x}) \cdot \sigma'(W'z + b') \cdot W'^T
$$

And,

$$
\frac{\partial z}{\partial W} = x \cdot \sigma'(Wx + b)
$$

The gradient of the adversarial loss with respect to $W$ is:

$$
\frac{\partial \text{Adversarial Loss (Generator/Encoder)}}{\partial W} = -\frac{1}{D(z)} \cdot \frac{\partial D(z)}{\partial W}
$$

Thus,

$$
\frac{\partial L}{\partial W} = [(x - \hat{x}) \cdot \sigma'(W'z + b') \cdot W'^T] \cdot x^T \cdot \sigma'(Wx + b) - \lambda_2 \cdot \frac{1}{D(z)} \cdot \frac{\partial D(z)}{\partial W}
$$

### Discriminator Gradients

For the discriminator, the gradient of the loss function with respect to the discriminator weights $W''$ is:

$$
\frac{\partial L}{\partial W''} = \frac{\partial \text{Adversarial Loss (Discriminator)}}{\partial W''}
$$

Since $D(z) = \text{sigmoid}(W''z + b'')$, we have:

$$
\frac{\partial D(z)}{\partial W''} = z \cdot \sigma'(W''z + b'')
$$

Thus,

$$
\frac{\partial \text{Adversarial Loss (Discriminator)}}{\partial W''} = \left( \frac{1}{D(z)} \cdot \frac{\partial D(z)}{\partial W''} - \frac{1}{1 - D(z)} \cdot \frac{\partial (1 - D(z))}{\partial W''} \right)
$$

### Gradient Descent Update

The weights and biases are updated using the gradients:

$$
W \leftarrow W - \eta \frac{\partial L}{\partial W}
$$

$$
b \leftarrow b - \eta \frac{\partial L}{\partial b}
$$

where $\eta$ is the learning rate.

### Advantages and Drawbacks

#### Advantages
- **Structured Latent Space**: AAEs impose a prior distribution on the latent space, leading to more structured and interpretable latent representations.
- **Realistic Sample Generation**: By leveraging the adversarial training mechanism, AAEs can generate more realistic samples compared to traditional autoencoders.
- **Flexibility**: AAEs allow for flexible and customizable prior distributions on the latent space, enabling various applications such as semi-supervised learning and disentangled representation learning.

#### Drawbacks
- **Complex Training**: Training AAEs involves alternating between optimizing the encoder/decoder and the discriminator, which can be challenging and time-consuming.
- **Mode Collapse**: Similar to GANs, AAEs can suffer from mode collapse, where the generator/encoder produces limited diversity in the generated samples.
- **Hyperparameter Sensitivity**: AAEs require careful tuning of hyperparameters, such as the learning rate and the balance between the reconstruction and adversarial losses.

### Innovations of Adversarial Autoencoders (AAEs)

#### Combining Autoencoders and GANs
- **Autoencoders**: Traditionally used for learning efficient codings of input data by minimizing reconstruction loss.
- **GANs**: Designed to generate realistic data samples by training a generator and discriminator in a competitive setting.
- **AAEs**: Integrate the autoencoder architecture with adversarial training, where the latent space learned by the encoder is regularized to follow a prior distribution using a discriminator.

#### Regularization of Latent Space
- **Prior Distribution**: The encoder is regularized to produce latent codes that follow a specific prior distribution (e.g., Gaussian, mixture of Gaussians). This regularization is enforced by training a discriminator to differentiate between true samples from the prior and encoded samples.
- **Structured Latent Space**: Ensures that the latent space has a well-defined structure, leading to more meaningful and interpretable representations. This is crucial for tasks like interpolation, clustering, and generation of diverse samples.

#### Improved Sample Generation
- **Realistic Samples**: By employing adversarial training, AAEs can generate more realistic and diverse samples compared to traditional autoencoders.
- **Mode Coverage**: Adversarial training helps the encoder learn a latent space that covers the true data distribution more effectively, reducing issues like mode collapse that are common in standard GANs.

#### Flexibility in Prior Distribution
- **Customizable Priors**: AAEs allow for the use of various prior distributions on the latent space, providing flexibility for different applications. For example, a mixture of Gaussians can be used for clustering, and a Gaussian prior can be used for smooth interpolations.
- **Application in Semi-Supervised Learning**: The structured latent space can facilitate semi-supervised learning by providing meaningful clusters in the latent space.

#### Stability in Training
- **Combining MSE and Adversarial Loss**: The reconstruction loss (e.g., Mean Squared Error) combined with adversarial loss provides a stabilizing effect during training. The reconstruction loss ensures that the autoencoder retains the essential features of the input, while the adversarial loss regularizes the latent space.
- **Reduced Mode Collapse**: By balancing the autoencoder and adversarial objectives, AAEs can achieve a more stable training process compared to standard GANs, reducing the likelihood of mode collapse.

## Numerical Example

Let's consider a numerical example using Python and Keras to illustrate how an AAE works. We'll use the MNIST dataset, which consists of 28x28 grayscale images of handwritten digits.




In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the Encoder
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, z_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        z = self.fc2(x)
        return z

# Define the Decoder
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, z):
        z = torch.relu(self.fc1(z))
        x_reconstructed = torch.sigmoid(self.fc2(z))
        return x_reconstructed

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, z_dim, hidden_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, 1)

    def forward(self, z):
        z = torch.relu(self.fc1(z))
        validity = torch.sigmoid(self.fc2(z))
        return validity

# Hyperparameters
input_dim = 28*28
hidden_dim = 256
z_dim = 64
batch_size = 128
num_epochs = 50
learning_rate = 1e-3

# Data loading
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Initialize models
encoder = Encoder(input_dim, hidden_dim, z_dim)
decoder = Decoder(z_dim, hidden_dim, input_dim)
discriminator = Discriminator(z_dim, hidden_dim)

# Loss functions
reconstruction_loss = nn.MSELoss()
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    for batch_idx, (real_data, _) in enumerate(train_loader):
        real_data = real_data.view(-1, input_dim)

        # Adversarial ground truths
        valid = torch.ones(real_data.size(0), 1)
        fake = torch.zeros(real_data.size(0), 1)

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()

        # Encode and Decode
        z = encoder(real_data)
        reconstructed_data = decoder(z)

        # Reconstruction loss
        recon_loss = reconstruction_loss(reconstructed_data, real_data)

        # Regularization loss
        validity = discriminator(z)
        reg_loss = adversarial_loss(validity, valid)

        # Total loss
        g_loss = recon_loss + reg_loss
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()

        # Sample from the prior distribution (e.g., Gaussian)
        z_real = torch.randn(real_data.size(0), z_dim)
        validity_real = discriminator(z_real)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Discriminator loss on encoded samples
        z_fake = encoder(real_data).detach()
        validity_fake = discriminator(z_fake)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = 0.5 * (d_real_loss + d_fake_loss)
        d_loss.backward()
        optimizer_D.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] | G Loss: {g_loss.item():.4f} | D Loss: {d_loss.item():.4f}")

print("Training complete!")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 49466098.00it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1847256.52it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 13487632.37it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5253868.94it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Epoch [1/50] | G Loss: 5.5588 | D Loss: 0.3284
Epoch [2/50] | G Loss: 3.4995 | D Loss: 0.1180
Epoch [3/50] | G Loss: 4.5743 | D Loss: 0.0086
Epoch [4/50] | G Loss: 5.0441 | D Loss: 0.0094
Epoch [5/50] | G Loss: 5.6633 | D Loss: 0.0038
Epoch [6/50] | G Loss: 5.9425 | D Loss: 0.0025
Epoch [7/50] | G Loss: 6.5514 | D Loss: 0.0010
Epoch [8/50] | G Loss: 6.7419 | D Loss: 0.0008
Epoch [9/50] | G Loss: 12.9642 | D Loss: 0.0619
Epoch [10/50] | G Loss: 5.9849 | D Loss: 0.0017
Epoch [11/50] | G Loss: 6.8452 | D Loss: 0.0007
Epoch [12/50] | G Loss: 7.0446 | D Loss: 0.0019
Epoch [13/50] | G Loss: 7.3296 | D Loss: 0.0006
Epoch [14/50] | G Loss: 7.2917 | D Loss: 0.0004
Epoch [15/50] | G Loss: 8.1754 | D Loss: 0.0002
Epoch [16/50] | G Loss: 8.5684 | D Loss: 0.0001
Epoch [17/50] | G Loss: 8.6007 | D Loss: 0.0001
Epoch [18/50] | G Loss: 8.5869 | D Loss: 0.0004
Epoch [19/50] | G Loss: 8.9012 | D Loss: 0.0001
Epoch [20/50] | G Loss: 8.5518 | D Loss: 0.0001
Epoch [21/50] | G Loss: 9.3223 | D Loss: 0.0001
