In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
from Tars.distributions import Normal, Bernoulli
from Tars.distributions.divergences import KullbackLeibler
from Tars.models import VAE

In [3]:
class Conv2dLSTMCell(nn.Module):
    """
    2d convolutional long short-term memory (LSTM) cell.
    Functionally equivalent to nn.LSTMCell with the
    difference being that nn.Kinear layers are replaced
    by nn.Conv2D layers.
    :param in_channels: number of input channels
    :param out_channels: number of output channels
    :param kernel_size: size of image kernel
    :param stride: length of kernel stride
    :param padding: number of pixels to pad with
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(Conv2dLSTMCell, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        kwargs = dict(kernel_size=kernel_size, stride=stride, padding=padding)

        self.forget = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.input  = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.output = nn.Conv2d(in_channels, out_channels, **kwargs)
        self.state  = nn.Conv2d(in_channels, out_channels, **kwargs)

    def forward(self, input, states):
        """
        Send input through the cell.
        :param input: input to send through
        :param states: (hidden, cell) pair of internal state
        :return new (hidden, cell) pair
        """
        (hidden, cell) = states

        forget_gate = F.sigmoid(self.forget(input))
        input_gate  = F.sigmoid(self.input(input))
        output_gate = F.sigmoid(self.output(input))
        state_gate  = F.tanh(self.state(input))

        # Update internal cell state
        cell = forget_gate * cell + input_gate * state_gate
        hidden = output_gate * F.tanh(cell)

        return hidden, cell

In [11]:
x_dim = 3
v_dim = 7
h_dim = 128
r_dim = 256
z_dim = 64

In [22]:
# inference model q(z|x_q,v_q,r)
class Inference(Normal):
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x_q","v_q", "r"], var=["z"])
        self.mean = nn.Conv2d(h_dim,z_dim,kernel_size=5,stride=1,padding=2)
        self.var = nn.Conv2d(h_dim,z_dim,kernel_size=5,stride=1,padding=2)
        
    def forward(self, h_e):
        return {"loc": self.mean(h_e), "var": self.var(h_e)}

        
# generative model g(x_q|z,v_q,r)
class Generator(Normal):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z","v_q", "r"], var=["x_q"])
        self.mean = nn.Conv2d(h_dim, x_dim, kernel_size=1, stride=1, padding=0)
    def forward(self, u, sigma_t):
        return {"loc": self.mean(u), "var": sigma_t}

        
# prior pi(z|v_q,r)
class Prior(Normal):
    def __init__(self):
        super(Prior, self).__init__(cond_var=["v_q", "r"], var=["z"])
        self.mean = nn.Conv2d(h_dim,z_dim,kernel_size=5,stride=1,padding=2)
        self.var = nn.Conv2d(h_dim,z_dim,kernel_size=5,stride=1,padding=2)
        
    def forward(self, h_g):
        return {"loc": self.mean(h_g), "var": self.var(h_g)}


In [19]:
pi = Prior()
q = Inference()

In [20]:
pi.prob_text

'p(z|v_q,r)'

In [21]:
q.prob_text

'p(z|x_q,v_q,r)'