# Deep Markov Model

In [24]:
from tqdm import tqdm

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

In [25]:
batch_size = 128
epochs = 5
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x7f77ddc5adf0>

In [26]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [27]:
# Experiment Setting
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

In [28]:
from pixyz.models import Model
from pixyz.losses import KullbackLeibler, CrossEntropy, IterativeLoss
from pixyz.distributions import Bernoulli, Normal, Deterministic
from pixyz.utils import print_latex

In [29]:
class RNN(Deterministic):
    '''
    push the observed x through the rnn
    rnn output contains the hidden state at each time step
    '''
    def __init__(self, x_dim, rnn_dim):
        super(RNN, self).__init__(cond_var=["x"], var=["h"])
        self.rnn = nn.GRU(x_dim, rnn_dim, bidirectional=True)
#         self.h0 = torch.zeros(2, batch_size, self.rnn.hidden_size).to(device)
        self.h0 = nn.Parameter(torch.zeros(2, 1, self.rnn.hidden_size))
        self.hidden_size = self.rnn.hidden_size
        
    def forward(self, x):
        # if on gpu we need the fully broadcast view of the rnn initial state
        # to be in contiguous gpu memory
        h0 = self.h0.expand(2, x.size(1), self.rnn.hidden_size).contiguous()
        h, _ = self.rnn(x, h0)
        return {"h": h}

In [30]:
class Generator(Bernoulli):
    '''
    Emitter
    Parameterizes the bernoulli observation likelihood p(x_t | z_t)
    '''
    def __init__(self, z_dim, hidden_dim):
        super(Generator, self).__init__(cond_var=["z"], var=["x"])
        # initialize the two linear transformations used in the neural network
        self.lin_z_to_hidden = nn.Linear(z_dim, hidden_dim)
        self.lin_hidden_to_input = nn.Linear(hidden_dim, x_dim)

        # initialize the two non-linearities used in the neural network
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, z):
        '''
        Given the latent z at a particular time step t, return the vector of
        probabilities taht parameterizes the bernlulli distribution p(x_t | x_t)
        '''
        h1 = self.relu(self.lin_z_to_hidden(z))
        probs = self.sigmoid(self.lin_hidden_to_input(h1))
        return {"probs": probs}

In [31]:
class Inference(Normal):
    '''
    Combiner
    Parameterizes q(z_t | z_{t-1}, x_{t:T}), which is the basic building block
    of te guide(i.e. the variational distribution). The dependence on x_{t:T} is
    through the hidden state of the RNN
    '''
    def __init__(self, z_dim, rnn_dim):
        super(Inference, self).__init__(cond_var=["h", "z_prev"], var=["z"])
        # initialize the three linear transformations used in the neural network
        self.lin_z_to_hidden = nn.Linear(z_dim, rnn_dim*2)
        self.lin_hidden_to_loc = nn.Linear(rnn_dim*2, z_dim)
        self.lin_hidden_to_scale = nn.Linear(rnn_dim*2, z_dim)
        # initialize the two non-linearities used in the neural network
        self.tanh = nn.Tanh()
        self.softplus = nn.Softplus()

        
    def forward(self, h, z_prev):
        '''
        given the latent z at a particular time step t-1 as well as the hidden
        state of the RNN h(x_{t:T}), return the mean and scale vectors that
        parameterize the gaussian distribution q(z_t | z_{t-1}, x_{t:T})
        '''
        # combine the rnn hideen state with a trasnformed bersion of z_{t-1}
        h_z = self.tanh(self.lin_z_to_hidden(z_prev))
        h_combined = 0.5 * (h + h_z)

        # use the combined hidden state to compute the mean used to sample z_t
        loc = self.lin_hidden_to_loc(h_combined)
        # use the combined hidden state to compute the scale used to sample z_t
        scale = self.softplus(self.lin_hidden_to_scale(h_combined))
        return {"loc": loc, "scale": scale}

In [32]:
class Prior(Normal):
    '''
    GatedTranstion
    Parameterizes the gaussian latent transition probability p(z_t | z_{t-1})
    '''
    def __init__(self, z_dim, transition_dim):
        super(Prior, self).__init__(cond_var=["z_prev"], var=["z"])
        # initialize the 3 linear transformations used in the neural network
        self.lin_gate_z_to_hidden = nn.Linear(z_dim, transition_dim)
        self.lin_gate_hidden_to_z = nn.Linear(transition_dim, z_dim)

        self.lin_proposed_mean_z_to_hidden = nn.Linear(z_dim, transition_dim)
        self.lin_proposed_mean_hidden_to_z = nn.Linear(transition_dim, z_dim)

        self.lin_sig = nn.Linear(z_dim, z_dim)
        self.lin_z_to_loc = nn.Linear(z_dim, z_dim)

        # modify the default initialization of lin_z_to_loc
        # so that it's starts out as the identity function
        self.lin_z_to_loc.weight.data = torch.eye(z_dim)
        self.lin_z_to_loc.bias.data = torch.zeros(z_dim)

        # initialize the 3 non-linearities used in the neural network
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()
        
    def forward(self, z_prev):
        '''
        Given the latent z_{t-1} correspoding to the time step t-1
        return the mean and scale vectors that parameterize the
        gaussian distribution p(z_t | z_{t-1})
        '''
        # compute the gating function
        _gate = self.relu(self.lin_gate_z_to_hidden(z_prev))
        gate = self.sigmoid(self.lin_gate_hidden_to_z(_gate))

        # compute the 'proposed mean'
        _proposed_mean = self.relu(self.lin_proposed_mean_z_to_hidden(z_prev))
        proposed_mean = self.lin_proposed_mean_hidden_to_z(_proposed_mean)

        # assemble the actual mean used to sample z_t, which mixes
        # a linear transformation of z_{t-1} with the proposed mean
        # modulated by the gating function
        # we don't want to force the dynamics be-nonlinear
        loc = (1 - gate) * self.lin_z_to_loc(z_prev) + gate * proposed_mean
        
        # compute the scale used to sample z_t, using the proposed
        # mean from above as input. the softplus ensures that scale is positive
        # scale の元がmeanなの初めてみたよ
        scale = self.softplus(self.lin_sig(self.relu(proposed_mean)))
        
        # return loc, scale which can be fed into Normal
        return {"loc": loc, "scale": scale}

In [33]:
x_dim = 28
hidden_dim = 32
rnn_dim = hidden_dim * 2
transition_dim = 32
z_dim = 16
t_max = x_dim

In [34]:
prior = Prior(z_dim=z_dim, transition_dim=transition_dim).to(device)# p(z_t| z_{t-1})
encoder = Inference(z_dim=z_dim, rnn_dim=rnn_dim).to(device)#q(z_t | z_{t-1}, x_{t:T})
decoder = Generator(z_dim=z_dim, hidden_dim=hidden_dim).to(device)# p(x_t|z_t)
rnn = RNN(x_dim=x_dim, rnn_dim=rnn_dim).to(device)

In [35]:
print(prior)
print("*"*80)
print(encoder)
print("*"*80)
print(decoder)
print("*"*80)
print(rnn)

Distribution:
  p(z|z_{prev})
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev'], input_var=['z_prev'], features_shape=torch.Size([])
    (lin_gate_z_to_hidden): Linear(in_features=16, out_features=32, bias=True)
    (lin_gate_hidden_to_z): Linear(in_features=32, out_features=16, bias=True)
    (lin_proposed_mean_z_to_hidden): Linear(in_features=16, out_features=32, bias=True)
    (lin_proposed_mean_hidden_to_z): Linear(in_features=32, out_features=16, bias=True)
    (lin_sig): Linear(in_features=16, out_features=16, bias=True)
    (lin_z_to_loc): Linear(in_features=16, out_features=16, bias=True)
    (relu): ReLU()
    (sigmoid): Sigmoid()
    (softplus): Softplus(beta=1, threshold=20)
  )
********************************************************************************
Distribution:
  p(z|h,z_{prev})
Network architecture:
  Inference(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['h', 'z_prev'], input_var=['h', 'z_p

In [36]:
generate_from_prior = prior * decoder
print(generate_from_prior)
print_latex(generate_from_prior)

Distribution:
  p(x,z|z_{prev}) = p(x|z)p(z|z_{prev})
Network architecture:
  Prior(
    name=p, distribution_name=Normal,
    var=['z'], cond_var=['z_prev'], input_var=['z_prev'], features_shape=torch.Size([])
    (lin_gate_z_to_hidden): Linear(in_features=16, out_features=32, bias=True)
    (lin_gate_hidden_to_z): Linear(in_features=32, out_features=16, bias=True)
    (lin_proposed_mean_z_to_hidden): Linear(in_features=16, out_features=32, bias=True)
    (lin_proposed_mean_hidden_to_z): Linear(in_features=32, out_features=16, bias=True)
    (lin_sig): Linear(in_features=16, out_features=16, bias=True)
    (lin_z_to_loc): Linear(in_features=16, out_features=16, bias=True)
    (relu): ReLU()
    (sigmoid): Sigmoid()
    (softplus): Softplus(beta=1, threshold=20)
  )
  Generator(
    name=p, distribution_name=Bernoulli,
    var=['x'], cond_var=['z'], input_var=['z'], features_shape=torch.Size([])
    (lin_z_to_hidden): Linear(in_features=16, out_features=32, bias=True)
    (lin_hidden_t

<IPython.core.display.Math object>

In [37]:
step_loss = CrossEntropy(encoder, decoder) + KullbackLeibler(encoder, prior)
_loss = IterativeLoss(step_loss, max_iter=t_max, 
                      series_var=["x", "h"], update_value={"z": "z_prev"})
loss = _loss.expectation(rnn).mean()

In [38]:
dmm = Model(loss, distributions=[rnn, encoder, decoder, prior], 
            optimizer=optim.RMSprop, optimizer_params={"lr": 5e-4}, clip_grad_value=10)

In [39]:
print(dmm)
print_latex(dmm)

Distributions (for training): 
  p(h|x), p(z|h,z_{prev}), p(x|z), p(z|z_{prev}) 
Loss function: 
  mean \left(\mathbb{E}_{p(h|x)} \left[\sum_{t=1}^{28} \left(D_{KL} \left[p(z|h,z_{prev})||p(z|z_{prev}) \right] - \mathbb{E}_{p(z|h,z_{prev})} \left[\log p(x|z) \right]\right) \right] \right) 
Optimizer: 
  RMSprop (
  Parameter Group 0
      alpha: 0.99
      centered: False
      eps: 1e-08
      lr: 0.0005
      momentum: 0
      weight_decay: 0
  )


<IPython.core.display.Math object>

In [40]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        batch_size = data.size()[0]
        x = data.transpose(0, 1)
        z_prev = torch.zeros(batch_size, z_dim).to(device)
        if train_mode:
            mean_loss += model.train({'x': x, 'z_prev': z_prev}).item() * batch_size
        else:
            mean_loss += model.test({'x': x, 'z_prev': z_prev}).item() * batch_size
    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [41]:
def plot_image_from_latent(batch_size):
    x = []
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for step in range(t_max):
        samples = generate_from_prior.sample({'z_prev': z_prev})
        x_t = decoder.sample_mean({"z": samples["z"]})
        z_prev = samples["z"]
        x.append(x_t[None, :])
    x = torch.cat(x, dim=0).transpose(0, 1)
    return x


def reconstruct(data):
    xs = []
    data = data.to(device)
    x = data.transpose(0, 1)
    batch_size = data.size()[0]
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for t in range(t_max):
        h_t = rnn.sample_mean({'x': x})[t]
        z_t = encoder.sample_mean({'h': h_t, 'z_prev': z_prev})
        dec_x = decoder.sample_mean({'z': z_t})
        z_prev = z_t
        xs.append(dec_x[None, :])
    reconst_img = torch.cat(xs, dim=0).transpose(0, 1)
    return reconst_img


def sample_after_n_steps(num_step, data):
    xs = []
    data = data.to(device)
    x = data.transpose(0, 1)
    batch_size = data.size()[0]
    z_prev = torch.zeros(batch_size, z_dim).to(device)
    for t in range(t_max):
        if t+1 <  num_step:
            h_t = rnn.sample_mean({'x': x})[t]
            z_t = encoder.sample_mean({'h': h_t, 'z_prev': z_prev})
            dec_x = decoder.sample_mean({'z': z_t})
            z_prev = z_t
            xs.append(dec_x[None, :])
        else:
            z_t = prior.sample_mean({'z_prev':z_prev})
            dec_x = decoder.sample_mean({'z': z_t})
            z_prev = z_t
            xs.append(dec_x[None, :])
    sample_img = torch.cat(xs, dim=0).transpose(0, 1)
    return sample_img
            

In [43]:
writer = SummaryWriter()
_x, _ = iter(test_loader).next()
_x = _x.to(device)


for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, dmm, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, dmm, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)

    sample = plot_image_from_latent(batch_size)[:, None]
    writer.add_images('Image_from_latent', sample, epoch)
    
    reconst_img = reconstruct(_x)[:, None]
    writer.add_images('reconstructed_img', reconst_img, epoch)
    
    n_step_after_sample = sample_after_n_steps(14, _x)[:, None]
    writer.add_images('n_step', n_step_after_sample, epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)

100%|██████████| 469/469 [01:19<00:00,  5.86it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 1 Train loss: 210.5909


100%|██████████| 79/79 [00:06<00:00, 12.82it/s]


Test loss: 199.9833


100%|██████████| 469/469 [01:20<00:00,  5.79it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 2 Train loss: 194.6889


100%|██████████| 79/79 [00:06<00:00, 13.02it/s]


Test loss: 188.0311


100%|██████████| 469/469 [01:20<00:00,  5.81it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 3 Train loss: 185.8990


100%|██████████| 79/79 [00:06<00:00, 12.29it/s]


Test loss: 182.3530


100%|██████████| 469/469 [01:20<00:00,  5.80it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 4 Train loss: 181.1654


100%|██████████| 79/79 [00:06<00:00, 13.16it/s]


Test loss: 179.2692


100%|██████████| 469/469 [01:20<00:00,  5.84it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

Epoch: 5 Train loss: 179.0450


100%|██████████| 79/79 [00:05<00:00, 13.25it/s]


Test loss: 178.2976
