# Deep Markov Model
* Original paper: Structured Inference Networks for Nonlinear State Space Models (https://arxiv.org/abs/1609.09869)
* Original code: https://github.com/clinicalml/dmm

 ## Deep Markov Model summary
>Deep Markov models (DMM), a class of
generative models where classic linear emission and transition distributions are replaced with complex multi-layer
perceptrons (MLPs).


In [1]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

In [2]:
batch_size = 256
epochs = 25
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x7f3a96684d90>

In [3]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [4]:
# Experiment Setting
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

## Define probability distributions
### In the original paper
Prior(Transition model, equation(1) in the paper): $z_{t} \sim  \cal{N}\left(G_{\alpha}\left(z_{t-1}, \Delta_{t}\right), S_{\beta}\left(z_{t-1}, \Delta_{t}\right)\right)$  
Generator(Emission, equation(2) in the paper): $x_{t} \sim \Pi\left(F_{\kappa}\left(z_{t}\right)\right)$  
Inference(equation(5) in the paper): $\begin{aligned} q_{\phi}\left(z_{t} | z_{t-1}, x_{t}, \ldots, x_{T}\right) = \mathcal{N}\left(\mu_{\phi}\left(z_{t-1}, x_{t}, \ldots, x_{T}\right),\right.&\left.\Sigma_{\phi}\left(z_{t-1}, x_{t}, \ldots, x_{T}\right)\right) \end{aligned}$

### MNIST settings
Prior(Transition model): $p_{\theta}(z_{t} | z_{t-1}) =  \cal{N}(\mu = f_{prior_\mu}(z_{t-1}), \sigma^2 = f_{prior_\sigma^2}(z_{t-1})$    
Generator(Emission): $p_{\theta}(x | z)=\mathscr{B}\left(x ; \lambda=g_{x}(z)\right)$  

RNN: $p(h) = RNN(x)$  
Inference(Combiner): $p_{\phi}(z | h, z_{t-1}) = \cal{N}(\mu = f_{\mu}(h, z_{t-1}), \sigma^2 = f_{\sigma^2}(h, z_{t-1})$

In [5]:
from pixyz.utils import print_latex
from pixyz.distributions import Bernoulli, Normal, Deterministic

In [6]:
x_dim = 28
h_dim = 32
hidden_dim = 32
z_dim = 16
t_max = x_dim

In [7]:
# RNN
class RNN(Deterministic):
    """
    h = RNN(x)
    Given observed x, RNN output hidden state
    """
    def __init__(self):
        super(RNN, self).__init__(cond_var=["x"], var=["h"])
        self.rnn = nn.GRU(x_dim, h_dim, bidirectional=True)
#         self.h0 = torch.zeros(2, batch_size, self.rnn.hidden_size).to(device)
        self.h0 = nn.Parameter(torch.zeros(2, 1, self.rnn.hidden_size))
        self.hidden_size = self.rnn.hidden_size
        
    def forward(self, x):
        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        # x(Batch_size, Time, Features)
        h0 = self.h0.expand(2, x.size(1), self.rnn.hidden_size).contiguous()
        h, _ = self.rnn(x, h0)
        return {"h": h}

In [8]:
# Emission p(x_t | z_t)
class Generator(Bernoulli):
    """
    Given the latent z at time step t, return the vector of
    probabilities that parameterizes the bernlulli distribution p(x_t | x_t)
    """
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, x_dim)
    
    def forward(self, z):
        h = F.relu(self.fc1(z))
        return {"probs": torch.sigmoid(self.fc2(h))}

In [9]:
# Combiner q(z_t | z_{t-1}, x_{1:T})
class Inference(Normal):
    """
    given the latent z at time step t-1 and the hidden state of the RNN h(x_{0:T} 
    return the loc and scale vectors that
    parameterize the gaussian distribution q(z_t | z_{t-1}, x_{t:T})
    """
    def __init__(self):
        super(Inference, self).__init__(cond_var=["h", "z_prev"], var=["z"])
        self.fc1 = nn.Linear(z_dim, h_dim*2)
        self.fc21 = nn.Linear(h_dim*2, z_dim)
        self.fc22 = nn.Linear(h_dim*2, z_dim)

        
    def forward(self, h, z_prev):
        h_z = torch.tanh(self.fc1(z_prev))
        h = 0.5 * (h + h_z)
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [10]:
# Transition model p(z_t | z_{t-1})
class Prior(Normal):
    """
    Given the latent variable at the time step t-1
    return the mean and scale vectors that parameterize the
    gaussian distribution p(z_t | z_{t-1})
    """
    def __init__(self):
        super(Prior, self).__init__(cond_var=["z_prev"], var=["z"])
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        
    def forward(self, z_prev):
        h = F.relu(self.fc1(z_prev))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [11]:
prior = Prior().to(device)
encoder = Inference().to(device)
decoder = Generator().to(device)
rnn = RNN().to(device)

In [12]:
print(prior)
print("*"*80)
print(encoder)
print("*"*80)
print(decoder)
print("*"*80)
print(rnn)

Distribution:
  p(z|z_{prev})
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev'], input_var=['z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=16, out_features=32, bias=True)
    (fc21): Linear(in_features=32, out_features=16, bias=True)
    (fc22): Linear(in_features=32, out_features=16, bias=True)
  )
********************************************************************************
Distribution:
  p(z|h,z_{prev})
Network architecture:
  Inference(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['h', 'z_prev'], input_var=['h', 'z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=16, out_features=64, bias=True)
    (fc21): Linear(in_features=64, out_features=16, bias=True)
    (fc22): Linear(in_features=64, out_features=16, bias=True)
  )
********************************************************************************
Distribution:
  p(x|z)
Network architecture:
  Generator(
    na

In [13]:
generate_from_prior = prior * decoder
print(generate_from_prior)
print_latex(generate_from_prior)

Distribution:
  p(x,z|z_{prev}) = p(x|z)p(z|z_{prev})
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev'], input_var=['z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=16, out_features=32, bias=True)
    (fc21): Linear(in_features=32, out_features=16, bias=True)
    (fc22): Linear(in_features=32, out_features=16, bias=True)
  )
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=16, out_features=32, bias=True)
    (fc2): Linear(in_features=32, out_features=28, bias=True)
  )


<IPython.core.display.Math object>

## Define Loss function
### In the original paper(equation (6) in the paper)
${\mathcal{L}(\vec{x} ;(\theta, \phi))=\sum_{t=1}^{T} \underset{q_{\phi}\left(z_{t} | \vec{x}\right)}{\mathbb{E}}\left[\log p_{\theta}\left(x_{t} | z_{t}\right)\right]} {-\operatorname{KL}\left(q_{\phi}\left(z_{1} | \vec{x}\right) | p_{\theta}\left(z_{1}\right)\right)}  {-\sum_{t=2}^{T} \underset{q_{\phi}\left(z_{t-1} | \vec{x}\right)}{\mathbb{E}}\left[\operatorname{KL}\left(q_{\phi}\left(z_{t} | z_{t-1}\vec{x}\right) | p_{\theta}\left(z_{t} | z_{t-1}\right)\right)\right]}$

In [14]:
from pixyz.losses import KullbackLeibler
from pixyz.losses import Expectation as E
from pixyz.losses import LogProb
from pixyz.losses import IterativeLoss
from pixyz.losses import CrossEntropy

In [15]:
step_loss = CrossEntropy(encoder, decoder) + KullbackLeibler(encoder, prior)
_loss = IterativeLoss(step_loss, max_iter=t_max, 
                      series_var=["x", "h"], update_value={"z": "z_prev"})
loss = _loss.expectation(rnn).mean()

## Define DMM model using Model class

In [16]:
from pixyz.models import Model

dmm = Model(loss, distributions=[rnn, encoder, decoder, prior], optimizer=optim.Adam, optimizer_params={'lr': 3e-3})

In [17]:
print(dmm)
print_latex(dmm)

Distributions (for training): 
  p(h|x), p(z|h,z_{prev}), p(x|z), p(z|z_{prev}) 
Loss function: 
  mean \left(\mathbb{E}_{p(h|x)} \left[\sum_{t=1}^{28} \left(D_{KL} \left[p(z|h,z_{prev})||p(z|z_{prev}) \right] - \mathbb{E}_{p(z|h,z_{prev})} \left[\log p(x|z) \right]\right) \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.003
      weight_decay: 0
  )


<IPython.core.display.Math object>

## Define Train and Test loop using model

In [18]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        batch_size = data.size()[0]
        x = data.transpose(0, 1)
        z_prev = torch.zeros(batch_size, z_dim).to(device)
        #q_z_prev = torch.zeros(batch_size, z_dim).to(device)
        if train_mode:
            mean_loss += model.train({'x': x, 'z_prev': z_prev}).item() * batch_size
        else:
            mean_loss += model.test({'x': x, 'z_prev': z_prev}).item() * batch_size
    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

## Reconstruction and generation

In [19]:
def plot_image_from_latent(batch_size):
    x = []
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for t in range(t_max):
        samples = generate_from_prior.sample({'z_prev': z_prev})
        x_t = decoder.sample_mean({"z": samples["z"]})
        z_prev = samples["z"]
        x.append(x_t[None, :])
    x = torch.cat(x, dim=0).transpose(0, 1)
    return x


def reconstruct(data):
    xs = []
    data = data.to(device)
    x = data.transpose(0, 1)
    batch_size = data.size()[0]
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for t in range(t_max):
        h_t = rnn.sample_mean({'x': x})[t]
        z_t = encoder.sample_mean({'h': h_t, 'z_prev': z_prev})
        dec_x = decoder.sample_mean({'z': z_t})
        z_prev = z_t
        xs.append(dec_x[None, :])
    reconst_img = torch.cat(xs, dim=0).transpose(0, 1)
    return reconst_img


def sample_after_n_steps(num_step, data):
    xs = []
    data = data.to(device)
    x = data.transpose(0, 1)
    batch_size = data.size()[0]
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for t in range(t_max):
        if t+1 <  num_step:
            h_t = rnn.sample_mean({'x': x})[t]
            z_t = encoder.sample_mean({'h': h_t, 'z_prev': z_prev})
            dec_x = decoder.sample_mean({'z': z_t})
            z_prev = z_t
            xs.append(dec_x[None, :])
        else:
            z_t = prior.sample({'z_prev':z_prev})["z"]
            dec_x = decoder.sample_mean({'z': z_t})
            z_prev = z_t
            xs.append(dec_x[None, :])
    sample_img = torch.cat(xs, dim=0).transpose(0, 1)
    return sample_img
            

In [20]:
writer = SummaryWriter(comment='DMM_masa_original_cross')
_x, _ = iter(test_loader).next()
_x = _x.to(device)

for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, dmm, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, dmm, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = plot_image_from_latent(batch_size)[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    
    reconst_img = reconstruct(_x)[:, None]
    writer.add_images('Reconstructed', reconst_img, epoch)
    
    generated_img_7 = sample_after_n_steps(7, _x)[:, None]
    writer.add_images('Generate_after_7steps', generated_img_7, epoch)
    
    generated_img_14 = sample_after_n_steps(14, _x)[:, None]
    writer.add_images('Generate_after_14steps', generated_img_14, epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 235/235 [00:34<00:00,  6.79it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 1 Train loss: 255.1766


100%|██████████| 40/40 [00:02<00:00, 14.83it/s]


Test loss: 204.9264


100%|██████████| 235/235 [00:34<00:00,  6.74it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 2 Train loss: 192.0579


100%|██████████| 40/40 [00:03<00:00, 12.35it/s]


Test loss: 184.1096


100%|██████████| 235/235 [00:34<00:00,  6.79it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 3 Train loss: 183.0376


100%|██████████| 40/40 [00:02<00:00, 14.60it/s]


Test loss: 180.6062


100%|██████████| 235/235 [00:35<00:00,  6.64it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 4 Train loss: 180.9568


100%|██████████| 40/40 [00:02<00:00, 14.72it/s]


Test loss: 179.3476


100%|██████████| 235/235 [00:34<00:00,  6.78it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 5 Train loss: 179.8904


100%|██████████| 40/40 [00:02<00:00, 14.03it/s]


Test loss: 178.4508


100%|██████████| 235/235 [00:34<00:00,  6.85it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 6 Train loss: 176.8495


100%|██████████| 40/40 [00:02<00:00, 14.78it/s]


Test loss: 174.6067


100%|██████████| 235/235 [00:34<00:00,  6.79it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 7 Train loss: 175.1140


100%|██████████| 40/40 [00:02<00:00, 14.85it/s]


Test loss: 174.0099


100%|██████████| 235/235 [00:35<00:00,  6.67it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 8 Train loss: 174.4934


100%|██████████| 40/40 [00:02<00:00, 13.78it/s]


Test loss: 173.4077


100%|██████████| 235/235 [00:35<00:00,  6.54it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 9 Train loss: 174.0542


100%|██████████| 40/40 [00:02<00:00, 14.44it/s]


Test loss: 173.2236


100%|██████████| 235/235 [00:36<00:00,  6.52it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 10 Train loss: 173.7850


100%|██████████| 40/40 [00:02<00:00, 14.49it/s]


Test loss: 173.0226


100%|██████████| 235/235 [00:35<00:00,  6.69it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 11 Train loss: 173.4782


100%|██████████| 40/40 [00:03<00:00, 12.58it/s]


Test loss: 172.6737


100%|██████████| 235/235 [00:35<00:00,  6.68it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 12 Train loss: 173.2920


100%|██████████| 40/40 [00:02<00:00, 15.06it/s]


Test loss: 172.2236


100%|██████████| 235/235 [00:34<00:00,  6.71it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 13 Train loss: 173.1253


100%|██████████| 40/40 [00:02<00:00, 13.44it/s]


Test loss: 172.2413


100%|██████████| 235/235 [00:35<00:00,  6.69it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 14 Train loss: 172.9210


100%|██████████| 40/40 [00:02<00:00, 14.92it/s]


Test loss: 172.0694


100%|██████████| 235/235 [00:34<00:00,  6.78it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 15 Train loss: 172.8624


100%|██████████| 40/40 [00:02<00:00, 13.92it/s]


Test loss: 171.8137


100%|██████████| 235/235 [00:34<00:00,  6.90it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 16 Train loss: 172.7550


100%|██████████| 40/40 [00:02<00:00, 13.52it/s]


Test loss: 172.2008


100%|██████████| 235/235 [00:34<00:00,  6.86it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 17 Train loss: 172.6063


100%|██████████| 40/40 [00:02<00:00, 14.37it/s]


Test loss: 171.8161


100%|██████████| 235/235 [00:35<00:00,  6.59it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 18 Train loss: 172.5269


100%|██████████| 40/40 [00:03<00:00, 12.88it/s]


Test loss: 171.5653


100%|██████████| 235/235 [00:35<00:00,  6.69it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 19 Train loss: 172.4226


100%|██████████| 40/40 [00:03<00:00, 13.23it/s]


Test loss: 171.5906


100%|██████████| 235/235 [00:34<00:00,  6.73it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 20 Train loss: 172.3534


100%|██████████| 40/40 [00:02<00:00, 14.56it/s]


Test loss: 171.6463


100%|██████████| 235/235 [00:35<00:00,  6.68it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 21 Train loss: 172.2563


100%|██████████| 40/40 [00:03<00:00, 13.27it/s]


Test loss: 171.5527


100%|██████████| 235/235 [00:34<00:00,  6.80it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 22 Train loss: 172.3164


100%|██████████| 40/40 [00:02<00:00, 15.25it/s]


Test loss: 171.6282


100%|██████████| 235/235 [00:36<00:00,  6.53it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 23 Train loss: 172.0883


100%|██████████| 40/40 [00:03<00:00, 11.78it/s]


Test loss: 171.2768


100%|██████████| 235/235 [00:34<00:00,  6.77it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 24 Train loss: 172.1598


100%|██████████| 40/40 [00:02<00:00, 14.85it/s]


Test loss: 171.1873


100%|██████████| 235/235 [00:35<00:00,  6.67it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 25 Train loss: 171.9704


100%|██████████| 40/40 [00:02<00:00, 14.28it/s]


Test loss: 171.2867


## Image label conditioning on latent variable z

### MNIST settings
Prior(Transition model): $p_{\theta}(z_{t} | z_{t-1}, u) =  \cal{N}(\mu = f_{prior_\mu}(z_{t-1}, u), \sigma^2 = f_{prior_\sigma^2}(z_{t-1}, u)$    
Generator(Emission): $p_{\theta}(x | z)=\mathscr{B}\left(x ; \lambda=g_{x}(z)\right)$  

RNN: $p(h) = RNN(x)$  
Inference(Combiner): $p_{\phi}(z | h, z_{t-1}, u) = \cal{N}(\mu = f_{\mu}(h, z_{t-1}, u), \sigma^2 = f_{\sigma^2}(h, z_{t-1}, u)$

In [21]:
x_dim = 28
h_dim = 32
hidden_dim = 32
z_dim = 16
t_max = x_dim

# label dim
u_dim = 10

In [22]:
# RNN
class RNN(Deterministic):
    """
    h = RNN(x)
    Given observed x, RNN output hidden state
    """
    def __init__(self):
        super(RNN, self).__init__(cond_var=["x"], var=["h"])
        self.rnn = nn.GRU(x_dim, h_dim, bidirectional=True)
#         self.h0 = torch.zeros(2, batch_size, self.rnn.hidden_size).to(device)
        self.h0 = nn.Parameter(torch.zeros(2, 1, self.rnn.hidden_size))
        self.hidden_size = self.rnn.hidden_size
        
    def forward(self, x):
        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        # x(Batch_size, Time, Features)
        h0 = self.h0.expand(2, x.size(1), self.rnn.hidden_size).contiguous()
        h, _ = self.rnn(x, h0)
        return {"h": h}


# Emission p(x_t | z_t)
class Generator(Bernoulli):
    """
    Given the latent z at time step t, return the vector of
    probabilities that parameterizes the bernlulli distribution p(x_t | x_t)
    """
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, x_dim)
    
    def forward(self, z):
        h = F.relu(self.fc1(z))
        return {"probs": torch.sigmoid(self.fc2(h))}


class Inference(Normal):
    """
    given the latent z at time step t-1, the hidden state of the RNN h(x_{0:T} and u
    return the loc and scale vectors that
    parameterize the gaussian distribution q(z_t | z_{t-1}, x_{t:T}, u)
    """
    def __init__(self):
        super(Inference, self).__init__(cond_var=["h", "z_prev", "u"], var=["z"])
        self.fc1 = nn.Linear(z_dim+u_dim, h_dim*2)
        self.fc21 = nn.Linear(h_dim*2, z_dim)
        self.fc22 = nn.Linear(h_dim*2, z_dim)

        
    def forward(self, h, z_prev, u):
        feature = torch.cat((z_prev, u), 1)
        h_z = torch.tanh(self.fc1(feature))
        h = 0.5 * (h + h_z)
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}


class Prior(Normal):
    """
    Given the latent variable at the time step t-1 and u,
    return the mean and scale vectors that parameterize the
    gaussian distribution p(z_t | z_{t-1}, u)
    """
    def __init__(self):
        super(Prior, self).__init__(cond_var=["z_prev", "u"], var=["z"])
        self.fc1 = nn.Linear(z_dim+u_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        
    def forward(self, z_prev, u):
        feature = torch.cat((z_prev, u), 1)
        h = F.relu(self.fc1(feature))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [23]:
prior = Prior().to(device)
encoder = Inference().to(device)
decoder = Generator().to(device)
rnn = RNN().to(device)


step_loss = CrossEntropy(encoder, decoder) + KullbackLeibler(encoder, prior)
_loss = IterativeLoss(step_loss, max_iter=t_max, 
                      series_var=["x", "h", "u"], update_value={"z": "z_prev"})
loss = _loss.expectation(rnn).mean()

In [24]:
dmm = Model(loss, distributions=[rnn, encoder, decoder, prior], optimizer=optim.Adam, optimizer_params={'lr': 3e-3})

In [25]:
# check label dims
x, label = iter(train_loader).next()
label = torch.eye(10)[label]
print(x.shape)
print(label.shape)
# copy labels for each time step
print(torch.stack([label for num in range(28)]).shape)

torch.Size([256, 28, 28])
torch.Size([256, 10])
torch.Size([28, 256, 10])


In [26]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, label) in enumerate(tqdm(loader)):
        data = data.to(device)
        label = torch.eye(10)[label].to(device)
    
        batch_size = data.size()[0]
        # convert to (timestep, batch_size, feature)
        x = data.transpose(0, 1)
        label = torch.stack([label for num in range(t_max)])
        
        z_prev = torch.zeros(batch_size, z_dim).to(device)
        #q_z_prev = torch.zeros(batch_size, z_dim).to(device)
        if train_mode:
            mean_loss += model.train({'x': x, 'z_prev': z_prev, 'u': label}).item() * batch_size
        else:
            mean_loss += model.test({'x': x, 'z_prev': z_prev, 'u': label}).item() * batch_size
    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [27]:
def plot_image_from_latent():
    # plot 100 images, 10 images each label
    sample_num = 100 
    x = []
    label = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]*10)
    label = torch.eye(10)[label].to(device)
    label = torch.stack([label for num in range(t_max)])
    
    z_prev = torch.zeros(sample_num, z_dim).to(device)
    for step in range(t_max):
        z_t = prior.sample({'z_prev':z_prev, 'u': label[t]})["z"]
        dec_x = decoder.sample_mean({'z': z_t})
        z_prev = z_t
        x.append(x_t[None, :])
    x = torch.cat(x, dim=0).transpose(0, 1)
    return x


def reconstruct(data, label):
    xs = []
    data = data.to(device)
    label = torch.eye(10)[label].to(device)
    
    x = data.transpose(0, 1)
    label = torch.stack([label for num in range(t_max)])
    batch_size = data.size()[0]
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for t in range(t_max):
        h_t = rnn.sample_mean({'x': x})[t]
        z_t = encoder.sample_mean({'h': h_t, 'z_prev': z_prev, 'u': label[t]})
        dec_x = decoder.sample_mean({'z': z_t})
        z_prev = z_t
        xs.append(dec_x[None, :])
    reconst_img = torch.cat(xs, dim=0).transpose(0, 1)
    return reconst_img


def sample_after_n_steps(num_step, data, label):
    xs = []
    data = data.to(device)
    label = torch.eye(10)[label].to(device)
    
    x = data.transpose(0, 1)
    label = torch.stack([label for num in range(t_max)])
    batch_size = data.size()[0]
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for t in range(t_max):
        if t+1 <  num_step:
            h_t = rnn.sample_mean({'x': x})[t]
            z_t = encoder.sample_mean({'h': h_t, 'z_prev': z_prev, 'u': label[t]})
            dec_x = decoder.sample_mean({'z': z_t})
            z_prev = z_t
            xs.append(dec_x[None, :])
        else:
            z_t = prior.sample({'z_prev':z_prev, 'u': label[t]})["z"]
            dec_x = decoder.sample_mean({'z': z_t})
            z_prev = z_t
            xs.append(dec_x[None, :])
    sample_img = torch.cat(xs, dim=0).transpose(0, 1)
    return sample_img
            

In [28]:
writer = SummaryWriter(comment='DMM_masa_original_cross_action_conditional')
_x, _label = iter(test_loader).next()
_x = _x.to(device)


for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, dmm, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, dmm, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = plot_image_from_latent()[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    
    reconst_img = reconstruct(_x, _label)[:, None]
    writer.add_images('Reconstructed', reconst_img, epoch)
    
    generated_img_7 = sample_after_n_steps(7, _x, _label)[:, None]
    writer.add_images('Generate_after_7steps', generated_img_7, epoch)
    
    generated_img_14 = sample_after_n_steps(14, _x, _label)[:, None]
    writer.add_images('Generate_after_14steps', generated_img_14, epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 235/235 [00:36<00:00,  6.43it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 1 Train loss: 250.0728


100%|██████████| 40/40 [00:03<00:00, 13.21it/s]

Test loss: 196.4350





NameError: name 't' is not defined