In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer import SVI, Trace_ELBO

In [4]:
class Emitter(nn.Module):
    def __init__(self, input_dim, z_dim, emission_dim):
        super(Emitter, self).__init__()
        self.lin_z_to_hidden = nn.Linear(z_dim, emission_dim)
        self.lin_hidden_to_hidden = nn.Linear(emission_dim, emission_dim)
        self.lin_hidden_to_input = nn.Linear(emission_dim, input_dim)
        
    def __call__(self, z_t):
        h1 = F.relu(self.lin_z_to_hidden(z_t))
        h2 = F.relu(self.lin_hidden_to_hidden(h1))
        ps = F.sigmoid(self.lin_hidden_to_input(h2))
        return ps

In [5]:
class GatedTransition(nn.Module):
    def __init__(self, z_dim, transition_dim):
        super(GatedTransition, self).__init__()
        self.lin_gate_z_to_hidden = nn.Linear(z_dim, transition_dim)
        self.lin_gate_hidden_to_z = nn.Linear(transition_dim, z_dim)
        self.lin_proposed_mean_z_to_hidden = nn.Linear(z_dim, transition_dim)
        self.lin_proposed_mean_hidden_to_z = nn.Linear(transition_dim, z_dim)
        self.lin_sig = nn.Linear(z_dim, z_dim)
        self.lin_z_to_loc = nn.Linear(z_dim, z_dim)
        self.lin_z_to_loc.weight.data = torch.eye(z_dim)
        self.lin_z_to_loc.bias.data = torch.zeros(z_dim)
        
    def forward(self, z_t_1):
        _gate = F.relu(self.lin_gate_z_to_hidden(z_t_1))
        gate = F.sigmoid(self.lin_gate_hidden_to_z(_gate))
        _proposed_mean = F.relu(self.lin_proposed_mean_z_to_hidden(z_t_1))
        proposed_mean = self.lin_proposed_mean_hidden_to_z(_proposed_mean)
        loc = (1 - gate) * self.lin_z_to_loc(z_t_1) + gate * proposed_mean
        scale = F.softplus(self.lin_sig(F.relu(proposed_mean)))
        return loc, scale

In [13]:
class DeepMarkov(nn.Module):
    def model(self, mini_batch, mini_batch_reversed, mini_batch_mask,
              mini_batch_seq_lengths, annealing_factor=1.0):
        T_max = mini_batch.size(1)
        # Register all of PyTorch submodules
        pyro.module('dmm', self)
        
        z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))
        
        with pyro.iarange('z_minibatch', len(mini_batch)):
            for t in range(1, T_max + 1):
                z_loc, z_scale = self.trans(z_prev)
                z_t = pyro.sample("z_%d" % t, dist.Normal(z_loc, z_scale))
                emission_probs_t = self.emitter(z_t)
                pyro.sample('obs_x_%d' % t,
                            dist.Bernoulli(emission_probs_t),
                            obs=mini_batch[:, t-1, :])
                with poutine.scale(None, annealing_factor):
                    z_t = pyro.sample('z_%d' % t,
                                      dist.Normal(z_loc, z_scale)
                                      .mask(mini_batch_mask[:, t-1:t])
                                      .independent(1))
                
                emission_probs_t = self.emitter(z_t)
            
                pyro.sample("obs_x_%d" % t,
                            dist.Bernoulli(emission_probs_t)
                                .mask(mini_batch_mask[:, t - 1:t])
                                .independent(1),
                            obs=mini_batch[:, t - 1, :])
            
                z_prev = z_t

In [7]:
import pyro

$y = x + sin(\alpha(x + w)) + sin(\beta(x + w)) + w$ with $w \sim N(0, 0.032), \alpha = 4, \beta = 13$