In [9]:
import os

import torch
import torch.nn.functional as F
import torch.distributions as D
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from vae import Encoder, Decoder
from ssm import StateSpaceModel

In [6]:
a_dim = 10

In [None]:
ssm = StateSpaceModel(2, 2, 3)
as_ = torch.randn(8, 16, 2)

In [17]:
filter_means, filter_covariances, filter_next_means, filter_next_covariances, mat_As, mat_Cs = ssm.kalman_filter(as_)
means, covariances = ssm.kalman_smooth(as_, filter_means, filter_covariances, filter_next_means, filter_next_covariances, mat_As, mat_Cs)

In [None]:
class KalmanVariationalAutoencoder(nn.Module):
    def __init__(self, a_dim, z_dim, K):
        super(KalmanVariationalAutoencoder, self).__init__()
        self.encoder = Encoder(a_dim)
        self.decoder = Decoder(a_dim)
        self.state_space_model = StateSpaceModel(a_dim, z_dim, K)
    
    def objective(self, xs):
        seq_length = xs.shape[0]
        batch_size = xs.shape[1]
        
        as_mean, as_std = self.encoder(xs.view(-1, *xs.shape[2:]))
        as_mean = as_mean.view(seq_length, batch_size, *as_mean.shape[2:])
        as_std = as_std.view(seq_length, batch_size, *as_std.shape[2:])

        # Sample from q_\phi (a|x)
        as_sample = D.Normal(as_mean, as_std).rsample()

        # Reconstruction loss
        xs, xs_logits = self.decoder(as_sample)
        reconstruction_obj = D.Bernoulli(logits=xs_logits).log_prob(xs).sum()
        
        q_obj= -D.Normal(as_mean, as_std).log_prob(as_sample).sum(0).mean(0).sum()

        # Kalman filter and smoother
        filter_means, filter_covariances, filter_next_means, filter_next_covariances, mat_As, mat_Cs = self.state_space_model.kalman_filter(as_sample)
        means, covariances = self.state_space_model.kalman_smooth(as_sample, filter_means, filter_covariances, filter_next_means, filter_next_covariances, mat_As, mat_Cs)

        # Sample from p_\gamma (z|a,u)
        # Shape of means: (sequence_length, batch_size, z_dim, 1)
        # Shape of covariances: (sequence_length, batch_size, z_dim, z_dim)
        zs_distrib = D.MultivariateNormal(means.view(-1, z_dim, 1), covariances.view(-1, z_dim, z_dim))
        zs_sample = zs_distrib.rsample()
        zs_sample = zs_sample.view(seq_length, batch_size, z_dim, 1)

        # ln p_\gamma(a|z)
        kalman_reconst_distrib = D.Normal(mat_Cs @ zs_sample, self.state_space_model.mat_R)
        kalman_reconst_obj = kalman_reconst_distrib.log_prob(as_sample).sum(0).mean(0).sum()

        # -ln p_\gamma(z|a)
        gamma_obj = - zs_distrib.log_prob(zs_sample.view(-1, z_dim, 1)).sum(0).mean(0).sum()

        objective = reconstruction_obj + q_obj + kalman_reconst_obj + gamma_obj
        
        return objective, {
            'reconstruction': reconstruction_obj,
            'q': q_obj,
            'kalman_reconst': kalman_reconst_obj,
            'gamma': gamma_obj
        }


In [None]:
image = torch.zeros(1, 3, 32, 32)
encoder = Encoder(a_dim=a_dim)
decoder = Decoder(a_dim=a_dim)
decoder(encoder(image)).shape