In [None]:
#default_exp core.train

In [None]:
#export
import torch
from torch import nn
from torch.nn import functional as F
from vase.config import DATA_PATH


In [None]:
#hide
from vase.core.models import VanillaVAE
from vase.core.datasets.moving_mnist import MovingFashionMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor


In [None]:
fashion_data = MovingFashionMNIST(DATA_PATH, transform=ToTensor(), download=True)
loader_iter = iter(DataLoader(fashion_data, 64))

# Training For VASE
> all losses and training code for VASE (variational inference, environmental inference, latent masking, generative replay, object classification, location regresion)

## Problem Setup

TODO: add all the distributions

## Standard VAE (Reconstruction + Target KL)
The paper's "The Minimum Length Description (MDL)" loss is a variannt of the standard VAE ELBO loss, maximimzing the likelihood while minimizing the KL Divergence to the prior:

$$\mathcal{L}_{MDL}(\phi, \theta) = E_{\mathbf{z}^s \sim q_{\phi}(\dot|\mathbf{x}^s)}[-\log{p_{\theta}(\mathbf{x}|\mathbf{z}^s, s)}] + \gamma |KL(q_{\phi}(\mathbf{z}^s|\mathbf{x}^s)||p(z)) - C|^2$$


However, you'll notice the KL divergence term is slighly non-standard. Rather than penalizing the KLDiv at a fixed rate, the loss is the difference between the KLDiv and a dynamic target $C$, which increases over the course of training, allowing for gradually more representation capacity. This trick was taken from [Understanding disentanglement in the $\beta$-VAE](https://arxiv.org/pdf/1804.03599.pdf)

For now we'll also drop the environment super script s, just training an autoencoder on iid data:

$$\mathcal{L}_{MDL}(\phi, \theta) = E_{\mathbf{z} \sim q_{\phi}(\dot|\mathbf{x})}[-\log{p_{\theta}(\mathbf{x}|\mathbf{z})}] + \gamma |KL(q_{\phi}(\mathbf{z}|\mathbf{x})||p(z)) - C|^2$$

### Reconstruction Loss

We'll use Binary Cross Entropy Loss with $y$ the ground truth image $x$, and $p(y)$ the reconstructed image. In terms of log likelihood, I'm not really sure how this makes sense, but it seems to be how its done... (TODO figure this out)

In [None]:
#export
def reconstruction_loss(x, x_rec):
    return F.binary_cross_entropy(x_rec, x, reduce=True)

### KL Div Target

Recall the definition of KL Divergence is the expected value under the reference distribution of the information ratio (or something like that):

$$D_{KL}(q||p) = E_q[\log{\frac{q}{p}}] $$

So in our case, with
$$KL(q_{\phi}(\mathbf{z}|\mathbf{x})||p(z))$$
we have 

$$KL(q_{\phi}(\mathbf{z}|\mathbf{x})||p(z)) = E_{q_{\phi}(\mathbf{z}|x)}[\log{q_{\phi}(\mathbf{z}|\mathbf{x})} - \log{p(z)}]$$

Note that both $q_{\phi}(\mathbf{z}|x))$ and $p(z)$ are diagonal gaussians. The KL divergence between diagonal gaussians can be [derived analytically](https://stats.stackexchange.com/questions/7440/kl-divergence-between-two-univariate-gaussians), and is given by:

$$ \log{\frac{\sigma_2}{\sigma_1}} + \frac{\sigma_1^2 + (\mu_1 - \mu_2)^2}{2\sigma_2^2} - \frac{1}{2}$$

Since $p(z)$ is standard normal, we have $\mu_2 = 0, \sigma_2 = 1$, reducing the equation to:

$$ \frac{1}{2}(\sigma_1^2 + \mu_1^2 - 1) - \log{\sigma_1^2}$$

#### KLDiv Standard Normal

In [None]:
#export
def kl_div_stdnorm(mu, logvar):
    """Returns mean of KL Divergence across batch"""
    return torch.mean(0.5 * (logvar.exp() + mu.pow(2) - 1) - logvar) #NOTE: this might be off, other implementations scale logvar too

In [None]:
assert kl_div_stdnorm(torch.Tensor([0]), torch.log(torch.Tensor([1]))) == 0

Let $\mu_1$ = 2, $\sigma_1^2$ = 4, then we would have
$$KL(q, p) = \log \frac{1}{4} + \frac{4 + (2-0)^2}{2} - \frac{1}{2} = 4 - \frac{1}{2} + log{\frac{1}{4}}

In [None]:
assert kl_div_stdnorm(torch.Tensor([2]), torch.log(torch.Tensor([4]))) == 4 - .5 + torch.log(torch.Tensor([.25]))

#### KLDiv Target Loss

Now we can define the full loss:

$$\gamma |KL(q_{\phi}(\mathbf{z}^s|\mathbf{x}^s)||p(z)) - C|^2$$

I'm not sure if the difference is computed element wise, or by batch....

In [None]:
#export
def kl_div_target(mu, logvar, C=0, gamma=1):
    """Returns target loss: squared difference of mean kldivergence and target C scaled by gamma"""
    return gamma * ((kl_div_stdnorm(mu, logvar) - C).pow(2))

In [None]:
assert kl_div_target(torch.Tensor([0]), torch.log(torch.Tensor([1]))) == 0

In [None]:
assert kl_div_target(torch.Tensor([0]), torch.log(torch.Tensor([1])), C=1) == 1

In [None]:
assert kl_div_target(torch.Tensor([0]), torch.log(torch.Tensor([1])), C=2, gamma=3) == 12

### Train 

In [None]:
gamma=1
lr=6e-4
batch_size = 64
latents=24
C=0
epochs = 10

In [None]:
vanilla_vae = VanillaVAE(latents=latents)
optimizer = torch.optim.Adam(params = vanilla_vae.parameters(), lr=lr)
loader = DataLoader(fashion_data, batch_size)

In [None]:
for epoch in range(epochs):
    total_loss = 0
    for X, _y, _pos in loader:
        optimizer.zero_grad()

        rec_X, mu, logvar = vanilla_vae(X)

        rec_loss = reconstruction_loss(X, rec_X)
        kl_loss = kl_div_target(mu, logvar, C=C, gamma=gamma)
        loss = rec_loss + kl_loss

        loss.backward()
        optimizer.step()
        total_loss += loss
    print(f"epoch: {epoch}, loss={total_loss}")



epoch: 0, loss=132.15542602539062
epoch: 1, loss=131.6337432861328
epoch: 2, loss=131.08164978027344
epoch: 3, loss=130.8853302001953
epoch: 4, loss=131.02780151367188
epoch: 5, loss=130.64657592773438
epoch: 6, loss=130.9297332763672
epoch: 7, loss=130.6197052001953


KeyboardInterrupt: 