# Contents

In this notebook, we will create a RealNVP(real-valued non-volume preserving) generative model for the Oxford Pets dataset, with an intermediate AutoEncoder.

Instead of training the model on direct pixel values and generating images, we will

1. Train an AutoEncoder for the images
2. Convert the data into embeddings using the AutoEncoder
3. Train our RealNVP model on the embeddings
4. Generate embeddings using the RealNVP model and convert them to images using the AutoEncoder's decoder

This notebook is heavily based on [This Repo](https://github.com/SpencerSzabados/realnvp-pytorch/tree/master)

In [2]:
import copy
import os

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

import torch
from torch import nn, distributions
from torch.nn import MSELoss
from torchvision import transforms
from datasets import load_dataset

In [3]:
# Set the random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define constants
IMAGE_SIZE = 64
BATCH_SIZE = 64
EMBEDDING_DIM = 256 # Dimension of the latent embedding

Using device: cpu


### Loading Data

In [4]:
from torchvision import transforms, datasets

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    # 你可以添加更多 transform
])

train_dataset = datasets.OxfordIIITPet(
    root='data',
    split='trainval',
    target_types='category',
    transform=transform,  # 这里很关键
    download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2
)

### Autoencoder definition

# **📌 Autoencoder for Compressing Pet Images**

The Oxford Pets dataset contains much more complex images (64x64, RGB, varied poses and backgrounds) than MNIST. To handle this, we use a more powerful **Autoencoder** with a deeper, DCGAN-style architecture.

## **🔹 How it Works**

1️⃣ **Encoder**: A deep convolutional network that compresses a 64x64 pet image into a low-dimensional latent vector (embedding). This embedding must capture the key features that define the pet's appearance and species.

2️⃣ **Decoder**: A deconvolutional network that attempts to perfectly reconstruct the original image from this compressed embedding.

By training this model to minimize reconstruction error, we create a rich, low-dimensional "embedding space" that we can then model with our Normalizing Flow.



## **📌 Expected Input & Output Shapes**

- **Input Image:** `(batch_size, 3, 64, 64)`
- **Latent Embedding:** `(batch_size, 256)`  *(256 is our `EMBEDDING_DIM`)*
- **Reconstructed Image:** `(batch_size, 3, 64, 64)`

In [5]:
class Autoencoder(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
        
        # --- Encoder ---
        # Input: (B, 3, 64, 64)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), # -> (B, 32, 32, 32)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # -> (B, 64, 16, 16)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # -> (B, 128, 8, 8)
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # -> (B, 256, 4, 4)
            nn.ReLU(),
            nn.Flatten(), # -> (B, 256 * 4 * 4)
            nn.Linear(256 * 4 * 4, self.embedding_dim)
        )
        
        # --- Decoder ---
        # Input: (B, embedding_dim)
        self.decoder_input = nn.Linear(self.embedding_dim, 256 * 4 * 4)
        self.decoder = nn.Sequential(
            # Reshape -> (B, 256, 4, 4)
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # -> (B, 128, 8, 8)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # -> (B, 64, 16, 16)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # -> (B, 32, 32, 32)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1), # -> (B, 3, 64, 64)
            nn.Tanh() # Tanh activation to output values in [-1, 1]
        )

    def forward(self, x):
        embedding = self.encoder(x)
        decoder_in = self.decoder_input(embedding)
        decoder_in = decoder_in.view(-1, 256, 4, 4) # Reshape for deconvolution
        reconstructed = self.decoder(decoder_in)
        return reconstructed, embedding

### Autoencoder training on Pets

In [6]:
# Create model and optimizer
ae_model = Autoencoder(EMBEDDING_DIM).to(device)
optimizer = torch.optim.Adam(ae_model.parameters(), lr=2e-4)
loss_fn = MSELoss()
AE_EPOCHS = 100

# Training Loop
ae_model.train()
for epoch in range(AE_EPOCHS):
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"AE Epoch {epoch+1}/{AE_EPOCHS}"):
        # OxfordIIITPet returns (image, target) tuples
        if isinstance(batch, dict):
            images = batch['image']
            # If images are PIL Images, apply transform
            if isinstance(images[0], Image.Image):
                images = torch.stack([transform(img) for img in images])
            images = images.to(device)
        else:
            images, _ = batch
            images = images.to(device)
        
        # Forward pass
        reconstructed, _ = ae_model(images)
        loss = loss_fn(reconstructed, images)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{AE_EPOCHS}], Average Reconstruction Loss: {avg_loss:.4f}")

# Save the trained model
torch.save(ae_model.state_dict(), 'pets_autoencoder.pth')
print("Autoencoder model saved.")

AE Epoch 1/100: 100%|██████████| 115/115 [00:27<00:00,  4.11it/s]


Epoch [1/100], Average Reconstruction Loss: 0.0862


AE Epoch 2/100: 100%|██████████| 115/115 [00:29<00:00,  3.89it/s]


Epoch [2/100], Average Reconstruction Loss: 0.0316


AE Epoch 3/100: 100%|██████████| 115/115 [00:30<00:00,  3.75it/s]


Epoch [3/100], Average Reconstruction Loss: 0.0256


AE Epoch 4/100: 100%|██████████| 115/115 [00:30<00:00,  3.75it/s]


Epoch [4/100], Average Reconstruction Loss: 0.0226


AE Epoch 5/100: 100%|██████████| 115/115 [00:28<00:00,  4.09it/s]


Epoch [5/100], Average Reconstruction Loss: 0.0204


AE Epoch 6/100: 100%|██████████| 115/115 [00:28<00:00,  4.01it/s]


Epoch [6/100], Average Reconstruction Loss: 0.0189


AE Epoch 7/100: 100%|██████████| 115/115 [00:27<00:00,  4.21it/s]


Epoch [7/100], Average Reconstruction Loss: 0.0177


AE Epoch 8/100: 100%|██████████| 115/115 [00:27<00:00,  4.24it/s]


Epoch [8/100], Average Reconstruction Loss: 0.0166


AE Epoch 9/100: 100%|██████████| 115/115 [00:27<00:00,  4.22it/s]


Epoch [9/100], Average Reconstruction Loss: 0.0157


AE Epoch 10/100: 100%|██████████| 115/115 [00:27<00:00,  4.20it/s]


Epoch [10/100], Average Reconstruction Loss: 0.0149


AE Epoch 11/100: 100%|██████████| 115/115 [00:27<00:00,  4.20it/s]


Epoch [11/100], Average Reconstruction Loss: 0.0141


AE Epoch 12/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [12/100], Average Reconstruction Loss: 0.0136


AE Epoch 13/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [13/100], Average Reconstruction Loss: 0.0129


AE Epoch 14/100: 100%|██████████| 115/115 [00:27<00:00,  4.16it/s]


Epoch [14/100], Average Reconstruction Loss: 0.0125


AE Epoch 15/100: 100%|██████████| 115/115 [00:28<00:00,  4.06it/s]


Epoch [15/100], Average Reconstruction Loss: 0.0120


AE Epoch 16/100: 100%|██████████| 115/115 [00:28<00:00,  3.98it/s]


Epoch [16/100], Average Reconstruction Loss: 0.0116


AE Epoch 17/100: 100%|██████████| 115/115 [00:27<00:00,  4.12it/s]


Epoch [17/100], Average Reconstruction Loss: 0.0111


AE Epoch 18/100: 100%|██████████| 115/115 [00:27<00:00,  4.11it/s]


Epoch [18/100], Average Reconstruction Loss: 0.0107


AE Epoch 19/100: 100%|██████████| 115/115 [00:27<00:00,  4.21it/s]


Epoch [19/100], Average Reconstruction Loss: 0.0105


AE Epoch 20/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [20/100], Average Reconstruction Loss: 0.0100


AE Epoch 21/100: 100%|██████████| 115/115 [00:27<00:00,  4.16it/s]


Epoch [21/100], Average Reconstruction Loss: 0.0097


AE Epoch 22/100: 100%|██████████| 115/115 [00:27<00:00,  4.14it/s]


Epoch [22/100], Average Reconstruction Loss: 0.0095


AE Epoch 23/100: 100%|██████████| 115/115 [00:27<00:00,  4.16it/s]


Epoch [23/100], Average Reconstruction Loss: 0.0093


AE Epoch 24/100: 100%|██████████| 115/115 [00:29<00:00,  3.92it/s]


Epoch [24/100], Average Reconstruction Loss: 0.0090


AE Epoch 25/100: 100%|██████████| 115/115 [00:28<00:00,  4.09it/s]


Epoch [25/100], Average Reconstruction Loss: 0.0088


AE Epoch 26/100: 100%|██████████| 115/115 [00:29<00:00,  3.91it/s]


Epoch [26/100], Average Reconstruction Loss: 0.0086


AE Epoch 27/100: 100%|██████████| 115/115 [00:29<00:00,  3.87it/s]


Epoch [27/100], Average Reconstruction Loss: 0.0085


AE Epoch 28/100: 100%|██████████| 115/115 [00:27<00:00,  4.19it/s]


Epoch [28/100], Average Reconstruction Loss: 0.0082


AE Epoch 29/100: 100%|██████████| 115/115 [00:27<00:00,  4.26it/s]


Epoch [29/100], Average Reconstruction Loss: 0.0081


AE Epoch 30/100: 100%|██████████| 115/115 [00:27<00:00,  4.18it/s]


Epoch [30/100], Average Reconstruction Loss: 0.0082


AE Epoch 31/100: 100%|██████████| 115/115 [00:27<00:00,  4.21it/s]


Epoch [31/100], Average Reconstruction Loss: 0.0078


AE Epoch 32/100: 100%|██████████| 115/115 [00:27<00:00,  4.20it/s]


Epoch [32/100], Average Reconstruction Loss: 0.0077


AE Epoch 33/100: 100%|██████████| 115/115 [00:27<00:00,  4.18it/s]


Epoch [33/100], Average Reconstruction Loss: 0.0075


AE Epoch 34/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [34/100], Average Reconstruction Loss: 0.0074


AE Epoch 35/100: 100%|██████████| 115/115 [00:27<00:00,  4.20it/s]


Epoch [35/100], Average Reconstruction Loss: 0.0073


AE Epoch 36/100: 100%|██████████| 115/115 [00:27<00:00,  4.17it/s]


Epoch [36/100], Average Reconstruction Loss: 0.0072


AE Epoch 37/100: 100%|██████████| 115/115 [00:27<00:00,  4.21it/s]


Epoch [37/100], Average Reconstruction Loss: 0.0071


AE Epoch 38/100: 100%|██████████| 115/115 [00:27<00:00,  4.24it/s]


Epoch [38/100], Average Reconstruction Loss: 0.0070


AE Epoch 39/100: 100%|██████████| 115/115 [00:27<00:00,  4.22it/s]


Epoch [39/100], Average Reconstruction Loss: 0.0068


AE Epoch 40/100: 100%|██████████| 115/115 [00:26<00:00,  4.28it/s]


Epoch [40/100], Average Reconstruction Loss: 0.0067


AE Epoch 41/100: 100%|██████████| 115/115 [00:27<00:00,  4.23it/s]


Epoch [41/100], Average Reconstruction Loss: 0.0067


AE Epoch 42/100: 100%|██████████| 115/115 [00:27<00:00,  4.16it/s]


Epoch [42/100], Average Reconstruction Loss: 0.0065


AE Epoch 43/100: 100%|██████████| 115/115 [00:27<00:00,  4.20it/s]


Epoch [43/100], Average Reconstruction Loss: 0.0065


AE Epoch 44/100: 100%|██████████| 115/115 [00:27<00:00,  4.22it/s]


Epoch [44/100], Average Reconstruction Loss: 0.0065


AE Epoch 45/100: 100%|██████████| 115/115 [00:27<00:00,  4.22it/s]


Epoch [45/100], Average Reconstruction Loss: 0.0063


AE Epoch 46/100: 100%|██████████| 115/115 [00:27<00:00,  4.18it/s]


Epoch [46/100], Average Reconstruction Loss: 0.0062


AE Epoch 47/100: 100%|██████████| 115/115 [00:27<00:00,  4.14it/s]


Epoch [47/100], Average Reconstruction Loss: 0.0061


AE Epoch 48/100: 100%|██████████| 115/115 [00:27<00:00,  4.23it/s]


Epoch [48/100], Average Reconstruction Loss: 0.0060


AE Epoch 49/100: 100%|██████████| 115/115 [00:27<00:00,  4.20it/s]


Epoch [49/100], Average Reconstruction Loss: 0.0059


AE Epoch 50/100: 100%|██████████| 115/115 [00:27<00:00,  4.18it/s]


Epoch [50/100], Average Reconstruction Loss: 0.0059


AE Epoch 51/100: 100%|██████████| 115/115 [00:27<00:00,  4.20it/s]


Epoch [51/100], Average Reconstruction Loss: 0.0058


AE Epoch 52/100: 100%|██████████| 115/115 [00:27<00:00,  4.16it/s]


Epoch [52/100], Average Reconstruction Loss: 0.0058


AE Epoch 53/100: 100%|██████████| 115/115 [00:27<00:00,  4.18it/s]


Epoch [53/100], Average Reconstruction Loss: 0.0057


AE Epoch 54/100: 100%|██████████| 115/115 [00:27<00:00,  4.16it/s]


Epoch [54/100], Average Reconstruction Loss: 0.0056


AE Epoch 55/100: 100%|██████████| 115/115 [00:27<00:00,  4.14it/s]


Epoch [55/100], Average Reconstruction Loss: 0.0056


AE Epoch 56/100: 100%|██████████| 115/115 [00:27<00:00,  4.21it/s]


Epoch [56/100], Average Reconstruction Loss: 0.0056


AE Epoch 57/100: 100%|██████████| 115/115 [00:27<00:00,  4.18it/s]


Epoch [57/100], Average Reconstruction Loss: 0.0054


AE Epoch 58/100: 100%|██████████| 115/115 [00:27<00:00,  4.12it/s]


Epoch [58/100], Average Reconstruction Loss: 0.0053


AE Epoch 59/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [59/100], Average Reconstruction Loss: 0.0053


AE Epoch 60/100: 100%|██████████| 115/115 [00:27<00:00,  4.18it/s]


Epoch [60/100], Average Reconstruction Loss: 0.0054


AE Epoch 61/100: 100%|██████████| 115/115 [00:27<00:00,  4.12it/s]


Epoch [61/100], Average Reconstruction Loss: 0.0052


AE Epoch 62/100: 100%|██████████| 115/115 [00:27<00:00,  4.14it/s]


Epoch [62/100], Average Reconstruction Loss: 0.0052


AE Epoch 63/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [63/100], Average Reconstruction Loss: 0.0052


AE Epoch 64/100: 100%|██████████| 115/115 [00:27<00:00,  4.11it/s]


Epoch [64/100], Average Reconstruction Loss: 0.0050


AE Epoch 65/100: 100%|██████████| 115/115 [00:28<00:00,  4.11it/s]


Epoch [65/100], Average Reconstruction Loss: 0.0050


AE Epoch 66/100: 100%|██████████| 115/115 [00:27<00:00,  4.12it/s]


Epoch [66/100], Average Reconstruction Loss: 0.0050


AE Epoch 67/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [67/100], Average Reconstruction Loss: 0.0052


AE Epoch 68/100: 100%|██████████| 115/115 [00:27<00:00,  4.14it/s]


Epoch [68/100], Average Reconstruction Loss: 0.0049


AE Epoch 69/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [69/100], Average Reconstruction Loss: 0.0048


AE Epoch 70/100: 100%|██████████| 115/115 [00:27<00:00,  4.14it/s]


Epoch [70/100], Average Reconstruction Loss: 0.0047


AE Epoch 71/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [71/100], Average Reconstruction Loss: 0.0047


AE Epoch 72/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [72/100], Average Reconstruction Loss: 0.0047


AE Epoch 73/100: 100%|██████████| 115/115 [00:27<00:00,  4.11it/s]


Epoch [73/100], Average Reconstruction Loss: 0.0047


AE Epoch 74/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [74/100], Average Reconstruction Loss: 0.0046


AE Epoch 75/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [75/100], Average Reconstruction Loss: 0.0046


AE Epoch 76/100: 100%|██████████| 115/115 [00:28<00:00,  4.10it/s]


Epoch [76/100], Average Reconstruction Loss: 0.0046


AE Epoch 77/100: 100%|██████████| 115/115 [00:27<00:00,  4.11it/s]


Epoch [77/100], Average Reconstruction Loss: 0.0045


AE Epoch 78/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [78/100], Average Reconstruction Loss: 0.0045


AE Epoch 79/100: 100%|██████████| 115/115 [00:27<00:00,  4.11it/s]


Epoch [79/100], Average Reconstruction Loss: 0.0045


AE Epoch 80/100: 100%|██████████| 115/115 [00:27<00:00,  4.12it/s]


Epoch [80/100], Average Reconstruction Loss: 0.0044


AE Epoch 81/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [81/100], Average Reconstruction Loss: 0.0043


AE Epoch 82/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [82/100], Average Reconstruction Loss: 0.0043


AE Epoch 83/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [83/100], Average Reconstruction Loss: 0.0043


AE Epoch 84/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [84/100], Average Reconstruction Loss: 0.0044


AE Epoch 85/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [85/100], Average Reconstruction Loss: 0.0042


AE Epoch 86/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [86/100], Average Reconstruction Loss: 0.0042


AE Epoch 87/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [87/100], Average Reconstruction Loss: 0.0042


AE Epoch 88/100: 100%|██████████| 115/115 [00:27<00:00,  4.13it/s]


Epoch [88/100], Average Reconstruction Loss: 0.0042


AE Epoch 89/100: 100%|██████████| 115/115 [00:27<00:00,  4.11it/s]


Epoch [89/100], Average Reconstruction Loss: 0.0041


AE Epoch 90/100: 100%|██████████| 115/115 [00:27<00:00,  4.11it/s]


Epoch [90/100], Average Reconstruction Loss: 0.0041


AE Epoch 91/100: 100%|██████████| 115/115 [00:27<00:00,  4.14it/s]


Epoch [91/100], Average Reconstruction Loss: 0.0042


AE Epoch 92/100: 100%|██████████| 115/115 [00:27<00:00,  4.12it/s]


Epoch [92/100], Average Reconstruction Loss: 0.0040


AE Epoch 93/100: 100%|██████████| 115/115 [00:27<00:00,  4.14it/s]


Epoch [93/100], Average Reconstruction Loss: 0.0040


AE Epoch 94/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [94/100], Average Reconstruction Loss: 0.0040


AE Epoch 95/100: 100%|██████████| 115/115 [00:27<00:00,  4.16it/s]


Epoch [95/100], Average Reconstruction Loss: 0.0041


AE Epoch 96/100: 100%|██████████| 115/115 [00:27<00:00,  4.12it/s]


Epoch [96/100], Average Reconstruction Loss: 0.0040


AE Epoch 97/100: 100%|██████████| 115/115 [00:28<00:00,  4.09it/s]


Epoch [97/100], Average Reconstruction Loss: 0.0039


AE Epoch 98/100: 100%|██████████| 115/115 [00:27<00:00,  4.12it/s]


Epoch [98/100], Average Reconstruction Loss: 0.0039


AE Epoch 99/100: 100%|██████████| 115/115 [00:27<00:00,  4.15it/s]


Epoch [99/100], Average Reconstruction Loss: 0.0038


AE Epoch 100/100: 100%|██████████| 115/115 [00:28<00:00,  4.05it/s]

Epoch [100/100], Average Reconstruction Loss: 0.0038
Autoencoder model saved.





### Note: the following is done to make the embeddings in a normalized scale that the NF model expects

In [7]:
# --- Generate Embeddings ---
ae_model.eval()
all_embeddings = []
all_labels = []
with torch.no_grad():
    for batch in tqdm(train_loader, desc="Generating Embeddings"):
        images, labels = batch  # Unpack tuple
        images = images.to(device)
        _, embeddings = ae_model(images)
        all_embeddings.append(embeddings.cpu())
        all_labels.append(labels)

all_embeddings = torch.cat(all_embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)

# --- Normalize Embeddings ---
mean = all_embeddings.mean(0, keepdim=True)
std = all_embeddings.std(0, keepdim=True)
normalized_embeddings = (all_embeddings - mean) / std

# Save embeddings, labels, and normalization stats
torch.save({
    'embeddings': normalized_embeddings,
    'labels': all_labels,
    'mean': mean,
    'std': std
}, 'pets_embeddings.pth')

print("Embeddings generated and saved.")
print(f"Shape of normalized embeddings: {normalized_embeddings.shape}")

Generating Embeddings: 100%|██████████| 115/115 [00:21<00:00,  5.33it/s]

Embeddings generated and saved.
Shape of normalized embeddings: torch.Size([3680, 256])





### Normalizing Flow training

# **📌 RealNVP for Generating Pet Embeddings**
Now that we have a way to represent complex pet images as 256-dimensional vectors, we can train a **Normalizing Flow** to learn the *distribution* of these vectors. We will use the same **RealNVP** model, but this time it will be *conditional*.

## **🔹 Key Concepts**
1️⃣ **Conditional Generation**: We provide the class label (0 for Cat, 1 for Dog) as an additional input to the model. This allows the flow to learn two distinct distributions within the same model and lets us control whether we generate a cat or a dog.

2️⃣ **Invertible Transformation**: The model learns an invertible function `f` that maps a pet embedding `x` and its label `y` to a latent point `z` from a simple Gaussian distribution. This can be reversed to generate a new pet embedding from a random point `z` and a chosen label `y`.

3️⃣ **Coupling Layers**: RealNVP uses these clever layers to split the input, transforming one part based on the other part *and the conditional label*. This makes the transformation powerful while keeping the necessary calculations efficient.



## **📌 Expected Input & Output Shapes**
- **Input (Embeddings):** `(batch_size, 256)`
- **Conditional Input (Labels):** `(batch_size, 2)`  (One-hot encoded: [1,0] for Cat, [0,1] for Dog)
- **Output (Latent `u`):** `(batch_size, 256)`

In [8]:
class CouplingLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, mask, condition_dim=2):
        super().__init__()
        self.mask = mask
        # The scale and translate networks now also take the conditional input
        self.s_net = nn.Sequential(nn.Linear(input_dim // 2 + condition_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim // 2))
        self.t_net = nn.Sequential(nn.Linear(input_dim // 2 + condition_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim // 2))

    def forward(self, x, condition):
        x_a = x * self.mask
        x_b = x * (1 - self.mask)
        
        s_t_input = torch.cat([x_a, condition], dim=1)
        s = self.s_net(s_t_input)
        t = self.t_net(s_t_input)
        
        y_b = x_b * torch.exp(s * (1 - self.mask)) + t * (1 - self.mask)
        return x_a + y_b, s.sum(dim=1)

    def inverse(self, y, condition):
        y_a = y * self.mask
        y_b = y * (1 - self.mask)
        
        s_t_input = torch.cat([y_a, condition], dim=1)
        s = self.s_net(s_t_input)
        t = self.t_net(s_t_input)
        
        x_b = (y_b - t * (1 - self.mask)) * torch.exp(-s * (1 - self.mask))
        return y_a + x_b

class RealNVP(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super().__init__()
        self.prior = distributions.MultivariateNormal(torch.zeros(input_dim).to(device), torch.eye(input_dim).to(device))
        masks = [self.create_mask(input_dim, i) for i in range(num_layers)]
        self.layers = nn.ModuleList([CouplingLayer(input_dim, hidden_dim, m) for m in masks])

    def create_mask(self, dim, variation):
        mask = torch.arange(dim) % 2
        return (mask if variation % 2 == 0 else 1 - mask).float().to(device)

    def forward(self, x, condition):
        log_det_J = 0
        for layer in self.layers:
            x, ldj = layer(x, condition)
            log_det_J += ldj
        return x, log_det_J

    def inverse(self, z, condition):
        for layer in reversed(self.layers):
            z = layer.inverse(z, condition)
        return z

    def sample(self, num_samples, condition):
        z = self.prior.sample((num_samples,))
        return self.inverse(z, condition)

### Training the Conditional Normalizing Flow
We will now train the RealNVP model to learn the distribution of the normalized pet image embeddings, conditioned on whether the image is a cat or a dog.

1️⃣ **Forward Pass** → Transform a pet embedding `emb` and its one-hot encoded label `y` into a latent vector `u`.

2️⃣ **Compute Loss** → Maximize the log-likelihood of this transformation.

3️⃣ **Backward Pass** → Update the flow's parameters to better model the two conditional distributions.

In [9]:
# Load the embeddings and create a DataLoader
data = torch.load('pets_embeddings.pth')
normalized_embeddings = data['embeddings']
all_labels = data['labels']
embedding_dataset = torch.utils.data.TensorDataset(normalized_embeddings, all_labels)
embedding_loader = torch.utils.data.DataLoader(embedding_dataset, batch_size=BATCH_SIZE, shuffle=True)

# Model and Optimizer
FLOW_LAYERS = 8
FLOW_HIDDEN = 512
flow_model = RealNVP(EMBEDDING_DIM, FLOW_HIDDEN, FLOW_LAYERS).to(device)
optimizer = torch.optim.Adam(flow_model.parameters(), lr=1e-4)
FLOW_EPOCHS = 200

# Training Loop
flow_model.train()
for epoch in range(FLOW_EPOCHS):
    total_loss = 0
    for embeds, labels in tqdm(embedding_loader, desc=f"Flow Epoch {epoch+1}/{FLOW_EPOCHS}"):
        embeds = embeds.to(device)
        # Map breed indices to cat/dog: 0 for cat, 1 for dog
        # OxfordIIITPet: first 0-18 are cats, 19-36 are dogs
        binary_labels = (labels >= 19).long()
        condition = nn.functional.one_hot(binary_labels, num_classes=2).float().to(device)
        
        z, log_det = flow_model(embeds, condition)
        log_prob = flow_model.prior.log_prob(z)
        loss = -(log_prob + log_det).mean()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(embedding_loader)
    print(f"Epoch [{epoch+1}/{FLOW_EPOCHS}], Avg Negative Log-Likelihood: {avg_loss:.4f}")

# Save the trained flow model
torch.save(flow_model.state_dict(), 'pets_flow_model.pth')
print("Conditional Flow model saved.")

  data = torch.load('pets_embeddings.pth')
Flow Epoch 1/200:   0%|          | 0/58 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x258 and 130x512)


### Evaluating the Model

In [None]:
# --- Load all models and stats ---
ae_model.load_state_dict(torch.load('pets_autoencoder.pth', map_location=device))
ae_model.eval()

flow_model.load_state_dict(torch.load('pets_flow_model.pth', map_location=device))
flow_model.eval()

embedding_stats = torch.load('pets_embeddings.pth')
mean = embedding_stats['mean'].to(device)
std = embedding_stats['std'].to(device)

# --- Generation ---
def generate_images(label, num_images=5):
    # 0=Cat, 1=Dog
    condition_label = torch.tensor([label] * num_images).long()
    condition_one_hot = nn.functional.one_hot(condition_label, num_classes=2).float().to(device)
    
    with torch.no_grad():
        # Generate normalized embeddings
        generated_norm_embeds = flow_model.sample(num_images, condition_one_hot)
        # De-normalize
        generated_embeds = generated_norm_embeds * std + mean
        # Decode to images
        generated_images = ae_model.decoder(ae_model.decoder_input(generated_embeds).view(-1, 256, 4, 4))
    return generated_images

# Generate and plot cats and dogs
generated_cats = generate_images(label=0, num_images=5)
generated_dogs = generate_images(label=1, num_images=5)

# --- Plotting ---
def plot_images(images, title):
    images = images.cpu().numpy().transpose(0, 2, 3, 1)
    images = (images * 0.5) + 0.5 # De-normalize from [-1, 1] to [0, 1] for plotting
    fig, axes = plt.subplots(1, 5, figsize=(15, 3))
    fig.suptitle(title, fontsize=16)
    for i, ax in enumerate(axes):
        ax.imshow(images[i])
        ax.axis('off')
    plt.show()

plot_images(generated_cats, "Generated Cats")
plot_images(generated_dogs, "Generated Dogs")

## **🔹 Exercise: Exploring the Generative Landscape**

The final image quality is a result of two models working in tandem. Tweaking the hyperparameters of either the Autoencoder or the Normalizing Flow can lead to better results(hopefully).

### **📝 Tasks**

1.  **Autoencoder Quality**: The Autoencoder was trained for 100 epochs. Try training it for longer (e.g., `AE_EPOCHS = 200`). Does a lower reconstruction loss in the AE lead to sharper, more realistic generations from the complete system?
2.  **Embedding Dimension**: Change `EMBEDDING_DIM` to `128` (more compression) or `512` (less compression). How does this trade-off affect the detail (e.g., fur texture) and variety of the generated pets?
3.  **Flow Complexity**: Adjust the `flow_n` (e.g., to `8` or `16`) and the `coupling_topology` in the `LinearRNVP` (e.g. `[512, 512]`). How does the flow's capacity impact its ability to model the subtle differences between cat and dog breeds?

Note:
-   A high-quality **Autoencoder is crucial**. Garbage in, garbage out; if the embeddings are poor, the Normalizing Flow cannot generate good images.
    - This was a very big issue when I was creating the notebook 🙃

### Contributed by: Ali Habibullah.