# Deep Markov Model
* Original paper: Structured Inference Networks for Nonlinear State Space Models (https://arxiv.org/abs/1609.09869)
* Original code: https://github.com/clinicalml/dmm

 ## Deep Markov Model summary
>Deep Markov models (DMM), a class of
generative models where classic linear emission and transition distributions are replaced with complex multi-layer
perceptrons (MLPs).


In [1]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

In [2]:
batch_size = 256
epochs = 25
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x7fc1dd8a7f10>

In [3]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [4]:
# Experiment Setting
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

In [5]:
import pixyz
print(pixyz.__version__)

0.1.3


## Define probability distributions
### In the original paper
Prior(Transition model, equation(1) in the paper): $z_{t} \sim  \cal{N}\left(G_{\alpha}\left(z_{t-1}, \Delta_{t}\right), S_{\beta}\left(z_{t-1}, \Delta_{t}\right)\right)$  
Generator(Emission, equation(2) in the paper): $x_{t} \sim \Pi\left(F_{\kappa}\left(z_{t}\right)\right)$  
Inference(equation(5) in the paper): $\begin{aligned} q_{\phi}\left(z_{t} | z_{t-1}, x_{t}, \ldots, x_{T}\right) = \mathcal{N}\left(\mu_{\phi}\left(z_{t-1}, x_{t}, \ldots, x_{T}\right),\right.&\left.\Sigma_{\phi}\left(z_{t-1}, x_{t}, \ldots, x_{T}\right)\right) \end{aligned}$

### MNIST settings
Prior(Transition model): $p_{\theta}(z_{t} | z_{t-1}) =  \cal{N}(\mu = f_{prior_\mu}(z_{t-1}), \sigma^2 = f_{prior_\sigma^2}(z_{t-1})$    
Generator(Emission): $p_{\theta}(x | z)=\mathscr{B}\left(x ; \lambda=g_{x}(z)\right)$  

RNN: $p(h) = RNN(x)$  
Inference(Combiner): $p_{\phi}(z | h, z_{t-1}) = \cal{N}(\mu = f_{\mu}(h, z_{t-1}), \sigma^2 = f_{\sigma^2}(h, z_{t-1})$

In [6]:
from pixyz.utils import print_latex
from pixyz.distributions import Bernoulli, Normal, Deterministic

In [7]:
x_dim = 28
h_dim = 32
hidden_dim = 32
z_dim = 16
t_max = x_dim

In [8]:
# RNN
class RNN(Deterministic):
    """
    h = RNN(x)
    Given observed x, RNN output hidden state
    """
    def __init__(self):
        super(RNN, self).__init__(cond_var=["x"], var=["h"])
        self.rnn = nn.GRU(x_dim, h_dim, bidirectional=True)
#         self.h0 = torch.zeros(2, batch_size, self.rnn.hidden_size).to(device)
        self.h0 = nn.Parameter(torch.zeros(2, 1, self.rnn.hidden_size))
        self.hidden_size = self.rnn.hidden_size
        
    def forward(self, x):
        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        # x(Batch_size, Time, Features)
        h0 = self.h0.expand(2, x.size(1), self.rnn.hidden_size).contiguous()
        h, _ = self.rnn(x, h0)
        return {"h": h}

In [9]:
# Emission p(x_t | z_t)
class Generator(Bernoulli):
    """
    Given the latent z at time step t, return the vector of
    probabilities that parameterizes the bernlulli distribution p(x_t | x_t)
    """
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, x_dim)
    
    def forward(self, z):
        h = F.relu(self.fc1(z))
        return {"probs": torch.sigmoid(self.fc2(h))}

In [10]:
# Combiner q(z_t | z_{t-1}, x_{1:T})
class Inference(Normal):
    """
    given the latent z at time step t-1 and the hidden state of the RNN h(x_{0:T} 
    return the loc and scale vectors that
    parameterize the gaussian distribution q(z_t | z_{t-1}, x_{t:T})
    """
    def __init__(self):
        super(Inference, self).__init__(cond_var=["h", "z_prev"], var=["z"])
        self.fc1 = nn.Linear(z_dim, h_dim*2)
        self.fc21 = nn.Linear(h_dim*2, z_dim)
        self.fc22 = nn.Linear(h_dim*2, z_dim)

        
    def forward(self, h, z_prev):
        h_z = torch.tanh(self.fc1(z_prev))
        h = 0.5 * (h + h_z)
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [11]:
# Transition model p(z_t | z_{t-1})
class Prior(Normal):
    """
    Given the latent variable at the time step t-1
    return the mean and scale vectors that parameterize the
    gaussian distribution p(z_t | z_{t-1})
    """
    def __init__(self):
        super(Prior, self).__init__(cond_var=["z_prev"], var=["z"])
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        
    def forward(self, z_prev):
        h = F.relu(self.fc1(z_prev))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [12]:
prior = Prior().to(device)
encoder = Inference().to(device)
decoder = Generator().to(device)
rnn = RNN().to(device)

In [13]:
print(prior)
print("*"*80)
print(encoder)
print("*"*80)
print(decoder)
print("*"*80)
print(rnn)

Distribution:
  p(z|z_{prev})
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev'], input_var=['z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=16, out_features=32, bias=True)
    (fc21): Linear(in_features=32, out_features=16, bias=True)
    (fc22): Linear(in_features=32, out_features=16, bias=True)
  )
********************************************************************************
Distribution:
  p(z|h,z_{prev})
Network architecture:
  Inference(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['h', 'z_prev'], input_var=['h', 'z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=16, out_features=64, bias=True)
    (fc21): Linear(in_features=64, out_features=16, bias=True)
    (fc22): Linear(in_features=64, out_features=16, bias=True)
  )
********************************************************************************
Distribution:
  p(x|z)
Network architecture:
  Generator(
    na

In [14]:
generate_from_prior = prior * decoder
print(generate_from_prior)
print_latex(generate_from_prior)

Distribution:
  p(x,z|z_{prev}) = p(x|z)p(z|z_{prev})
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev'], input_var=['z_prev'], features_shape=torch.Size([])
    (fc1): Linear(in_features=16, out_features=32, bias=True)
    (fc21): Linear(in_features=32, out_features=16, bias=True)
    (fc22): Linear(in_features=32, out_features=16, bias=True)
  )
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (fc1): Linear(in_features=16, out_features=32, bias=True)
    (fc2): Linear(in_features=32, out_features=28, bias=True)
  )


<IPython.core.display.Math object>

## Define Loss function
### In the original paper(equation (6) in the paper)
${\mathcal{L}(\vec{x} ;(\theta, \phi))=\sum_{t=1}^{T} \underset{q_{\phi}\left(z_{t} | \vec{x}\right)}{\mathbb{E}}\left[\log p_{\theta}\left(x_{t} | z_{t}\right)\right]} {-\operatorname{KL}\left(q_{\phi}\left(z_{1} | \vec{x}\right) | p_{\theta}\left(z_{1}\right)\right)}  {-\sum_{t=2}^{T} \underset{q_{\phi}\left(z_{t-1} | \vec{x}\right)}{\mathbb{E}}\left[\operatorname{KL}\left(q_{\phi}\left(z_{t} | z_{t-1}\vec{x}\right) | p_{\theta}\left(z_{t} | z_{t-1}\right)\right)\right]}$

In [15]:
from pixyz.losses import KullbackLeibler
from pixyz.losses import Expectation as E
from pixyz.losses import LogProb
from pixyz.losses import IterativeLoss

In [16]:
step_loss = - E(encoder, LogProb(decoder)) + KullbackLeibler(encoder, prior)
_loss = IterativeLoss(step_loss, max_iter=t_max, 
                      series_var=["x", "h"], update_value={"z": "z_prev"})
loss = E(rnn, _loss).mean()
print_latex(loss)

<IPython.core.display.Math object>

In [17]:
step_loss = - E(encoder, LogProb(decoder)) + KullbackLeibler(encoder, prior)
# _loss = IterativeLoss(step_loss, max_iter=t_max, 
#                       series_var=["x", "h"], update_value={"z": "z_prev"})
loss = E(rnn, step_loss).mean()
print_latex(loss)

<IPython.core.display.Math object>

In [18]:
from pixyz.models import Model

class DMM(Model):
    def __init__(self,
                 optimizer=optim.Adam,
                 optimizer_params={},
                 clip_grad_norm=None,
                 clip_grad_value=None):
        """
        Parameters
        ----------
        loss : pixyz.losses.Loss
            Loss class for training.
        test_loss : pixyz.losses.Loss
            Loss class for testing.
        distributions : list
            List of :class:`pixyz.distributions.Distribution`.
        optimizer : torch.optim
            Optimization algorithm.
        optimizer_params : dict
            Parameters of optimizer
        clip_grad_norm : float or int
            Maximum allowed norm of the gradients.
        clip_grad_value : float or int
            Maximum allowed value of the gradients.
        """
        self.prior = Prior().to(device)
        self.encoder = Inference().to(device)
        self.decoder = Generator().to(device)
        self.rnn = RNN().to(device)
        
        self.reconst_loss = E(self.encoder, LogProb(self.decoder))
        self.kl_loss = KullbackLeibler(self.encoder, self.prior)
        
        self.step_loss = self.kl_loss - self.reconst_loss
        
        distributions = [self.prior, self.encoder, self.decoder, self.rnn]
        self.distributions = nn.ModuleList(distributions)

        # set params and optim
        params = self.distributions.parameters()
        self.optimizer = optimizer(params, **optimizer_params)

        self.clip_norm = clip_grad_norm
        self.clip_value = clip_grad_value
        
        
    
    def calculate_loss(self, input_var_dict={}):        
        batch_size = input_var_dict['x'].size()[1]
        # time_dimension = input_var_dict['x'].size()[0]
        
        z_prev = torch.zeros(batch_size, z_dim).to(device)
        
        """
        # Without IterativeLoss
        total_loss = 0
        for t in range(t_max):
            h_t = h[t]
            step_loss, samples = self.step_loss.eval({'x': x_t, 'z_prev': z_prev}, return_dict=True)
            total_loss += step_loss
            z_prev = samples["z"]
        """
            
        # With IterativeLoss            
        _loss = IterativeLoss(self.step_loss, max_iter=t_max,
                             series_var=['x', 'h'],
                             update_value={"z": "z_prev"})
        loss = E(self.rnn, _loss).mean()
        total_loss = loss.eval({'x': input_var_dict['x'], 'z_prev': z_prev})
        
        return total_loss
    
    def train(self, train_x_dict={}):
        """Train the model.

        Parameters
        ----------
        train_x_dict : dict
            Input data.
        **kwargs

        Returns
        -------
        loss : torch.Tensor
            Train loss value

        """
        self.distributions.train()

        self.optimizer.zero_grad()
        loss = self.calculate_loss(train_x_dict)

        # backprop
        loss.backward()

        if self.clip_norm:
            clip_grad_norm_(self.distributions.parameters(), self.clip_norm)
        if self.clip_value:
            clip_grad_value_(self.distributions.parameters(), self.clip_value)

        # update params
        self.optimizer.step()

        return loss.item()
    
    def test(self, test_x_dict={}):
        """Test the model.

        Parameters
        ----------
        test_x_dict : dict
            Input data
        **kwargs

        Returns
        -------
        loss : torch.Tensor
            Test loss value

        """
        self.distributions.eval()

        with torch.no_grad():
            loss = self.calculate_loss(test_x_dict)

        return loss.item()
    
    def generate_image_after_nsteps(self, n_step_num, original_data):
        self.distributions.eval()
        with torch.no_grad():
            xs = []
            x = original_data.transpose(0, 1)
            batch_size = original_data.size()[0]
            z_prev = torch.zeros(batch_size, z_dim).to(device)
            h = self.rnn.sample_mean({'x': x})
            for t in range(t_max):
                if t < n_step_num - 1:
                    # before n_step, reconstruct
                    h_t = h[t]
                    z_t = self.encoder.sample_mean({'h': h_t, 'z_prev': z_prev})
                    dec_x = self.decoder.sample_mean({'z': z_t})
                    z_prev = z_t
                    xs.append(dec_x[None, :])
                else:
                    # generate
                    z_t = self.prior.sample({'z_prev': z_prev})["z"]
                    dec_x = self.decoder.sample_mean({'z': z_t})
                    z_prev = z_t
                    xs.append(dec_x[None, :])
            generated_img = torch.cat(xs, dim=0).transpose(0, 1)
        return generated_img

    def reconst_image(self, original_data):
        self.distributions.eval()
        with torch.no_grad():
            xs = []
            x = original_data.transpose(0, 1)
            batch_size = original_data.size()[0]
            z_prev = torch.zeros(batch_size, z_dim).to(device)
            h = self.rnn.sample_mean({'x': x})
            for t in range(t_max):
                h_t = h[t]
                z_t = self.encoder.sample_mean({'h': h_t, 'z_prev': z_prev})
                dec_x = self.decoder.sample_mean({'z': z_t})
                z_prev = z_t
                xs.append(dec_x[None, :])
            recon_img = torch.cat(xs, dim=0).transpose(0, 1)
        return recon_img

    def plot_image_from_latent(self, batch_size):
        self.distributions.eval()
        with torch.no_grad():
            xs = []
            z_prev = torch.zeros(batch_size, z_dim).to(device)
            for step in range(t_max):
                z_t = self.prior.sample({'z_prev': z_prev})['z']
                dec_x = self.decoder.sample_mean({'z': z_t})
                z_prev = z_t
                xs.append(dec_x[None, :])
            plotted_image = torch.cat(xs, dim=0).transpose(0, 1)
        return plotted_image

In [19]:
 _x, _ = iter(test_loader).next()
fixed_batch = _x.to(device)
batch_size = fixed_batch.size()[0]
sequential_x = fixed_batch.transpose(0, 1)

In [20]:
dmm = DMM()
dmm.calculate_loss(input_var_dict={'x': sequential_x})

tensor(600.5259, device='cuda:0', grad_fn=<MeanBackward0>)

## Define DMM model using Model class

## Define Train and Test loop using model

In [21]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        batch_size = data.size()[0]
        x = data.transpose(0, 1)
        z_prev = torch.zeros(batch_size, z_dim).to(device)
        #q_z_prev = torch.zeros(batch_size, z_dim).to(device)
        if train_mode:
            mean_loss += model.train({'x': x, 'z_prev': z_prev}) * batch_size
        else:
            mean_loss += model.test({'x': x, 'z_prev': z_prev}) * batch_size
    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [22]:
writer = SummaryWriter(comment='DMM_Iterative_Pixyz')
# fixed _x for watching reconstruction improvement
_x, _ = iter(test_loader).next()
_x = _x.to(device)
dmm = DMM(optimizer=optim.Adam, optimizer_params={'lr': 3e-3})

for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, dmm, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, dmm, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = dmm.plot_image_from_latent(batch_size)[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    generated_img_7 = dmm.generate_image_after_nsteps(7, _x)
    writer.add_images('Generate_after_7steps', generated_img_7[:, None], epoch)
    
    generated_img_14 = dmm.generate_image_after_nsteps(14, _x)
    writer.add_images('Generate_after_14steps', generated_img_14[:, None], epoch)
    
    recon_img = dmm.reconst_image(_x)
    writer.add_images('Reconstructed',  recon_img[:, None], epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 235/235 [00:34<00:00,  6.88it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 1 Train loss: 241.9602


100%|██████████| 40/40 [00:02<00:00, 14.54it/s]


Test loss: 174.9231


100%|██████████| 235/235 [00:34<00:00,  6.86it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 2 Train loss: 156.8226


100%|██████████| 40/40 [00:02<00:00, 13.37it/s]


Test loss: 144.5665


100%|██████████| 235/235 [00:34<00:00,  6.86it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 3 Train loss: 139.5204


100%|██████████| 40/40 [00:02<00:00, 14.28it/s]


Test loss: 133.8730


100%|██████████| 235/235 [00:33<00:00,  6.93it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 4 Train loss: 132.8676


100%|██████████| 40/40 [00:02<00:00, 15.28it/s]


Test loss: 129.7791


100%|██████████| 235/235 [00:34<00:00,  6.87it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 5 Train loss: 129.5975


100%|██████████| 40/40 [00:02<00:00, 15.29it/s]


Test loss: 127.6434


100%|██████████| 235/235 [00:33<00:00,  7.01it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 6 Train loss: 127.2680


100%|██████████| 40/40 [00:02<00:00, 13.58it/s]


Test loss: 125.6527


100%|██████████| 235/235 [00:34<00:00,  6.79it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 7 Train loss: 125.5457


100%|██████████| 40/40 [00:02<00:00, 14.73it/s]


Test loss: 124.0644


100%|██████████| 235/235 [00:34<00:00,  6.89it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 8 Train loss: 124.2994


100%|██████████| 40/40 [00:02<00:00, 14.16it/s]


Test loss: 123.6804


100%|██████████| 235/235 [00:34<00:00,  6.88it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 9 Train loss: 123.3705


100%|██████████| 40/40 [00:02<00:00, 15.07it/s]


Test loss: 122.1760


100%|██████████| 235/235 [00:33<00:00,  6.93it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 10 Train loss: 122.4995


100%|██████████| 40/40 [00:02<00:00, 15.59it/s]


Test loss: 121.2313


100%|██████████| 235/235 [00:33<00:00,  6.99it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 11 Train loss: 121.8245


100%|██████████| 40/40 [00:02<00:00, 14.83it/s]


Test loss: 121.0417


100%|██████████| 235/235 [00:33<00:00,  6.97it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 12 Train loss: 121.3876


100%|██████████| 40/40 [00:02<00:00, 14.95it/s]


Test loss: 120.2135


100%|██████████| 235/235 [00:34<00:00,  6.87it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 13 Train loss: 120.8118


100%|██████████| 40/40 [00:03<00:00, 13.25it/s]


Test loss: 119.6426


100%|██████████| 235/235 [00:34<00:00,  6.85it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 14 Train loss: 120.5697


100%|██████████| 40/40 [00:02<00:00, 13.96it/s]


Test loss: 119.0193


100%|██████████| 235/235 [00:34<00:00,  6.84it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 15 Train loss: 120.1707


100%|██████████| 40/40 [00:02<00:00, 14.40it/s]


Test loss: 118.6662


100%|██████████| 235/235 [00:33<00:00,  6.98it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 16 Train loss: 119.8397


100%|██████████| 40/40 [00:02<00:00, 13.60it/s]


Test loss: 118.4765


100%|██████████| 235/235 [00:34<00:00,  6.83it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 17 Train loss: 119.5151


100%|██████████| 40/40 [00:02<00:00, 15.45it/s]


Test loss: 118.4786


100%|██████████| 235/235 [00:33<00:00,  7.00it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 18 Train loss: 119.3542


100%|██████████| 40/40 [00:02<00:00, 14.04it/s]


Test loss: 118.2453


100%|██████████| 235/235 [00:33<00:00,  7.00it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 19 Train loss: 119.1043


100%|██████████| 40/40 [00:02<00:00, 15.38it/s]


Test loss: 118.0624


100%|██████████| 235/235 [00:33<00:00,  7.03it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 20 Train loss: 119.0036


100%|██████████| 40/40 [00:02<00:00, 15.40it/s]


Test loss: 117.9668


100%|██████████| 235/235 [00:33<00:00,  7.03it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 21 Train loss: 118.8365


100%|██████████| 40/40 [00:02<00:00, 14.16it/s]


Test loss: 117.7396


100%|██████████| 235/235 [00:34<00:00,  6.75it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 22 Train loss: 118.6330


100%|██████████| 40/40 [00:02<00:00, 15.30it/s]


Test loss: 118.0164


100%|██████████| 235/235 [00:35<00:00,  6.71it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 23 Train loss: 118.4450


100%|██████████| 40/40 [00:02<00:00, 15.29it/s]


Test loss: 117.4074


100%|██████████| 235/235 [00:34<00:00,  6.84it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 24 Train loss: 118.3671


100%|██████████| 40/40 [00:02<00:00, 15.34it/s]


Test loss: 118.1019


100%|██████████| 235/235 [00:33<00:00,  6.92it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 25 Train loss: 118.2798


100%|██████████| 40/40 [00:02<00:00, 14.70it/s]


Test loss: 116.9339


## Image label conditioning on latent variable z

### MNIST settings
Prior(Transition model): $p_{\theta}(z_{t} | z_{t-1}, u) =  \cal{N}(\mu = f_{prior_\mu}(z_{t-1}, u), \sigma^2 = f_{prior_\sigma^2}(z_{t-1}, u)$    
Generator(Emission): $p_{\theta}(x | z)=\mathscr{B}\left(x ; \lambda=g_{x}(z)\right)$  

RNN: $p(h) = RNN(x)$  
Inference(Combiner): $p_{\phi}(z | h, z_{t-1}, u) = \cal{N}(\mu = f_{\mu}(h, z_{t-1}, u), \sigma^2 = f_{\sigma^2}(h, z_{t-1}, u)$

In [23]:
x_dim = 28
h_dim = 32
hidden_dim = 32
z_dim = 16
t_max = x_dim

# label dim
u_dim = 10

In [24]:
# RNN
class RNN(Deterministic):
    """
    h = RNN(x)
    Given observed x, RNN output hidden state
    """
    def __init__(self):
        super(RNN, self).__init__(cond_var=["x"], var=["h"])
        self.rnn = nn.GRU(x_dim, h_dim, bidirectional=True)
#         self.h0 = torch.zeros(2, batch_size, self.rnn.hidden_size).to(device)
        self.h0 = nn.Parameter(torch.zeros(2, 1, self.rnn.hidden_size))
        self.hidden_size = self.rnn.hidden_size
        
    def forward(self, x):
        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        # x(Batch_size, Time, Features)
        h0 = self.h0.expand(2, x.size(1), self.rnn.hidden_size).contiguous()
        h, _ = self.rnn(x, h0)
        return {"h": h}


# Emission p(x_t | z_t)
class Generator(Bernoulli):
    """
    Given the latent z at time step t, return the vector of
    probabilities that parameterizes the bernlulli distribution p(x_t | z_t)
    """
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, x_dim)
    
    def forward(self, z):
        h = F.relu(self.fc1(z))
        return {"probs": torch.sigmoid(self.fc2(h))}


class Inference(Normal):
    """
    given the latent z at time step t-1, the hidden state of the RNN h(x_{0:T} and u
    return the loc and scale vectors that
    parameterize the gaussian distribution q(z_t | z_{t-1}, x_{t:T}, u)
    """
    def __init__(self):
        super(Inference, self).__init__(cond_var=["h", "z_prev", "u"], var=["z"])
        self.fc1 = nn.Linear(z_dim+u_dim, h_dim*2)
        self.fc21 = nn.Linear(h_dim*2, z_dim)
        self.fc22 = nn.Linear(h_dim*2, z_dim)

        
    def forward(self, h, z_prev, u):
        feature = torch.cat((z_prev, u), 1)
        h_z = torch.tanh(self.fc1(feature))
        h = 0.5 * (h + h_z)
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}


class Prior(Normal):
    """
    Given the latent variable at the time step t-1 and u,
    return the mean and scale vectors that parameterize the
    gaussian distribution p(z_t | z_{t-1}, u)
    """
    def __init__(self):
        super(Prior, self).__init__(cond_var=["z_prev", "u"], var=["z"])
        self.fc1 = nn.Linear(z_dim+u_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, z_dim)
        self.fc22 = nn.Linear(hidden_dim, z_dim)
        
    def forward(self, z_prev, u):
        feature = torch.cat((z_prev, u), 1)
        h = F.relu(self.fc1(feature))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

In [25]:
from pixyz.models import Model

class DMM(Model):
    def __init__(self,
                 optimizer=optim.Adam,
                 optimizer_params={},
                 clip_grad_norm=None,
                 clip_grad_value=None):
        """
        Parameters
        ----------
        loss : pixyz.losses.Loss
            Loss class for training.
        test_loss : pixyz.losses.Loss
            Loss class for testing.
        distributions : list
            List of :class:`pixyz.distributions.Distribution`.
        optimizer : torch.optim
            Optimization algorithm.
        optimizer_params : dict
            Parameters of optimizer
        clip_grad_norm : float or int
            Maximum allowed norm of the gradients.
        clip_grad_value : float or int
            Maximum allowed value of the gradients.
        """
        self.prior = Prior().to(device)
        self.encoder = Inference().to(device)
        self.decoder = Generator().to(device)
        self.rnn = RNN().to(device)
        
        self.reconst_loss = E(self.encoder, LogProb(self.decoder))
        self.kl_loss = KullbackLeibler(self.encoder, self.prior)
        
        self.step_loss = self.kl_loss - self.reconst_loss
        
        distributions = [self.prior, self.encoder, self.decoder, self.rnn]
        self.distributions = nn.ModuleList(distributions)

        # set params and optim
        params = self.distributions.parameters()
        self.optimizer = optimizer(params, **optimizer_params)

        self.clip_norm = clip_grad_norm
        self.clip_value = clip_grad_value
        
        
    
    def calculate_loss(self, input_var_dict={}):        
        batch_size = input_var_dict['x'].size()[1]
        # time_dimension = input_var_dict['x'].size()[0]
        
        z_prev = torch.zeros(batch_size, z_dim).to(device)
        
        """
        # Without IterativeLoss
        total_loss = 0
        for t in range(t_max):
            h_t = h[t]
            step_loss, samples = self.step_loss.eval({'x': x_t, 'z_prev': z_prev}, return_dict=True)
            total_loss += step_loss
            z_prev = samples["z"]
        """
            
        # With IterativeLoss            
        _loss = IterativeLoss(self.step_loss, max_iter=t_max,
                             series_var=['x', 'h', 'u'],
                             update_value={"z": "z_prev"})
        loss = E(self.rnn, _loss).mean()
        total_loss = loss.eval({'x': input_var_dict['x'], 'u': input_var_dict['u'], 'z_prev': z_prev})
        
        return total_loss
    
    def train(self, train_x_dict={}):
        """Train the model.

        Parameters
        ----------
        train_x_dict : dict
            Input data.
        **kwargs

        Returns
        -------
        loss : torch.Tensor
            Train loss value

        """
        self.distributions.train()

        self.optimizer.zero_grad()
        loss = self.calculate_loss(train_x_dict)

        # backprop
        loss.backward()

        if self.clip_norm:
            clip_grad_norm_(self.distributions.parameters(), self.clip_norm)
        if self.clip_value:
            clip_grad_value_(self.distributions.parameters(), self.clip_value)

        # update params
        self.optimizer.step()

        return loss.item()
    
    def test(self, test_x_dict={}):
        """Test the model.

        Parameters
        ----------
        test_x_dict : dict
            Input data
        **kwargs

        Returns
        -------
        loss : torch.Tensor
            Test loss value

        """
        self.distributions.eval()

        with torch.no_grad():
            loss = self.calculate_loss(test_x_dict)

        return loss.item()
    
    def generate_image_after_nsteps(self, n_step_num, original_data, label):
        self.distributions.eval()
        with torch.no_grad():
            xs = []
            x = original_data.transpose(0, 1)
            batch_size = original_data.size()[0]
            z_prev = torch.zeros(batch_size, z_dim).to(device)
            h = self.rnn.sample_mean({'x': x})
            
            label = torch.eye(10)[label].to(device)
            label = torch.stack([label for num in range(t_max)])
            for t in range(t_max):
                if t < n_step_num - 1:
                    # before n_step, reconstruct
                    h_t = h[t]
                    z_t = self.encoder.sample_mean({'h': h_t, 'z_prev': z_prev, 'u': label[t]})
                    dec_x = self.decoder.sample_mean({'z': z_t})
                    z_prev = z_t
                    xs.append(dec_x[None, :])
                else:
                    # generate
                    z_t = self.prior.sample({'z_prev': z_prev, 'u': label[t]})["z"]
                    dec_x = self.decoder.sample_mean({'z': z_t})
                    z_prev = z_t
                    xs.append(dec_x[None, :])
            generated_img = torch.cat(xs, dim=0).transpose(0, 1)
        return generated_img

    def reconst_image(self, original_data, label):
        self.distributions.eval()
        with torch.no_grad():
            xs = []
            x = original_data.transpose(0, 1)
            batch_size = original_data.size()[0]
            z_prev = torch.zeros(batch_size, z_dim).to(device)
            h = self.rnn.sample_mean({'x': x})
            
            label = torch.eye(10)[label].to(device)
            label = torch.stack([label for num in range(t_max)])
            for t in range(t_max):
                h_t = h[t]
                z_t = self.encoder.sample_mean({'h': h_t, 'z_prev': z_prev, 'u': label[t]})
                dec_x = self.decoder.sample_mean({'z': z_t})
                z_prev = z_t
                xs.append(dec_x[None, :])
            recon_img = torch.cat(xs, dim=0).transpose(0, 1)
        return recon_img

    def plot_image_from_latent(self, batch_size=100):
        self.distributions.eval()
        with torch.no_grad():
            xs = []
            z_prev = torch.zeros(batch_size, z_dim).to(device)
            label = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]*10)
            label = torch.eye(10)[label].to(device)
            label = torch.stack([label for num in range(t_max)])
            for t in range(t_max):
                z_t = self.prior.sample({'z_prev': z_prev, 'u': label[t]})['z']
                dec_x = self.decoder.sample_mean({'z': z_t})
                z_prev = z_t
                xs.append(dec_x[None, :])
            plotted_image = torch.cat(xs, dim=0).transpose(0, 1)
        return plotted_image

In [26]:
# check label dims
x, label = iter(train_loader).next()
label = torch.eye(10)[label].to(device)
print(x.shape)
print(label.shape)
# copy labels for each time step
labels_t = torch.stack([label for num in range(28)])
print(labels_t.shape)

dmm = DMM()
dmm.calculate_loss(input_var_dict={'x': sequential_x, 'u': labels_t})

torch.Size([256, 28, 28])
torch.Size([256, 10])
torch.Size([28, 256, 10])


tensor(586.3669, device='cuda:0', grad_fn=<MeanBackward0>)

In [29]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, label) in enumerate(tqdm(loader)):
        data = data.to(device)
        label = torch.eye(10)[label].to(device)
    
        batch_size = data.size()[0]
        # convert to (timestep, batch_size, feature)
        x = data.transpose(0, 1)
        label = torch.stack([label for num in range(t_max)])
        
        z_prev = torch.zeros(batch_size, z_dim).to(device)
        #q_z_prev = torch.zeros(batch_size, z_dim).to(device)
        if train_mode:
            mean_loss += model.train({'x': x, 'z_prev': z_prev, 'u': label}) * batch_size
        else:
            mean_loss += model.test({'x': x, 'z_prev': z_prev, 'u': label}) * batch_size
    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [30]:
writer = SummaryWriter(comment='DMM_Iterative_Pixyz_Action_Conditional')
_x, _label = iter(test_loader).next()
_x = _x.to(device)

dmm = DMM(optimizer=optim.Adam, optimizer_params={'lr': 3e-3})


for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, dmm, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, dmm, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = dmm.plot_image_from_latent()[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    generated_img_7 = dmm.generate_image_after_nsteps(7, _x, _label)
    writer.add_images('Generate_after_7steps', generated_img_7[:, None], epoch)
    
    generated_img_14 = dmm.generate_image_after_nsteps(14, _x, _label)
    writer.add_images('Generate_after_14steps', generated_img_14[:, None], epoch)
    
    recon_img = dmm.reconst_image(_x, _label)
    writer.add_images('Reconstructed',  recon_img[:, None], epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 235/235 [00:35<00:00,  6.61it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 1 Train loss: 240.5575


100%|██████████| 40/40 [00:03<00:00, 13.10it/s]


Test loss: 161.5460


100%|██████████| 235/235 [00:35<00:00,  6.56it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 2 Train loss: 151.3338


100%|██████████| 40/40 [00:03<00:00, 13.08it/s]


Test loss: 142.2698


100%|██████████| 235/235 [00:35<00:00,  6.57it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 3 Train loss: 135.3081


100%|██████████| 40/40 [00:02<00:00, 14.19it/s]


Test loss: 129.0952


100%|██████████| 235/235 [00:36<00:00,  6.47it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 4 Train loss: 128.5064


100%|██████████| 40/40 [00:02<00:00, 14.08it/s]


Test loss: 125.9396


100%|██████████| 235/235 [00:35<00:00,  6.71it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 5 Train loss: 125.6956


100%|██████████| 40/40 [00:02<00:00, 14.24it/s]


Test loss: 123.5771


100%|██████████| 235/235 [00:34<00:00,  6.72it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 6 Train loss: 124.0109


100%|██████████| 40/40 [00:03<00:00, 11.30it/s]


Test loss: 122.5931


100%|██████████| 235/235 [00:36<00:00,  6.42it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 7 Train loss: 122.6570


100%|██████████| 40/40 [00:03<00:00, 12.39it/s]


Test loss: 120.9117


100%|██████████| 235/235 [00:36<00:00,  6.47it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 8 Train loss: 121.6669


100%|██████████| 40/40 [00:03<00:00, 12.86it/s]


Test loss: 120.3675


100%|██████████| 235/235 [00:36<00:00,  6.52it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 9 Train loss: 120.8751


100%|██████████| 40/40 [00:02<00:00, 14.04it/s]


Test loss: 119.2471


100%|██████████| 235/235 [00:36<00:00,  6.42it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 10 Train loss: 119.9530


100%|██████████| 40/40 [00:02<00:00, 14.14it/s]


Test loss: 119.0139


100%|██████████| 235/235 [00:36<00:00,  6.42it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 11 Train loss: 119.2717


100%|██████████| 40/40 [00:02<00:00, 13.36it/s]


Test loss: 118.1272


100%|██████████| 235/235 [00:36<00:00,  6.42it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 12 Train loss: 118.6759


100%|██████████| 40/40 [00:03<00:00, 12.10it/s]


Test loss: 117.8931


100%|██████████| 235/235 [00:36<00:00,  6.50it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 13 Train loss: 118.1336


100%|██████████| 40/40 [00:03<00:00, 11.63it/s]


Test loss: 117.3202


100%|██████████| 235/235 [00:35<00:00,  6.61it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 14 Train loss: 117.6876


100%|██████████| 40/40 [00:02<00:00, 14.40it/s]


Test loss: 116.3162


100%|██████████| 235/235 [00:35<00:00,  6.57it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 15 Train loss: 117.1790


100%|██████████| 40/40 [00:02<00:00, 14.31it/s]


Test loss: 115.9207


100%|██████████| 235/235 [00:35<00:00,  6.62it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 16 Train loss: 116.9519


100%|██████████| 40/40 [00:03<00:00, 11.41it/s]


Test loss: 115.8908


100%|██████████| 235/235 [00:36<00:00,  6.47it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 17 Train loss: 116.5182


100%|██████████| 40/40 [00:02<00:00, 14.14it/s]


Test loss: 115.1156


100%|██████████| 235/235 [00:35<00:00,  6.57it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 18 Train loss: 116.2762


100%|██████████| 40/40 [00:02<00:00, 13.97it/s]


Test loss: 115.2587


100%|██████████| 235/235 [00:36<00:00,  6.43it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 19 Train loss: 115.9145


100%|██████████| 40/40 [00:02<00:00, 14.12it/s]


Test loss: 114.7585


100%|██████████| 235/235 [00:37<00:00,  6.25it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 20 Train loss: 115.6697


100%|██████████| 40/40 [00:03<00:00, 12.55it/s]


Test loss: 115.0261


100%|██████████| 235/235 [00:35<00:00,  6.68it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 21 Train loss: 115.4622


100%|██████████| 40/40 [00:02<00:00, 14.09it/s]


Test loss: 114.4268


100%|██████████| 235/235 [00:36<00:00,  6.52it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 22 Train loss: 115.2438


100%|██████████| 40/40 [00:03<00:00, 12.89it/s]


Test loss: 114.2547


100%|██████████| 235/235 [00:35<00:00,  6.67it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 23 Train loss: 115.0140


100%|██████████| 40/40 [00:03<00:00, 12.51it/s]


Test loss: 114.4710


100%|██████████| 235/235 [00:35<00:00,  6.62it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 24 Train loss: 114.8183


100%|██████████| 40/40 [00:02<00:00, 14.08it/s]


Test loss: 113.7066


100%|██████████| 235/235 [00:35<00:00,  6.66it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 25 Train loss: 114.7482


100%|██████████| 40/40 [00:02<00:00, 14.17it/s]


Test loss: 113.7439
