# SwapVAE

[SwapVAE](https://proceedings.neurips.cc/paper/2021/file/58182b82110146887c02dbd78719e3d5-Paper.pdf) is a modification of a vanilla VAE architecture that allows for the partitioning of the latent representation into "content" and "style" components. 

## Imports

In [None]:
import os, sys
sys.path.append("../")

In [None]:
import torch
import torchvision
from pytorch_lightning import Trainer
from playground.models import SwapVAE
from playground.datamodules import MNISTDataModule
from playground.utils import imshow

In [None]:
log_dir = '../lightning_logs'
pretrained_pth = '../pretrained'

## Introduction

### Architecture

As mentioned before, SwapVAE is similar to that of a vanilla VAE. The novelty  of the architecture is in the `BlockSwap` operation, as coined by the authors, which swaps the content component of the latent representations of two samples. In their schematic, this corresponds to a change in reach direction whilst preserving the reach dynamic.

![swapvae](img/swapvae.png)

### Loss function

The loss function is composed of three components:
- Reconstruction loss $\mathcal{L}_{\text{rec}}$
- Style space regularization $\mathcal{D}_{KL}$
- Content space alignment loss $\mathcal{L}_{\text{align}}$

The style space regularization and content space alignment loss are weighted by $\beta$ and $\alpha$, respectively.

\begin{equation}
\min_{f,g} \sum_{i=1,2} \mathcal{L}_{\text{rec}} (\mathbf{x}_i, g(\mathbf{z}_i)) + \beta\sum_{i=1,2} \mathcal{D}_{KL} (\mathbf{z}_i^{(s)} || \mathbf{z}_{i,\text{prior}}^{(s)}) + \alpha\mathcal{L}_{\text{align}} (\mathbf{z}_1^{(c)}, \mathbf{z}_2^{(c)})
\label{loss_total}
\end{equation}

To promote disentanglement of the content and style bilaterally, the reconstruction loss is further refined to consider the swapped representations as well:

\begin{equation}
\mathcal{L}_{\text{align}}^{\text{swap}} = \mathcal{L}_{\text{rec}} (\mathbf{x}_i, g(\tilde{\mathbf{z}}_i)) + \mathcal{L}_{\text{rec}} (\mathbf{x}_i, g(\mathbf{z}_i))
\end{equation}

### Training

The model is trained using the following hyperparameters:
- `learning_rate`: 

## Initialize model

The SwapVAE architecture consists of an encoder and decoder, much like a vanilla VAE model.

In [None]:
input_dim = 256
hidden_dim = [128, 64, 64]
content_dim = 16
style_dim = 16

model = SwapVAE(input_dim, hidden_dim, content_dim, style_dim)

model

We can pass in a pair of 256-dimensional vectors $x_1$ and $x_2$ to confirm that it works. The outputs $y_1$ and $y_2$ should be vectors of the same dimension.

In [None]:
x1 = torch.rand(input_dim)
x2 = torch.rand(input_dim)

y1, y2 = model(x1, x2)

y1.shape, y2.shape

## Datasets

The authors conducted experiments on synthetic and real datasets on neural data; however, as with many demonstrations of models, we will be using MNIST as well as SVHN, another dataset of handwritten digits.

### MNIST

In [None]:
data_dir = "../data"
batch_size = 32
num_workers = 8
train_val_split = 0.9

Initialize a PyTorch LightningDataModule for the MNIST dataset.

In [None]:
mnist_dm = MNISTDataModule(
    data_dir=data_dir,
    batch_size=batch_size,
    num_workers=num_workers,
    split=train_val_split
)

mnist_dm.prepare_data()
mnist_dm.setup()

In [None]:
mnist_train = mnist_dm.train_dataloader()
inputs, classes = next(iter(mnist_train))
out = torchvision.utils.make_grid(inputs)
imshow(out)

### SVHN

## Train model

In [None]:
min_epochs = 1
max_epochs = 1

trainer = Trainer(min_epochs=min_epochs, max_epochs=max_epochs)

In [None]:
trainer.fit(model, mnist_dm)

## References

[Drop, Swap, and Generate: A Self-Supervised Approach for Generating Neural Activity](https://proceedings.neurips.cc/paper/2021/file/58182b82110146887c02dbd78719e3d5-Paper.pdf)

[PyTorch Lightning](https://www.pytorchlightning.ai/)