# VRNN
* Original paper: A Recurrent Latent Variable Model for Sequential Data (https://arxiv.org/pdf/1506.02216.pdf)
* Original code: https://github.com/jych/nips2015_vrnn

## VRNN summary
VRNN extends the VAE into a recurrent framework for modelling high-dimensional sequences.  
VRNN integrates random variables into the RNN hidden state, and integrates the dependencies between the latent random variables at neighboring timesteps.

In [1]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

batch_size = 256
epochs = 5
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
# Experiment Setting
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

In [3]:
from pixyz.utils import print_latex

In [4]:
import pixyz
print(pixyz.__version__)

0.1.3


## Define probability distribution
### In the original paper
Prior(equation (5) in the paper): $p\left(\mathbf{z}_{t} | \mathbf{x}_{<t}, \mathbf{z}_{<t}\right) = \mathcal{N}\left(\boldsymbol{\mu}_{0, t}, \operatorname{diag}\left(\boldsymbol{\sigma}_{0, t}^{2}\right)\right), \text { where }\left[\boldsymbol{\mu}_{0, t}, \boldsymbol{\sigma}_{0, t}\right]=\varphi_{\tau}^{\text {prior }}\left(\mathbf{h}_{t-1}\right)$

Generator(equation (6) in the paper): $p\left(\mathbf{x}_{t} | \mathbf{z} \leq t, \mathbf{x}<t\right) = \mathcal{N}\left(\boldsymbol{\mu}_{x, t}, \operatorname{diag}\left(\boldsymbol{\sigma}_{x, t}^{2}\right)\right), \text { where }\left[\boldsymbol{\mu}_{x, t}, \boldsymbol{\sigma}_{x, t}\right]=\varphi_{\tau}^{\mathrm{dec}}\left(\varphi_{\tau}^{\mathbf{z}}\left(\mathbf{z}_{t}\right), \mathbf{h}_{t-1}\right)$

Recurrence(equation (7) in the paper): $p\left(\mathbf{h}_{t} | \mathbf{z}_{t}, \mathbf{x}_{t}, \mathbf{h}_{t-1}\right) = f_{\theta}\left(\varphi_{\tau}^{\mathbf{x}}\left(\mathbf{x}_{t}\right), \varphi_{\tau}^{\mathbf{z}}\left(\mathbf{z}_{t}\right), \mathbf{h}_{t-1}\right)$

Inference(equation (9) in the paper): $q\left(\mathbf{z}_{t} | \mathbf{x}_{\leq t}, \mathbf{z}_{<t}\right) = \mathcal{N}\left(\boldsymbol{\mu}_{z, t}, \operatorname{diag}\left(\boldsymbol{\sigma}_{z, t}^{2}\right)\right), \text { where }\left[\boldsymbol{\mu}_{z, t}, \boldsymbol{\sigma}_{z, t}\right]=\varphi_{\tau}^{\mathrm{enc}}\left(\varphi_{\tau}^{\mathbf{x}}\left(\mathbf{x}_{t}\right), \mathbf{h}_{t-1}\right)$

### MNIST Settings
Prior: $p_{\theta}(z_t | h_{t-1}) = \cal{N}(\mu=f_{prior_\mu}(h_{t-1}),\sigma^2=f_{prior_\sigma^2}(h_{t-1})$

Generator: $p_{\theta}(x | z, h_{t-1})=\mathcal{B}\left(x ; \lambda=g_{x}(z, h_{t-1})\right)$

Recurrence: $p(h_{t} | z_t, x_t, h_{t-1}) = RNN(\varphi_{\tau}^{\mathbf{x}}\left(\mathbf{x}_{t}\right), \varphi_{\tau}^{\mathbf{z}}\left(\mathbf{z}_{t}\right), \mathbf{h}_{t-1})$

Inference: $q_\phi(z_t | h_{t-1}, x_t) =  \cal{N}(\mu=f_{infer_\mu}(h_{t-1}, \varphi_{\tau}^{\mathbf{x}}(x_t)),\sigma^2=f_{infer_\sigma^2}(h_{t-1}, \varphi_{\tau}^{\mathbf{x}}(x_t))$

In [5]:
x_dim = 28
h_dim = 100
z_dim = 16
t_max = x_dim

# feature extraction for x
class Phi_x(nn.Module):
    def __init__(self):
        super(Phi_x, self).__init__()
        self.fc0 = nn.Linear(x_dim, h_dim)

    def forward(self, x):
        return F.relu(self.fc0(x))


# feature extraction for z
class Phi_z(nn.Module):
    def __init__(self):
        super(Phi_z, self).__init__()
        self.fc0 = nn.Linear(z_dim, h_dim)

    def forward(self, z):
        return F.relu(self.fc0(z))

f_phi_x = Phi_x().to(device)
f_phi_z = Phi_z().to(device)

In [6]:
from pixyz.distributions import Bernoulli, Normal, Deterministic

class Generator(Bernoulli):
    """
    Parameterizes the bernoulli(for MNIST) observation likelihood p(x_t | z_t, h_{t-1})
    Given the latent z at a particular time step t and hidden state,
    return the vector of probabilities taht parameterizes the bernoulli distribution
    """
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z", "h_prev"], var=["x"])
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, h_dim)
        self.fc3 = nn.Linear(h_dim, x_dim)
        self.f_phi_z = f_phi_z

    def forward(self, z, h_prev):
        h = torch.cat((self.f_phi_z(z), h_prev), dim=-1)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}

class Prior(Normal):
    """
    VRNN's Prior for latent z is parameterized by hidden_state h_{t-1}
    z ~ N(loc(h_{t-1}), scale(h_{t-1}))
    """
    def __init__(self):
        super(Prior, self).__init__(cond_var=["h_prev"], var=["z"])
        self.fc1 = nn.Linear(h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)

    def forward(self, h_prev):
        h = F.relu(self.fc1(h_prev))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

class Inference(Normal):
    """
    Parameterizes q(z_t | h_{t-1}, x_t)
    infered z ~ N(loc(h_{t-1}, x_t), scale(h_{t-1}, x_t))
    """
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x", "h_prev"], var=["z"], name="q")
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)
        self.f_phi_x = f_phi_x

    def forward(self, x, h_prev):
        h = torch.cat((self.f_phi_x(x), h_prev), dim=-1)
        h = F.relu(self.fc1(h))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

class Recurrence(Deterministic):
    """
    RNN for hidden_state
    p(h_t | x_t, z_t, h_prev)
    """
    def __init__(self):
        super(Recurrence, self).__init__(cond_var=["x", "z", "h_prev"], var=["h"])
        self.rnncell = nn.GRUCell(h_dim * 2, h_dim).to(device)
        self.f_phi_x = f_phi_x
        self.f_phi_z = f_phi_z
        self.hidden_size = self.rnncell.hidden_size

    def forward(self, x, z, h_prev):
        h_next = self.rnncell(torch.cat((self.f_phi_z(z), self.f_phi_x(x)), dim=-1), h_prev)
        return {"h": h_next}

## Define Loss function
### In the original paper(equation (11) in the original paper)
$\mathbb{E}_{q(\mathbf{z} \leq T | \mathbf{x} \leq T)}\left[\sum_{t=1}^{T}\left(-\mathrm{KL}\left(q\left(\mathbf{z}_{t} | \mathbf{x}_{\leq t}, \mathbf{z}_{<t}\right) \| p\left(\mathbf{z}_{t} | \mathbf{x}_{<t}, \mathbf{z}_{<t}\right)\right)+\log p\left(\mathbf{x}_{t} | \mathbf{z}_{\leq t}, \mathbf{x}_{<t}\right)\right)\right]$

In [7]:
from pixyz.losses import KullbackLeibler, LogProb, Expectation as E
from pixyz.losses import IterativeLoss

In [8]:
from pixyz.models import Model

class VRNN(Model):
    def __init__(self,
                 optimizer=optim.Adam,
                 optimizer_params={},
                 clip_grad_norm=None,
                 clip_grad_value=None):
        """
        Parameters
        ----------
        loss : pixyz.losses.Loss
            Loss class for training.
        test_loss : pixyz.losses.Loss
            Loss class for testing.
        distributions : list
            List of :class:`pixyz.distributions.Distribution`.
        optimizer : torch.optim
            Optimization algorithm.
        optimizer_params : dict
            Parameters of optimizer
        clip_grad_norm : float or int
            Maximum allowed norm of the gradients.
        clip_grad_value : float or int
            Maximum allowed value of the gradients.
        """
        self.prior = Prior().to(device)
        self.encoder = Inference().to(device)
        self.decoder = Generator().to(device)
        self.recurrence = Recurrence().to(device).to(device)
        self.encoder_with_recurrence = self.encoder * self.recurrence
        
        self.reconst_loss = E(self.encoder_with_recurrence, LogProb(self.decoder))
        self.kl_loss = KullbackLeibler(self.encoder, self.prior)        
        self.step_loss = (self.kl_loss - self.reconst_loss).mean()
        
        distributions = [self.prior, self.encoder, self.decoder, self.recurrence]
        self.distributions = nn.ModuleList(distributions)

        # set params and optim
        params = self.distributions.parameters()
        self.optimizer = optimizer(params, **optimizer_params)

        self.clip_norm = clip_grad_norm
        self.clip_value = clip_grad_value
        
        
    
    def calculate_loss(self, input_var_dict={}):        
        batch_size = input_var_dict['x'].size()[1]
        time_dimension = input_var_dict['x'].size()[0]

        h_prev = torch.zeros(batch_size, self.recurrence.hidden_size).to(device)
        
        """
        # Without IterativeLoss
        total_loss = 0
        for t in range(t_max):
            x_t = input_var_dict['x'][t]
            step_loss, samples = self.step_loss.eval({'x': x_t, 'h_prev': h_prev}, return_dict=True)
            total_loss += step_loss
            h_prev = samples["h"]
        """
            
        # With IterativeLoss            
        loss = IterativeLoss(self.step_loss, max_iter=t_max,
                             series_var=['x'],
                             update_value={"h": "h_prev"})
        total_loss = loss.eval({'x': input_var_dict['x'], 'h_prev': h_prev})
        
        return total_loss
    
    def train(self, train_x_dict={}):
        """Train the model.

        Parameters
        ----------
        train_x_dict : dict
            Input data.
        **kwargs

        Returns
        -------
        loss : torch.Tensor
            Train loss value

        """
        self.distributions.train()

        self.optimizer.zero_grad()
        loss = self.calculate_loss(train_x_dict)

        # backprop
        loss.backward()

        if self.clip_norm:
            clip_grad_norm_(self.distributions.parameters(), self.clip_norm)
        if self.clip_value:
            clip_grad_value_(self.distributions.parameters(), self.clip_value)

        # update params
        self.optimizer.step()

        return loss.item()
    
    def test(self, test_x_dict={}):
        """Test the model.

        Parameters
        ----------
        test_x_dict : dict
            Input data
        **kwargs

        Returns
        -------
        loss : torch.Tensor
            Test loss value

        """
        self.distributions.eval()

        with torch.no_grad():
            loss = self.calculate_loss(test_x_dict)

        return loss.item()
    
    def generate_image_after_nsteps(self, n_step_num, original_data):
        xs = []
        x = original_data.transpose(0, 1)
        batch_size = original_data.size()[0]
        h_prev = torch.zeros(batch_size, self.recurrence.hidden_size).to(device)
        for t in range(28):
            if t < n_step_num - 1:
                # before n_step, reconstruct
                x_t = x[t]
                z_t = self.encoder.sample_mean({'x': x_t, 'h_prev': h_prev})
                h = self.recurrence.sample_mean({'x': x_t, 'h_prev': h_prev, 'z': z_t})
                dec_x = self.decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
                h_prev = h
                xs.append(dec_x[None, :])
            else:
                # generate
                z_t = self.prior.sample_mean({'h_prev': h_prev})
                dec_x = self.decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
                h = self.recurrence.sample_mean({'x': dec_x, 'h_prev': h_prev, 'z': z_t})
                h_prev = h
                xs.append(dec_x[None, :])
        generated_img = torch.cat(xs, dim=0).transpose(0, 1)
        return generated_img


    def reconst_image(self, original_data):
        xs = []
        x = original_data.transpose(0, 1)
        batch_size = original_data.size()[0]
        h_prev = torch.zeros(batch_size, self.recurrence.hidden_size).to(device)
        for t in range(28):
            # before n_step, reconstruct
            x_t = x[t]
            z_t = self.encoder.sample_mean({'x': x_t, 'h_prev': h_prev})
            h = self.recurrence.sample_mean({'x': x_t, 'h_prev': h_prev, 'z': z_t})
            dec_x = self.decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
            h_prev = h
            xs.append(dec_x[None, :])
        recon_img = torch.cat(xs, dim=0).transpose(0, 1)
        return recon_img
    
    def plot_image_from_latent(self, batch_size):
        xs = []
        h_prev = torch.zeros(batch_size, self.recurrence.hidden_size).to(device)
        for step in range(t_max):
            z_t = self.prior.sample({'h_prev': h_prev})['z']
            dec_x = self.decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
            h = self.recurrence.sample_mean({'x': dec_x, 'h_prev': h_prev, 'z': z_t})
            h_prev = h
            xs.append(dec_x[None, :])
        plotted_image = torch.cat(xs, dim=0).transpose(0, 1)
        return plotted_image

In [9]:
_x, _ = iter(test_loader).next()
fixed_batch = _x.to(device)
batch_size = fixed_batch.size()[0]
sequential_x = fixed_batch.transpose(0, 1)
vrnn = VRNN(optimizer=optim.Adam, optimizer_params={'lr': 1e-3})
vrnn.calculate_loss({'x': sequential_x})

tensor(549.5359, device='cuda:0', grad_fn=<AddBackward0>)

In [10]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        batch_size = data.size()[0]
        x = data.transpose(0, 1)
        if train_mode:
            loss = model.train({'x': x})
            mean_loss += loss * batch_size
        else:
            loss = model.test({'x': x})
            mean_loss += loss * batch_size

    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [11]:
writer = SummaryWriter('Original_VRNN_Model_without_Stochastic')
# fixed _x for watching reconstruction improvement
_x, _ = iter(test_loader).next()
_x = _x.to(device)
vrnn = VRNN(optimizer=optim.Adam, optimizer_params={'lr': 1e-3})
for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, vrnn, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, vrnn, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = vrnn.plot_image_from_latent(batch_size)[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    generated_img_7 = vrnn.generate_image_after_nsteps(7, _x)
    writer.add_images('Generate_after_7steps', generated_img_7[:, None], epoch)
    
    generated_img_14 = vrnn.generate_image_after_nsteps(14, _x)
    writer.add_images('Generate_after_14steps', generated_img_14[:, None], epoch)
    
    recon_img = vrnn.reconst_image(_x)
    writer.add_images('Reconstructed',  recon_img[:, None], epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 235/235 [00:45<00:00,  5.15it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 1 Train loss: 198.6982


100%|██████████| 40/40 [00:03<00:00, 11.41it/s]


Test loss: 110.7391


100%|██████████| 235/235 [00:45<00:00,  5.15it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 2 Train loss: 98.7199


100%|██████████| 40/40 [00:03<00:00, 11.66it/s]


Test loss: 90.9959


100%|██████████| 235/235 [00:45<00:00,  5.16it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 3 Train loss: 89.7438


100%|██████████| 40/40 [00:03<00:00, 10.74it/s]


Test loss: 86.8309


100%|██████████| 235/235 [00:45<00:00,  5.19it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 4 Train loss: 86.8837


100%|██████████| 40/40 [00:03<00:00, 10.84it/s]


Test loss: 84.8012


100%|██████████| 235/235 [00:45<00:00,  5.14it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 5 Train loss: 84.8467


100%|██████████| 40/40 [00:03<00:00, 10.86it/s]


Test loss: 83.2546


## With Stochastic Construction Loss

In [17]:
x_dim = 28
h_dim = 100
z_dim = 16
t_max = x_dim

# feature extraction for x
class Phi_x(nn.Module):
    def __init__(self):
        super(Phi_x, self).__init__()
        self.fc0 = nn.Linear(x_dim, h_dim)

    def forward(self, x):
        return F.relu(self.fc0(x))


# feature extraction for z
class Phi_z(nn.Module):
    def __init__(self):
        super(Phi_z, self).__init__()
        self.fc0 = nn.Linear(z_dim, h_dim)

    def forward(self, z):
        return F.relu(self.fc0(z))

f_phi_x = Phi_x().to(device)
f_phi_z = Phi_z().to(device)


from pixyz.distributions import Bernoulli, Normal, Deterministic

class Generator(Bernoulli):
    """
    Parameterizes the bernoulli(for MNIST) observation likelihood p(x_t | z_t, h_{t-1})
    Given the latent z at a particular time step t and hidden state,
    return the vector of probabilities taht parameterizes the bernoulli distribution
    """
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z", "h_prev"], var=["x"])
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, h_dim)
        self.fc3 = nn.Linear(h_dim, x_dim)
        self.f_phi_z = f_phi_z

    def forward(self, z, h_prev):
        h = torch.cat((self.f_phi_z(z), h_prev), dim=-1)
        h = F.relu(self.fc1(h))
        h = F.relu(self.fc2(h))
        return {"probs": torch.sigmoid(self.fc3(h))}

class Prior(Normal):
    """
    VRNN's Prior for latent z is parameterized by hidden_state h_{t-1}
    z ~ N(loc(h_{t-1}), scale(h_{t-1}))
    """
    def __init__(self):
        super(Prior, self).__init__(cond_var=["h_prev"], var=["z"])
        self.fc1 = nn.Linear(h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)

    def forward(self, h_prev):
        h = F.relu(self.fc1(h_prev))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

class Inference(Normal):
    """
    Parameterizes q(z_t | h_{t-1}, x_t)
    infered z ~ N(loc(h_{t-1}, x_t), scale(h_{t-1}, x_t))
    """
    def __init__(self):
        super(Inference, self).__init__(cond_var=["x", "h_prev"], var=["z"], name="q")
        self.fc1 = nn.Linear(h_dim + h_dim, h_dim)
        self.fc21 = nn.Linear(h_dim, z_dim)
        self.fc22 = nn.Linear(h_dim, z_dim)
        self.f_phi_x = f_phi_x

    def forward(self, x, h_prev):
        h = torch.cat((self.f_phi_x(x), h_prev), dim=-1)
        h = F.relu(self.fc1(h))
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

class Recurrence(Deterministic):
    """
    RNN for hidden_state
    p(h_t | x_t, z_t, h_prev)
    """
    def __init__(self):
        super(Recurrence, self).__init__(cond_var=["x", "z", "h_prev"], var=["h"])
        self.rnncell = nn.GRUCell(h_dim * 2, h_dim).to(device)
        self.f_phi_x = f_phi_x
        self.f_phi_z = f_phi_z
        self.hidden_size = self.rnncell.hidden_size

    def forward(self, x, z, h_prev):
        h_next = self.rnncell(torch.cat((self.f_phi_z(z), self.f_phi_x(x)), dim=-1), h_prev)
        return {"h": h_next}

In [18]:
from pixyz.models import Model
from pixyz.losses import KullbackLeibler, LogProb, Expectation as E
from pixyz.losses import IterativeLoss
from pixyz.losses import StochasticReconstructionLoss

class VRNN(Model):
    def __init__(self,
                 optimizer=optim.Adam,
                 optimizer_params={},
                 clip_grad_norm=None,
                 clip_grad_value=None):
        """
        Parameters
        ----------
        loss : pixyz.losses.Loss
            Loss class for training.
        test_loss : pixyz.losses.Loss
            Loss class for testing.
        distributions : list
            List of :class:`pixyz.distributions.Distribution`.
        optimizer : torch.optim
            Optimization algorithm.
        optimizer_params : dict
            Parameters of optimizer
        clip_grad_norm : float or int
            Maximum allowed norm of the gradients.
        clip_grad_value : float or int
            Maximum allowed value of the gradients.
        """
        self.prior = Prior().to(device)
        self.encoder = Inference().to(device)
        self.decoder = Generator().to(device)
        self.recurrence = Recurrence().to(device).to(device)
        self.encoder_with_recurrence = self.encoder * self.recurrence
        
        self.reconst_loss = StochasticReconstructionLoss(self.encoder_with_recurrence, self.decoder)
        self.kl_loss = KullbackLeibler(self.encoder, self.prior)        
        self.step_loss = (self.kl_loss + self.reconst_loss).mean()
        
        
        distributions = [self.prior, self.encoder, self.decoder, self.recurrence]
        self.distributions = nn.ModuleList(distributions)

        # set params and optim
        params = self.distributions.parameters()
        self.optimizer = optimizer(params, **optimizer_params)

        self.clip_norm = clip_grad_norm
        self.clip_value = clip_grad_value
        
        
    
    def calculate_loss(self, input_var_dict={}):        
        batch_size = input_var_dict['x'].size()[1]
        time_dimension = input_var_dict['x'].size()[0]

        h_prev = torch.zeros(batch_size, self.recurrence.hidden_size).to(device)
        
        """
        # Without IterativeLoss
        total_loss = 0
        for t in range(t_max):
            x_t = input_var_dict['x'][t]
            step_loss, samples = self.step_loss.eval({'x': x_t, 'h_prev': h_prev}, return_dict=True)
            total_loss += step_loss
            h_prev = samples["h"]
        """
            
        # With IterativeLoss            
        loss = IterativeLoss(self.step_loss, max_iter=t_max,
                             series_var=['x'],
                             update_value={"h": "h_prev"})
        total_loss = loss.eval({'x': input_var_dict['x'], 'h_prev': h_prev})
        
        return total_loss
    
    def train(self, train_x_dict={}):
        """Train the model.

        Parameters
        ----------
        train_x_dict : dict
            Input data.
        **kwargs

        Returns
        -------
        loss : torch.Tensor
            Train loss value

        """
        self.distributions.train()

        self.optimizer.zero_grad()
        loss = self.calculate_loss(train_x_dict)

        # backprop
        loss.backward()

        if self.clip_norm:
            clip_grad_norm_(self.distributions.parameters(), self.clip_norm)
        if self.clip_value:
            clip_grad_value_(self.distributions.parameters(), self.clip_value)

        # update params
        self.optimizer.step()

        return loss.item()
    
    def test(self, test_x_dict={}):
        """Test the model.

        Parameters
        ----------
        test_x_dict : dict
            Input data
        **kwargs

        Returns
        -------
        loss : torch.Tensor
            Test loss value

        """
        self.distributions.eval()

        with torch.no_grad():
            loss = self.calculate_loss(test_x_dict)

        return loss.item()
    
    def generate_image_after_nsteps(self, n_step_num, original_data):
        xs = []
        x = original_data.transpose(0, 1)
        batch_size = original_data.size()[0]
        h_prev = torch.zeros(batch_size, self.recurrence.hidden_size).to(device)
        for t in range(28):
            if t < n_step_num - 1:
                # before n_step, reconstruct
                x_t = x[t]
                z_t = self.encoder.sample_mean({'x': x_t, 'h_prev': h_prev})
                h = self.recurrence.sample_mean({'x': x_t, 'h_prev': h_prev, 'z': z_t})
                dec_x = self.decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
                h_prev = h
                xs.append(dec_x[None, :])
            else:
                # generate
                z_t = self.prior.sample_mean({'h_prev': h_prev})
                dec_x = self.decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
                h = self.recurrence.sample_mean({'x': dec_x, 'h_prev': h_prev, 'z': z_t})
                h_prev = h
                xs.append(dec_x[None, :])
        generated_img = torch.cat(xs, dim=0).transpose(0, 1)
        return generated_img


    def reconst_image(self, original_data):
        xs = []
        x = original_data.transpose(0, 1)
        batch_size = original_data.size()[0]
        h_prev = torch.zeros(batch_size, self.recurrence.hidden_size).to(device)
        for t in range(28):
            # before n_step, reconstruct
            x_t = x[t]
            z_t = self.encoder.sample_mean({'x': x_t, 'h_prev': h_prev})
            h = self.recurrence.sample_mean({'x': x_t, 'h_prev': h_prev, 'z': z_t})
            dec_x = self.decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
            h_prev = h
            xs.append(dec_x[None, :])
        recon_img = torch.cat(xs, dim=0).transpose(0, 1)
        return recon_img
    
    def plot_image_from_latent(self, batch_size):
        xs = []
        h_prev = torch.zeros(batch_size, self.recurrence.hidden_size).to(device)
        for step in range(t_max):
            z_t = self.prior.sample({'h_prev': h_prev})['z']
            dec_x = self.decoder.sample_mean({'h_prev': h_prev, 'z': z_t})
            h = self.recurrence.sample_mean({'x': dec_x, 'h_prev': h_prev, 'z': z_t})
            h_prev = h
            xs.append(dec_x[None, :])
        plotted_image = torch.cat(xs, dim=0).transpose(0, 1)
        return plotted_image

In [19]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        batch_size = data.size()[0]
        x = data.transpose(0, 1)
        if train_mode:
            loss = model.train({'x': x})
            mean_loss += loss * batch_size
        else:
            loss = model.test({'x': x})
            mean_loss += loss * batch_size

    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [20]:
_x, _ = iter(test_loader).next()
fixed_batch = _x.to(device)
batch_size = fixed_batch.size()[0]
sequential_x = fixed_batch.transpose(0, 1)
vrnn = VRNN(optimizer=optim.Adam, optimizer_params={'lr': 1e-3})
vrnn.calculate_loss({'x': sequential_x})

tensor(547.7822, device='cuda:0', grad_fn=<AddBackward0>)

In [21]:
writer = SummaryWriter('Original_VRNN_Model_with_Stochastic')
# fixed _x for watching reconstruction improvement
_x, _ = iter(test_loader).next()
_x = _x.to(device)
vrnn = VRNN(optimizer=optim.Adam, optimizer_params={'lr': 1e-3})
for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, vrnn, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, vrnn, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = vrnn.plot_image_from_latent(batch_size)[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    generated_img_7 = vrnn.generate_image_after_nsteps(7, _x)
    writer.add_images('Generate_after_7steps', generated_img_7[:, None], epoch)
    
    generated_img_14 = vrnn.generate_image_after_nsteps(14, _x)
    writer.add_images('Generate_after_14steps', generated_img_14[:, None], epoch)
    
    recon_img = vrnn.reconst_image(_x)
    writer.add_images('Reconstructed',  recon_img[:, None], epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 235/235 [00:45<00:00,  5.18it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 1 Train loss: 198.2223


100%|██████████| 40/40 [00:03<00:00, 10.74it/s]


Test loss: 113.1748


100%|██████████| 235/235 [00:45<00:00,  5.20it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 2 Train loss: 99.7568


100%|██████████| 40/40 [00:03<00:00, 10.88it/s]


Test loss: 91.9831


100%|██████████| 235/235 [00:45<00:00,  5.14it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 3 Train loss: 90.6161


100%|██████████| 40/40 [00:03<00:00, 11.27it/s]


Test loss: 87.4282


100%|██████████| 235/235 [00:45<00:00,  5.21it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 4 Train loss: 87.0113


100%|██████████| 40/40 [00:03<00:00, 11.62it/s]


Test loss: 84.7338


100%|██████████| 235/235 [00:44<00:00,  5.26it/s]
  0%|          | 0/40 [00:00<?, ?it/s]

Epoch: 5 Train loss: 84.6224


100%|██████████| 40/40 [00:03<00:00, 10.92it/s]


Test loss: 82.6811
