# Hierarchical Recurrent State Space Model
* Original paper: Variational Temporal Abstraction (https://arxiv.org/pdf/1910.00775.pdf)
* Original code: https://github.com/taesupkim/vta

## Hierarchical Recurrent State Space Model summary
HRSSM infer the latent temporal structure and thus perform the stochastic state transition hierarchically

In [1]:
import pixyz
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

batch_size = 128
epochs = 1
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
# Experiment Setting
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

In [3]:
from pixyz.utils import print_latex

## Define probability distribution

## In the original paper
### Generative model
Generative process:  
$p(X, Z, S, M)=\prod_{t=1}^{T} p\left(x_{t} | s_{t}\right) p\left(m_{t} | s_{t}\right) p\left(s_{t} | s_{<t}, z_{t}, m_{t-1}\right) p\left(z_{t} | z_{<t}, m_{t-1}\right)$  

Prior on Temporal Structure:  
$p(m_t | s_t) = \cal B(\sigma(f_{m-mlp}(s_t))$  
$p\left(m_{t}=1 | s_{t}\right)=\left\{\begin{array}{ll}{0} & {\text { if } n\left(m_{<t}\right) \geq N_{\max }} \\ {1} & {\text { elseif } l\left(m_{<t}\right) \geq l_{\max }} \\ {\sigma\left(f_{m-\operatorname{mlp}}\left(s_{t}\right)\right)} & {\text { otherwise }}\end{array}\right.$  


Hierarchical Transitions:  
$p\left(z_{t} | z_{<t}, m_{<t}\right)=\left\{\begin{array}{ll}{\delta\left(z_{t}=z_{t-1}\right)} & {\text { if } m_{t-1}=0(\mathrm{COPY})} \\ {\tilde{p}\left(z_{t} | c_{t}\right)} & {\text { otherwise (UPDATE) }}\end{array}\right.$  
${\tilde{p}\left(z_{t} | c_{t}\right)} = \cal N(z_t | \mu_z(c_t), \sigma_z(c_t))$  

$c_{t}=\left\{\begin{array}{ll}{c_{t-1}} & {\text { if } m_{t-1}=0(\mathrm{COPY})} \\ {f_{z-\text { rnn }}\left(z_{t-1}, c_{t-1}\right)} & {\text { otherwise (UPDATE) }}\end{array}\right.$  


Observation transition:  
$p\left(s_{t} | s_{<t}, z_{t}, m_{<t}\right)=\tilde{p}\left(s_{t} | h_{t}\right) \quad \text { where } \quad h_{t}=\left\{\begin{array}{ll}{f_{s-\text { rnn }}\left(s_{t-1} \| z_{t}, h_{t-1}\right)} & {\text { if } m_{t-1}=0 \text { (UPDATE) }} \\ {f_{s-\text { mlp }}\left(z_{t}\right)} & {\text { otherwise (INIT) }}\end{array}\right.$

${\tilde{p}\left(s_{t} | h_{t}\right)} = \cal N(s_t | \mu_z(h_t), \sigma_z(h_t))$  

### Inference
$q(Z, S, M | X)=q(M | X) q(Z | M, X) q(S | Z, M, X)$

Sequence Decomposition:  
$q(M | X)=\prod_{t=1}^{T} q\left(m_{t} | X\right)=\prod_{t=1}^{T} \operatorname{Bern}\left(m_{t} | \sigma(\varphi(X))\right)$  

State Inference:  
$q(Z | M, X)=\prod_{t=1}^{T} q\left(z_{t} | M, X\right)$  

$q\left(z_{t} | M, X\right)=\left\{\begin{array}{ll}{\delta\left(z_{t}=z_{t-1}\right)} & {\text { if } m_{t-1}=0(\mathrm{COPY})} \\ {\tilde{q}\left(z_{t} | \psi_{t-1}^{\mathrm{fwd}}, \psi_{t}^{\mathrm{bwd}}\right)} & {\text { otherwise (UPDATE) }}\end{array}\right.$

Observation abstraction predictor:  
$q(S | Z, M, X)=\prod_{t=1}^{T} q\left(s_{t} | z_{t}, M, X\right)$  
$q\left(s_{t} | z_{t}, M, X\right)=\tilde{q}\left(s_{t} | z_{t}, \phi_{t}^{\mathrm{fwd}}\right)$

In [4]:
from pixyz.distributions import Bernoulli, Normal, Deterministic

In [23]:
### Generator

In [None]:
class AbstractionExtractor(nn.Module):
    def __init__(self, c_dim, z_dim, abs_feat_size):
        super(AbstractionExtractor, self).__init__()
        self.c_dim = c_dim
        self.z_dim = z_dim
        
        self.abs_feat_size = abs_feat_size
        
        self.fc = nn.Linear(self.c_dim + self.z_dim, self.abs_feat_size)
    
    def forward(self, c_dim, z_dim):
        concat_feature = torch.cat((c_dim, z_dim), dim=-1)
        return self.fc(concat_feature)


class ObservationExtractor(nn.Module):
    def __init__(self, h_dim, s_dim, obs_feat_size):
        super(ObservationExtractor, self).__init__()
        self.h_dim = h_dim
        self.s_dim = s_dim
        
        self.obs_feat_size = obs_feat_size
        
        self.fc = nn.Linear(self.h_dim + self.s_dim, self.obs_feat_size)
    
    def forward(self, h_dim, s_dim):
        concat_feature = torch.cat((h_dim, s_dim), dim=-1)
        return self.fc(concat_feature)

In [19]:
class PriorOnTemporalStructure(Bernoulli):
    """
    Binary indicator
    p(mt | st)
    PriorBoundaryDetector
    output_dimが2になっているのはなぜ？
    nn.Identityで通している，あとでsigmoidかけているのか？
    """
    def __init__(self, s_dim):
        super(PriorOnTemporalStructure, self).__init__(cond_var=["s"], var=["m"])
        self.s_dim = s_dim
        self.fc1 = nn.Linear(s_dim, 1)
        
    def forward(self, s):
        m = self.fc1(s)
        m = torch.sigmoid(s)
        return {"probs": m}


class AbstRecurrence(Deterministic):
    """
    RecurrentLayer
    self.update_abs_beliefの部分
    ct = f_z_rnn(z_t-1, c_t-1)
    """
    def __init__(self, c_dim, z_dim):
        super(Recurrence, self).__init__(cond_var=["c_prev", "z_prev"], var=["c"])
        self.c_dim = c_dim
        self.z_dim = z_dim
        
        self.rnn_cell = nn.GRUCell(input_size=self.z_dim, hidden_size=self.c_dim)
        
    def forward(self, c_prev, z_prev):
        c = self.rnn_cell(z_prev, c_prev)
        return {"c": c}


class ObsRecurrecne(Deterministic):
    """
    RecurrentLayer
    self.update_obs_beliefの部分
    ht = f_s_rnn(s_t-1||zt, h_t-1)
    """
    def __init__(self, s_dim, z_dim, h_dim):
        super(ObsRecurrecne, self).__init__(cond_var=["s_prev", "z_prev", "h_prev"], var=["h"])
        self.s_dim = s_dim
        self.z_dim = z_dim
        self.h_dim = h_dim
        
        self.rnn_cell = nn.GRUCell(input_size=self.s_dim + self.z_dim, hidden_size=self.h_dim)
        
    def forward(self, s_prev, z_prev, h_prev):
        concat_feature = torch.cat((s_prev, z_prev), dim=-1)
        h = self.rnn_cell(concat_feature, h_prev)
        return {"h": h}a

    
class TransitionOfTemporalAbstraction(Normal):
    """
    p(z_t | c_t) = N(zt | mu(ct), sigma(ct))
    Latent Distribution
    self.prior_abs_stateの部分
    """
    def __init__(self, c_dim, z_dim):
        super(TransitionOfTemporalAbstraction, self).__init__(cond_var=["c"], var=["z"])
        self.c_dim = c_dim
        self.z_dim = z_dim
        
        self.fc1 = nn.Linear(c_dim, c_dim)
        
        self.fc21 = nn.Linear(c_dim, z_dim)
        self.fc22 = nn.Linear(c_dim, z_dim)
    
    def forward(self, c):
        h = nn.ELU(self.fc1(c), inplace=True)
        return {"loc": self.fc21(h), "scale": F.softplus(self.fc22(h))}

    
class TransitionOfObservation(Normal):
    """
    p(st | ht)
    self.prior_obs_state = LatentDistributionの部分
    """
    def __init__(self, h_dim, s_dim):
        super(TransitionOfObservation, self).__init__(cond_var=["h"], var=["s"])
        self.h_dim = h_dim
        self.s_dim = s_dim
        
        self.fc1 = nn.Linear(h_dim, h_dim)
        
        self.fc21 = nn.Linear(h_dim, s_dim)
        self.fc22 = nn.Linear(h_dim, s_dim)
        
    def forward(self, h):
        h_ = nn.ELU(self.fc1(h), inplace=True)
        return {"loc": self.fc21(h_), "scale": F.softplus(self.fc22(h_))}

In [20]:
prior = PriorOnTemporalStructure(3)
print_latex(prior)

<IPython.core.display.Math object>

In [21]:
p = TransitionOfTemporalAbstraction(3, 3)
print_latex(p)

<IPython.core.display.Math object>

In [22]:
p = TransitionOfObservation(3, 3)
print_latex(p)

<IPython.core.display.Math object>

In [18]:
r = Recurrence(3, 3)
print_latex(r)

<IPython.core.display.Math object>

In [7]:
pixyz.__version__

'0.1.3'