In [1]:
import os

import torch
import torch.nn.functional as F
import torch.distributions as D
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from vae import Encoder, BernoulliDecoder, GaussianDecoder
from ssm import StateSpaceModel

In [6]:
class KalmanVariationalAutoencoder(nn.Module):
    def __init__(self, image_size, image_channels, a_dim, z_dim, K):
        super(KalmanVariationalAutoencoder, self).__init__()
        self.encoder = Encoder(image_size, image_channels, a_dim)
        self.decoder = GaussianDecoder(a_dim, image_size, image_channels)
        self.state_space_model = StateSpaceModel(a_dim, z_dim, K)
        self.a_dim = a_dim
        self.z_dim = z_dim
    
    def objective(self, xs):
        seq_length = xs.shape[0]
        batch_size = xs.shape[1]
        
        as_dist = self.encoder(xs.view(-1, *xs.shape[2:]))
        as_sample = as_dist.rsample().view(seq_length, batch_size, self.a_dim)

        # Reconstruction objective
        xs_dist = self.decoder(as_sample.view(-1, self.a_dim))
        reconstruction_obj = xs_dist.log_prob(xs.view(-1, *xs.shape[2:])).sum(0).mean(0).sum()

        # Regularization objective
        # -ln q_\phi(a|x)
        regularization_obj = - as_dist.log_prob(as_sample.view(-1, self.a_dim)).sum(0).mean(0).sum()
        
        # Kalman filter and smoother
        filter_means, filter_covariances, filter_next_means, filter_next_covariances, mat_As, mat_Cs = self.state_space_model.kalman_filter(as_sample)
        means, covariances = self.state_space_model.kalman_smooth(as_sample, filter_means, filter_covariances, filter_next_means, filter_next_covariances, mat_As, mat_Cs)

        # Sample from p_\gamma (z|a,u)
        # Shape of means: (sequence_length, batch_size, z_dim, 1)
        # Shape of covariances: (sequence_length, batch_size, z_dim, z_dim)
        zs_distrib = D.MultivariateNormal(means.view(-1, self.z_dim), covariances.view(-1, self.z_dim, self.z_dim))
        zs_sample = zs_distrib.rsample()
        zs_sample = zs_sample.view(seq_length, batch_size, self.z_dim, 1)

        # ln p_\gamma(a|z)
        kalman_reconst_distrib = D.Normal(mat_Cs @ zs_sample, self.state_space_model.mat_R)
        kalman_reconst_obj = kalman_reconst_distrib.log_prob(as_sample).sum(0).mean(0).sum()

        # -ln p_\gamma(z|a)
        gamma_obj = - zs_distrib.log_prob(zs_sample.view(-1, self.z_dim, 1)).sum(0).mean(0).sum()

        objective = reconstruction_obj + regularization_obj + kalman_reconst_obj + gamma_obj
        
        return objective, {
            'reconstruction': reconstruction_obj,
            'regularization': regularization_obj,
            'kalman_reconst': kalman_reconst_obj,
            'gamma': gamma_obj
        }


In [7]:
image = torch.zeros(8, 16, 1, 32, 32)
image_size = image.shape[-2:]

In [8]:
kvae = KalmanVariationalAutoencoder(image_size, 1, 2, 4, 8)

In [9]:
kvae.objective(image)

ValueError: Expected parameter covariance_matrix (Tensor of shape (128, 4, 4)) of distribution MultivariateNormal(loc: torch.Size([128, 4]), covariance_matrix: torch.Size([128, 4, 4])) to satisfy the constraint PositiveDefinite(), but found invalid values:
tensor([[[ 5.2508e+06, -4.4078e+05, -2.4251e+05,  3.3224e+06],
         [-4.4390e+05,  3.6006e+05,  2.4741e+05, -1.3273e+06],
         [-2.4162e+05,  2.4702e+05, -2.8348e+05, -3.3266e+05],
         [ 3.3238e+06, -1.3240e+06, -3.3330e+05,  2.9675e+06]],

        [[ 2.5314e+06,  4.7952e+04,  6.2980e+05,  1.1008e+06],
         [ 5.0944e+04, -1.4036e+06,  3.3371e+05,  5.0498e+05],
         [ 6.2916e+05,  3.3349e+05,  7.3876e+04,  1.9722e+05],
         [ 1.0998e+06,  5.0288e+05,  1.9744e+05,  2.2086e+05]],

        [[-1.1091e+06,  7.5966e+05,  6.0016e+04, -1.4265e+06],
         [ 7.6165e+05, -5.7683e+05,  4.6904e+04,  1.1930e+06],
         [ 5.9768e+04,  4.6968e+04,  1.3475e+05, -1.4287e+05],
         [-1.4270e+06,  1.1914e+06, -1.4256e+05, -1.9015e+06]],

        ...,

        [[ 2.7008e+03, -6.9132e+02,  8.1735e+02,  1.1021e+03],
         [-6.9133e+02,  1.3783e+03, -3.6175e+02, -6.8818e+02],
         [ 8.1752e+02, -3.6185e+02,  1.8424e+03, -1.4145e+02],
         [ 1.1021e+03, -6.8814e+02, -1.4161e+02,  1.8276e+03]],

        [[ 2.3984e+03, -5.2052e+02,  6.8487e+02,  9.0012e+02],
         [-5.2043e+02,  1.2909e+03, -2.7181e+02, -5.6843e+02],
         [ 6.8438e+02, -2.7148e+02,  1.6041e+03, -1.0321e+02],
         [ 9.0006e+02, -5.6848e+02, -1.0272e+02,  1.6467e+03]],

        [[ 2.6278e+03, -6.6549e+02,  7.7805e+02,  1.0478e+03],
         [-6.6546e+02,  1.3642e+03, -3.6843e+02, -6.5722e+02],
         [ 7.7805e+02, -3.6845e+02,  1.7880e+03, -1.6575e+02],
         [ 1.0478e+03, -6.5721e+02, -1.6578e+02,  1.8003e+03]]],
       grad_fn=<ExpandBackward0>)