# Lec 31: Variational Autoencoders II and GANs I

## Instructions
Build a Convolutional VAE for the EuroSAT dataset

1. Demonstrate by plotting original and reconstructed images that your model is performing well. You are free to choose your model configuration and any hyperparameters
2. Plot the reconstructed images from the mean of latent representations for 'industrial buildings'  and 'forest' classes
3. Plot the transition from 'industrial buildings' to 'forest' by moving in the latent space from one class to another in 10 steps and generating the full images using the decoder

## Imports

In [None]:
# PyTorch imports
import torch
import torch.nn as nn

In [None]:
# Import helper files
from helpers.helper_utils import set_all_seeds
from helpers.helper_data import get_dataloaders_eurosat
from helpers.helper_models import Reshape, Trim
from helpers.helper_train import train_cvae

## Settings

In [None]:
# Hyperparameters
LEARNING_RATE = 0.0005
BATCH_SIZE = 256
NUM_EPOCHS = 100

# Other settings
DEVICE = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')
RANDOM_SEED = 123

print(f'Device: {DEVICE}')

In [None]:
set_all_seeds(RANDOM_SEED)

## EuroSAT Dataset

In [None]:
train_loader = get_dataloaders_eurosat(BATCH_SIZE, num_workers=8)

# Verify dataset
print('Train Dataset:')
for images, labels in train_loader:
    print(f'Image batch dimensions: {images.size()}')
    print(f'Image label dimensions: {labels.size()}')
    break

## Convolutional Variational Autoencoder Model

Convolutional layer output size formula:
$$o=\lfloor\frac{i+2p-k}{s}\rfloor+1$$

Transposed convolutional layer output size formula:
$$o=s(i-1)+k-2p$$

In [None]:
class CVAE(nn.Module):

    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            # 3 channels x 64x64 pixel images

            # Output: 32 channels x 32x32 pixel images
            nn.Conv2d(3, 32, stride=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),
            
            # Output: 64 channels x 16x16 pixel images
            nn.Conv2d(32, 64, stride=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),

            # Output: 128 channels x 8x8 pixel images
            nn.Conv2d(64, 128, stride=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),

            # Output: 256 channels x 4x4 pixel images
            nn.Conv2d(128, 256, stride=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),

            # Output: 512 channels x 2x2 pixel images
            nn.Conv2d(256, 512, stride=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1, inplace=True),

            # Output: 512x2x2 = 2048 features
            nn.Flatten()
        )

        # Fully connected layer: 2048 -> 512
        self.z_mean = nn.Linear(2048, 512)

        # Fully connected layer: 2048 -> 512
        self.z_log_var = nn.Linear(2048, 512)

        self.decoder = nn.Sequential(
            # Fully connected layer: 512 -> 2048
            nn.Linear(512, 2048),

            # Reshape to 512x2x2 (how it was in encoder before flattening)
            # See helper class above
            Reshape(-1, 512, 2, 2),

            # Output: 256 channels x 5x5 pixel images
            nn.ConvTranspose2d(512, 256, stride=2, kernel_size=3),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1, inplace=True),

            # Output: 128 channels, 9x9 pixel images
            nn.ConvTranspose2d(256, 128, stride=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1, inplace=True),

            # Output: 64 channels, 17x17 pixel images
            nn.ConvTranspose2d(128, 64, stride=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1, inplace=True),

            # Output: 32 channels, 33x33 pixel images
            nn.ConvTranspose2d(64, 32, stride=2, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.1, inplace=True),

            # Output: 3 channels, 65x65 pixel images
            nn.ConvTranspose2d(32, 3, stride=2, kernel_size=3, padding=1),

            # Trim images from 65x65 to 64x64
            # See helper class above
            Trim(64),
            nn.Sigmoid()
        )
    
    def encoding_fn(self, x):
        x = self.encoder(x)

        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)

        encoded = self.reparameterize(z_mean, z_log_var)

        return encoded

    def reparameterize(self, z_mean, z_log_var):
        eps = torch.randn(z_mean.size(0), z_mean.size(1)).to(DEVICE)
        z = z_mean + eps * torch.exp(z_log_var / 2.)
        return z

    def forward(self, x):
        x = self.encoder(x)

        z_mean = self.z_mean(x)
        z_log_var = self.z_log_var(x)
        encoded = self.reparameterize(z_mean, z_log_var)

        decoded = self.decoder(encoded)

        return encoded, z_mean, z_log_var, decoded

In [None]:
model = CVAE()
model.to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

## Training

In [None]:
log_dict = train_cvae(
    model,
    optimizer,
    num_epochs=NUM_EPOCHS,
    train_loader=train_loader,
    device=DEVICE
)