In [2]:
import os

import torch
import torch.nn.functional as F
import torch.distributions as D
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [3]:
a_dim = 10

In [4]:
class Encoder(nn.Module):
    def __init__(self, a_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.fc_mean = nn.Linear(in_features=32*4*4, out_features=a_dim)
        self.fc_std = nn.Linear(in_features=32*4*4, out_features=a_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        # print("conv1:", x.shape)
        x = F.relu(self.conv2(x))
        # print("conv2:", x.shape)
        x = F.relu(self.conv3(x))
        # print("conv3:", x.shape)
        x_mean = self.fc_mean(x.view(x.shape[0], -1))
        x_std = F.softplus(self.fc_var(x.view(x.shape[0], -1)))
        return D.Normal(x_mean, x_std)

In [5]:
class Decoder(nn.Module):
    def __init__(self, a_dim, upscale_factor=2):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(in_features=a_dim, out_features=32*4*4)
        self.deconv1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=3, stride=2, padding=1, output_padding=1)
        # self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = F.relu(self.fc(x))
        x = x.view(-1, 32, 4, 4)
        x = F.relu(self.deconv1(x))
        # print("deconv1:", x.shape)
        x = F.relu(self.deconv2(x))
        # print("deconv2:", x.shape)
        x = self.deconv3(x)
        # x = self.pixel_shuffle(x)
        # print("pixel_shuffle:", x.shape)
        return F.sigmoid(x), x

In [6]:
class LSTMModel(nn.Module):
    def __init__(self, a_dim, K):
        super(LSTMModel, self).__init__()
        self.a_dim = a_dim
        self.K = K
    
        self.lstm = nn.LSTM(a_dim, K, batch_first=False)

    def forward(self, x):
        x, h = self.lstm(x)
        x = F.softmax(x, dim=-1)
        return x

## Shape

1. Sequence length
1. Batch size
1. Channel: 1
1. Hight: 32
1. Width: 32

In [7]:
class StateSpaceModel(nn.Module):
    def __init__(self, a_dim, z_dim, K):
        
        super(StateSpaceModel, self).__init__()
        
        self.a_dim = a_dim
        self.z_dim = z_dim
        self.K = K
        
        self.mat_A_K = nn.Parameter(torch.randn(K, z_dim, z_dim))
        self.mat_C_K = nn.Parameter(torch.randn(K, a_dim, z_dim))
        self.mat_Q_L = nn.Parameter(torch.randn(z_dim, z_dim))
        self.mat_R_L = nn.Parameter(torch.randn(a_dim, a_dim))
        
        self.weight_model = LSTMModel(a_dim, K)
        # input shape: (sequence_length, batch_size, a_dim)
        # output shape: (sequence_length, batch_size, K)
    
    @property
    def mat_Q(self):
        # shape: (z_dim, z_dim)
        return self.mat_Q_L @ self.mat_Q_L.transpose()
    
    @property
    def mat_R(self):
        # shape: (a_dim, a_dim)
        return self.mat_R_L @ self.mat_R_L.transpose()
    
    def kalman_filter(self, as_):
        # as_: a_0, a_1, ..., a_{T-1}
        
        sequence_length, batch_size, _ = as_.size()
        
        # Initial state estimate: \hat{z}_{0|-1}
        mean_t_plus = torch.zeros(batch_size, self.z_dim, 1)
        # Initial state covariance: \Sigma_{0|-1}
        cov_t_plus = (torch.eye(self.z_dim)*1e9).unsqueeze(0).repeat(batch_size, 1, 1)  
        
        weights = self.weight_model(as_)
        # Shape of weights is (sequence_length, batch_size, K)
        # Shape of mat_As and mat_Cs is (sequence_length, batch_size, z_dim, z_dim)
        # A_0, A_1, ..., A_{T-1}
        mat_As = torch.einsum('tbk,kij->tbij', weights, self.mat_A_K)
        # C_0, C_1, ..., C_{T-1}
        mat_Cs = torch.einsum('tbk,kij->tbij', weights, self.mat_C_K)
        
        # \hat{z}_{0|0}, \hat{z}_{1|1}, ..., \hat{z}_{T-1|T-1}
        means = []

        # \Sigma_{0|0}, \Sigma_{1|1}, ..., \Sigma_{T-1|T-1}
        covariances = []

        # z_{1|0}, z_{2|1}, ..., z_{T|T-1}
        next_means = []

        # \Sigma_{1|0}, \Sigma_{2|1}, ..., \Sigma_{T|T-1}
        next_covariances = []

        for t in range(sequence_length):

            # Kalman gain
            # K_0, K_1, ..., K_{T-1}
            K_t = cov_t_plus @ mat_Cs[t].transpose(1, 2) @ torch.inverse(mat_Cs[t] @ cov_t_plus @ mat_Cs[t].transpose(1, 2) + self.mat_R)

            # \hat{z}_{0|0}, \hat{z}_{1|1}, ..., \hat{z}_{T-1|T-1}
            mean_t = mean_t_plus + K_t @ (as_[t].unsqueeze(2) - mat_Cs[t] @ mean_t_plus)  # Updated state estimate
            # z_{1|0}, z_{2|1}, ..., z_{T|T-1}
            mean_t_plus = mat_As[t] @ mean_t  # Predicted state estimate

            # \Sigma_{0|0}, \Sigma_{1|1}, ..., \Sigma_{T-1|T-1}
            cov_t = cov_t_plus - K_t @ mat_Cs[t] @ cov_t_plus  # Updated state covariance
            # \Sigma_{1|0}, \Sigma_{2|1}, ..., \Sigma_{T|T-1}
            cov_t_plus = mat_As[t] @ cov_t @ mat_As[t].transpose(1, 2) + self.mat_Q  # Predicted state covariance

            means.append(mean_t)
            covariances.append(cov_t)
            next_means.append(mean_t_plus)
            next_covariances.append(cov_t_plus)

        return means, covariances, next_means, next_covariances, mat_As, mat_Cs
    
    def kalman_smooth(self, as_, filter_means, filter_covariances, filter_next_means, filter_next_covariances, mat_As, mat_Cs):

        sequence_length, batch_size, _ = as_.size()

        means = [filter_means[-1]]  # \hat{z}_{T-1|T-1}
        covariances = [filter_covariances[-1]]  # \Sigma_{T-1|T-1}

        for t in reversed(range(sequence_length - 1)):

            # J_{T-2}, J_{T-3}, ..., J_0
            J_t = filter_covariances[t] @ mat_As[t].transpose(1, 2) @ torch.inverse(filter_next_covariances[t])
            
            # \hat{z}_{T-2}, \hat{z}_{T-3}, ..., \hat{z}_0
            mean_t = filter_means[t] + J_t @ (means[0] - filter_next_means[t])
            # \Sigma_{T-2}, \Sigma_{T-3}, ..., \Sigma_0
            cov_t = filter_covariances[t] + J_t @ (covariances[0] - filter_next_covariances[t]) @ J_t.transpose(1, 2)

            means.insert(0, mean_t)
            covariances.insert(0, cov_t)

        return means, covariances

In [8]:
z_dim = 2
batch_size = 3

mean_t = torch.zeros(batch_size, z_dim, 1)  # Initial state estimate
cov_t = torch.eye(z_dim).unsqueeze(0).repeat(batch_size, 1, 1)  # Initial state covariance

In [9]:
cov_t

tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]])

In [10]:
class KalmanVariationalAutoencoder(nn.Module):
    def __init__(self, a_dim, z_dim, K):
        super(KalmanVariationalAutoencoder, self).__init__()
        self.encoder = Encoder(a_dim)
        self.decoder = Decoder(a_dim)
        self.state_space_model = StateSpaceModel(a_dim, z_dim, K)
    
    def objective(self, x):
        seq_length = x.shape[0]
        batch_size = x.shape[1]
        a = self.encoder(x.view(-1, *x.shape[2:])).view(seq_length, batch, *x.shape[2:])
        

In [117]:
image = torch.zeros(1, 3, 32, 32)
encoder = Encoder(a_dim=a_dim)
decoder = Decoder(a_dim=a_dim)
decoder(encoder(image)).shape

torch.Size([1, 3, 32, 32])