In [2]:
import os

import torch
import torch.nn.functional as F
import torch.distributions as D
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [3]:
a_dim = 10

In [4]:
class Encoder(nn.Module):
    def __init__(self, a_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.fc_mean = nn.Linear(in_features=32*4*4, out_features=a_dim)
        self.fc_std = nn.Linear(in_features=32*4*4, out_features=a_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        # print("conv1:", x.shape)
        x = F.relu(self.conv2(x))
        # print("conv2:", x.shape)
        x = F.relu(self.conv3(x))
        # print("conv3:", x.shape)
        x_mean = self.fc_mean(x.view(x.shape[0], -1))
        x_std = F.softplus(self.fc_var(x.view(x.shape[0], -1)))
        return D.Normal(x_mean, x_std)

In [5]:
class Decoder(nn.Module):
    def __init__(self, a_dim, upscale_factor=2):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(in_features=a_dim, out_features=32*4*4)
        self.deconv1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=3, stride=2, padding=1, output_padding=1)
        # self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = F.relu(self.fc(x))
        x = x.view(-1, 32, 4, 4)
        x = F.relu(self.deconv1(x))
        # print("deconv1:", x.shape)
        x = F.relu(self.deconv2(x))
        # print("deconv2:", x.shape)
        x = self.deconv3(x)
        # x = self.pixel_shuffle(x)
        # print("pixel_shuffle:", x.shape)
        return F.sigmoid(x), x

In [21]:
class LSTMModel(nn.Module):
    def __init__(self, a_dim, K):
        super(LSTMModel, self).__init__()
        self.a_dim = a_dim
        self.K = K
    
        self.lstm = nn.LSTM(a_dim, K, batch_first=False)

    def forward(self, x):
        x, h = self.lstm(x)
        x = F.softmax(x, dim=-1)
        return x

## Shape

1. Sequence length
1. Batch size
1. Channel: 1
1. Hight: 32
1. Width: 32

In [14]:
class StateSpaceModel(nn.Module):
    def __init__(self, a_dim, z_dim, K):
        
        super(StateSpaceModel, self).__init__()
        
        self.a_dim = a_dim
        self.z_dim = z_dim
        self.K = K
        
        self.mat_A_K = nn.Parameter(torch.randn(K, z_dim, z_dim))
        self.mat_C_K = nn.Parameter(torch.randn(K, a_dim, z_dim))
        self.mat_Q_L = nn.Parameter(torch.randn(K, z_dim, z_dim))
        self.mat_R_L = nn.Parameter(torch.randn(K, a_dim, a_dim))
        
        self.weight_model = LSTMModel(a_dim, K)
    
    @property
    def mat_Q(self):
        return torch.bmm(self.mat_Q_L, self.mat_Q_L.transpose(-1, -2))
    
    @property
    def mat_R(self):
        return torch.bmm(self.mat_R_L, self.mat_R_L.transpose(-1, -2))
    
    def kalman_filter(self, a):
        
        sequence_length, batch_size, _ = a.size()
        mean_t = torch.zeros(batch_size, self.z_dim, 1)  # Initial state estimate
        cov_t = (torch.eye(self.z_dim)*1e9).unsqueeze(0).repeat(batch_size, 1, 1)  # Initial state covariance
        cov_t_plus = cov_t

        weights = self.weight_model(a)
        means = []
        covariances = []

        mat_As = []
        mat_Cs = []

        means.append(mean_t)
        covariances.append(cov_t)

        for t in range(sequence_length):
            mat_A = torch.bmm(weights[t], self.mat_A_K)
            mat_C = torch.bmm(weights[t], self.mat_C_K)

            mat_As.append(mat_A)
            mat_Cs.append(mat_C)

            # Prediction
            mean_t_plus = torch.bmm(mat_A, mean_t)  # Predicted state estimate


            # Kalman gain
            K_t = torch.bmm(
                torch.bmm(cov_t_plus, mat_C.transpose(1, 2)),
                torch.inverse(
                    torch.bmm(
                        torch.bmm(mat_C, cov_t_plus),
                        mat_C.transpose(1, 2)
                    ) + self.mat_R
                )
            )

            # Update
            mean_t = mean_t_plus + torch.bmm(K_t, (a[t].unsqueeze(2) - torch.bmm(mat_C, mean_t_plus)))  # Updated state estimate
            cov_t = cov_t_plus - torch.bmm(torch.bmm(K_t, mat_C), cov_t_plus)  # Updated state covariance

            means.append(mean_t)
            covariances.append(cov_t)

            cov_t_plus = torch.bmm(torch.bmm(mat_A, cov_t), mat_A.transpose(0, 1)) + self.mat_Q  # Predicted state covariance

        return means, covariances, mat_As, mat_Cs
    
    def kalman_smooth(self, a, filter_means, filter_covariances):
        pass

In [18]:
z_dim = 2
batch_size = 3

mean_t = torch.zeros(batch_size, z_dim, 1)  # Initial state estimate
cov_t = torch.eye(z_dim).unsqueeze(0).repeat(batch_size, 1, 1)  # Initial state covariance

In [20]:
cov_t

tensor([[[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]],

        [[1., 0.],
         [0., 1.]]])

In [None]:
class KalmanVariationalAutoencoder(nn.Module):
    def __init__(self, a_dim, z_dim, K):
        super(KalmanVariationalAutoencoder, self).__init__()
        self.encoder = Encoder(a_dim)
        self.decoder = Decoder(a_dim)
        self.state_space_model = StateSpaceModel(a_dim, z_dim, K)
    
    def objective(self, x):
        seq_length = x.shape[0]
        batch_size = x.shape[1]
        a = self.encoder(x.view(-1, *x.shape[2:])).view(seq_length, batch, *x.shape[2:])
        

In [117]:
image = torch.zeros(1, 3, 32, 32)
encoder = Encoder(a_dim=a_dim)
decoder = Decoder(a_dim=a_dim)
decoder(encoder(image)).shape

torch.Size([1, 3, 32, 32])