In [102]:
import os

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [103]:
a_dim = 10

In [113]:
class Encoder(nn.Module):
    def __init__(self, a_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(in_features=32*4*4, out_features=a_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        # print("conv1:", x.shape)
        x = F.relu(self.conv2(x))
        # print("conv2:", x.shape)
        x = F.relu(self.conv3(x))
        # print("conv3:", x.shape)
        x = self.fc(x.view(x.shape[0], -1))
        return x

In [118]:
class Decoder(nn.Module):
    def __init__(self, a_dim, upscale_factor=2):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(in_features=a_dim, out_features=32*4*4)
        self.deconv1 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=3, stride=2, padding=1, output_padding=1)
        # self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = F.relu(self.fc(x))
        x = x.view(-1, 32, 4, 4)
        x = F.relu(self.deconv1(x))
        # print("deconv1:", x.shape)
        x = F.relu(self.deconv2(x))
        # print("deconv2:", x.shape)
        x = self.deconv3(x)
        # x = self.pixel_shuffle(x)
        # print("pixel_shuffle:", x.shape)
        return F.sigmoid(x), x

## Shape

* Batch size
* Sequence length
* Channel: 1
* Hight: 32
* Width: 32

In [120]:
class StateSpaceModel(nn.Module):
    def __init__(self, a_dim, z_dim, K):
        super(LinearGaussianStateSpaceModel, self).__init__()
        self.mat_A = nn.Parameter(torch.randn(K, z_dim, z_dim))
        self.mat_C = nn.Parameter(torch.randn(K, a_dim, z_dim))
        self.mat_Q_L = nn.Parameter(torch.randn(K, z_dim, z_dim))
        self.mat_R_L = nn.Parameter(torch.randn(K, a_dim, a_dim))
    
    @property
    def mat_Q(self):
        return torch.bmm(self.mat_Q_L, self.mat_Q_L.transpose(-1, -2))
    
    @property
    def mat_R(self):
        return torch.bmm(self.mat_R_L, self.mat_R_L.transpose(-1, -2))

In [117]:
image = torch.zeros(1, 3, 32, 32)
encoder = Encoder(a_dim=a_dim)
decoder = Decoder(a_dim=a_dim)
decoder(encoder(image)).shape

torch.Size([1, 3, 32, 32])