# Kalman VAE
- Original paper: A Disentangled Recognition and Nonlinear DynamicsModel for Unsupervised Learning (https://arxiv.org/pdf/1710.05741.pdf)
- Original code: https://github.com/simonkamronn/kvae

## Kalman VAE summary
>KVAE disentangles two latent representations: an object’s representation, coming from a recognition model, and a latent state describing its dynamics. As a result, the evolution of the world can be imagined and missing data imputed, both without the need to generate high dimensional frames at each time step.

In [1]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

from pixyz.utils import print_latex

In [2]:
batch_size = 128
epochs = 1
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x7f02a8092df0>

In [3]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [4]:
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

## Define probability distributions
### In the original paper
Prior: $p_{\gamma}({\bf a}|{\bf u}) = {\int p_{\gamma}({\bf a}|{\bf z})p_{\gamma}({\bf z}|{\bf u})d{\bf z}}$ (equation (3) in the paper)

Generator: $p_{\theta}({\bf x}|{\bf a}) = \prod_{t=1}^{T}p_{\theta}(x_t | a_t)$  
>$p_{\theta}(x_t|a_t)$ is a deep neural network parameterized by θ, that emits either a factorized Gaussian or Bernoulli probability vector depending on the data type of $x_t$.

Inference:  $q_{\phi}({\bf a}|{\bf x}) = \prod_{t=1}^{T}q_{\phi}(a_t | x_t)$  
$q_{\phi}(a_t | x_t) = {\cal N}(\mu_{\phi}(x_t), \Sigma_{\phi}(x_t))$
>$q_{\phi}(a_t|x_t)$ is a deep neural network that maps xt to the mean and the diagonal covariance of a Gaussian distribution.  

LGSSM:  
$p_{\gamma_{t}}\left(\mathbf{z}_{t} | \mathbf{z}_{t-1}, \mathbf{u}_{t}\right)=\mathcal{N}\left(\mathbf{z}_{t} ; \mathbf{A}_{t} \mathbf{z}_{t-1}+\mathbf{B}_{t} \mathbf{u}_{t}, \mathbf{Q}\right), \quad p_{\gamma_{t}}\left(\mathbf{a}_{t} | \mathbf{z}_{t}\right)=\mathcal{N}\left(\mathbf{a}_{t} ; \mathbf{C}_{t} \mathbf{z}_{t}, \mathbf{R}\right)$ (equation (1) in the paper)

$p_{\gamma}(\mathbf{a}, \mathbf{z} | \mathbf{u})=\prod_{t=1}^{T} p_{\gamma_{t}}\left(\mathbf{a}_{0: t-1}\right)\left(\mathbf{a}_{t} | \mathbf{z}_{t}\right) \cdot p\left(\mathbf{z}_{1}\right) \prod_{t=2}^{T} p_{\gamma_{t}\left(\mathbf{a}_{0: t-1}\right)}\left(\mathbf{z}_{t} | \mathbf{z}_{t-1}, \mathbf{u}_{t}\right)$ (equation (8) in the paper)

dynamics parameter network: 
${\bf \alpha}_t = {\bf \alpha}_t(a_{0:t-1})$  
${\bf d}_t = LSTM(a_{t-1}, d_{t-1})$  
${\bf \alpha}_t = softmax({\bf d}_t)$
$\mathbf{A}_{t}=\sum_{k=1}^{K} \alpha_{t}^{(k)}\left(\mathbf{a}_{0: t-1}\right) \mathbf{A}^{(k)}, \quad \mathbf{B}_{t}=\sum_{k=1}^{K} \alpha_{t}^{(k)}\left(\mathbf{a}_{0: t-1}\right) \mathbf{B}^{(k)}, \quad \mathbf{C}_{t}=\sum_{k=1}^{K} \alpha_{t}^{(k)}\left(\mathbf{a}_{0: t-1}\right) \mathbf{C}^{(k)}$

In [5]:
# https://github.com/simonkamronn/kvae/blob/849d631dbf2faf2c293d56a0d7a2e8564e294a51/kvae/KalmanVariationalAutoencoder.py#L97

from pixyz.distributions import Normal
class Inference(Normal):
    def __init__(self, x_dim, a_dim):
        super(Inference, self).__init__(name="q_phi", cond_var=["x_t"], var=["a_t"])
        self.fc1 = nn.Linear(x_dim, 25)
        self.fc2 = nn.Linear(25, 25)
        
        self.fc3_1 = nn.Linear(25, a_dim)
        self.fc3_2 = nn.Linear(25, a_dim)
    
    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        return {"loc": self.fc3_1(h), "scale": F.softplus(self.fc3_2(h))}

In [6]:
encoder = Inference(44, 44)
print_latex(encoder)

<IPython.core.display.Math object>

In [7]:
# https://github.com/simonkamronn/kvae/blob/849d631dbf2faf2c293d56a0d7a2e8564e294a51/kvae/KalmanVariationalAutoencoder.py#L132

from pixyz.distributions import Bernoulli
class Generator(Bernoulli):
    def __init__(self, a_dim, x_dim):
        super(Generator, self).__init__(name="p_theta", cond_var=["a_t"], var=["x_t"])
        self.fc1 = nn.Linear(a_dim, 25)
        self.fc2 = nn.Linear(25, 25)
        
        self.fc3 = nn.Linear(25, x_dim)
    
    def forward(self, a):
        h = F.relu(self.fc1(a))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}

In [8]:
g = Generator(44, 44)
print_latex(g)

<IPython.core.display.Math object>

In [11]:
# https://github.com/simonkamronn/kvae/blob/849d631dbf2faf2c293d56a0d7a2e8564e294a51/kvae/KalmanVariationalAutoencoder.py#L176

from pixyz.distributions import Deterministic
class RNN(Deterministic):
    def __init__(self, a_dim, d_dim):
        super(RNN, self).__init__(name="LSTM", cond_var=["a_prev", "d_prev"], var=["d"])
        # d_dim = 50
        self.rnn = nn.LSTMCell(a_dim, d_dim)
        
    def forward(self, a_prev, d_prev):
        h, _ = self.rnn(a_prev, d_prev)
        return {"d": h}

In [None]:
# https://github.com/simonkamronn/kvae/blob/849d631dbf2faf2c293d56a0d7a2e8564e294a51/kvae/KalmanVariationalAutoencoder.py#L220

from pixyz.distributions import Deterministic

In [12]:
r = RNN(10, 50)
print_latex(r)

<IPython.core.display.Math object>

In [17]:
# https://github.com/simonkamronn/kvae/blob/849d631dbf2faf2c293d56a0d7a2e8564e294a51/kvae/KalmanVariationalAutoencoder.py#L220

from pixyz.distributions import Deterministic
class DynamicParameterNetwork(Deterministic):
    def __init__(self, d_dim, k_num):
        super(DynamicParameterNetwork, self).__init__(name="p_alpha", cond_var=["d"], var=["alpha"])
        self.fc1 = nn.Linear(d_dim, k_num)
        
    def forward(self, d):
        return {"alpha": F.softmax(self.fc1(d))}

In [18]:
a = DynamicParameterNetwork(50, 3)
print_latex(a)

<IPython.core.display.Math object>

In [20]:
Dynamic = a * r
print_latex(Dynamic)

<IPython.core.display.Math object>

## Define Loss function
$\mathcal{F}(\theta, \gamma, \phi)=\mathbb{E}_{q_{\phi}(\mathbf{a} | \mathbf{x})}\left[\log \frac{p_{\theta}(\mathbf{x} | \mathbf{a})}{q_{\phi}(\mathbf{a} | \mathbf{x})}+\mathbb{E}_{p_{\gamma}(\mathbf{z} | \mathbf{a}, \mathbf{u})}\left[\log \frac{p_{\gamma}(\mathbf{a} | \mathbf{z}) p_{\gamma}(\mathbf{z} | \mathbf{u})}{p_{\gamma}(\mathbf{z} | \mathbf{a}, \mathbf{u})}\right]\right]$ (euqation (6) in the paper)

$\hat{\mathcal{F}}(\theta, \gamma, \phi)=\frac{1}{I} \sum_{i} \log p_{\theta}\left(\mathbf{x} | \widetilde{\mathbf{a}}^{(i)}\right)+\log p_{\gamma}\left(\widetilde{\mathbf{a}}^{(i)}, \widetilde{\mathbf{z}}^{(i)} | \mathbf{u}\right)-\log q_{\phi}\left(\widetilde{\mathbf{a}}^{(i)} | \mathbf{x}\right)-\log p_{\gamma}\left(\widetilde{\mathbf{z}}^{(i)} | \widetilde{\mathbf{a}}^{(i)}, \mathbf{u}\right)$ (euqation (7) in the paper)