# DynaNet
- Original paper: [DynaNet: Neural Kalman Dynamical Model for Motion Estimation and Prediction](https://arxiv.org/abs/1908.03918)

 ## install modules and env setting

In [1]:
from tqdm import tqdm

import numpy as np
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from tensorboardX import SummaryWriter

In [2]:
seed = 1
torch.manual_seed(seed)

<torch._C.Generator at 0x7f841d32a610>

In [3]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

## Define probability distribution

Encoder(equation (3) in the paper): $\mathbf{a}_{t}, \boldsymbol{\sigma}_{t}=f_{\mathrm{encoder}}\left(\mathbf{x}_{t}\right)$  

Deterministic Transition(equation (5) in the paper): $\mathbf{A}_{t}=\mathrm{LSTM}\left(\mathbf{z}_{t-1}, \mathbf{h}_{t-1}\right)$

Resampled Transition(equation (6), (7) in the paper):  
$\boldsymbol{\alpha}=\operatorname{LSTM}\left(\mathbf{z}_{t-1}, \mathbf{h}_{t-1}\right)$  
$\mathbf{A}_{t} \sim \operatorname{Dirichlet}(\boldsymbol{\alpha})$

Kalman predictor(equation (8) in the paper):  
$\begin{aligned} \mathbf{z}_{t | t-1} &=\mathbf{A}_{t} \mathbf{z}_{t-1 | t-1} \\ \mathbf{P}_{t | t-1} &=\mathbf{A}_{t} \mathbf{P}_{t-1 | t-1} \mathbf{A}_{t}^{T}+\mathbf{Q}_{t} \end{aligned}$  

Kalman updator(equation (9) in the paper):  
$\begin{aligned} \mathbf{r}_{t} &=\mathbf{a}_{t}-\mathbf{H}_{t} \mathbf{z}_{t | t-1} \\ \mathbf{S}_{t} &=\mathbf{R}_{t}+\mathbf{H}_{t} \mathbf{P}_{t | t-1} \mathbf{H}_{t}^{T} \\ \mathbf{K}_{t} &=\mathbf{P}_{t | t-1} \mathbf{H}_{t}^{T} \mathbf{S}_{t}^{-1} \\ \mathbf{z}_{t | t} &=\mathbf{z}_{t | t-1}+\mathbf{K}_{t} \mathbf{r}_{t} \\ \mathbf{P}_{t | t} &=\left(\mathbf{I}-\mathbf{K}_{t} \mathbf{H}_{t}\right) \mathbf{P}_{t | t-1} \end{aligned}$  

f predictor(equation (10), (11) in the paper):  
$\tilde{\mathbf{y}}_{t}=f_{\text {predictor }}\left(\mathbf{z}_{t | t}\right)$  
$\hat{\mathbf{y}}_{t}=f_{\text {predictor }}\left(\mathbf{z}_{t | t-1}\right)$  

In [4]:
import pixyz
from pixyz.utils import print_latex
from pixyz.distributions import Normal, Deterministic, Dirichlet, Bernoulli

In [5]:
print(pixyz.__version__)

0.2.0


In [6]:
batch_size = 3
x_dim = 10
a_dim = 10
sigma_dim =10
z_dim = 10
h_dim = 10
k_dim = 10
y_dim = 10
t_max = 3

In [7]:
# f_encoder
class Encoder(Deterministic):
    def __init__(self, x_dim, a_dim, sigma_dim):
        super(Encoder, self).__init__(cond_var=["x"], var=["a", "R"])
        self.fc1 = nn.Linear(x_dim, 100)
        self.fc21 = nn.Linear(100, a_dim)
        self.fc22 = nn.Linear(100, sigma_dim)
    
    def forward(self, x):
        h = F.relu(self.fc1(x))
        sigma =F.softplus(self.fc22(h))
        R = torch.diag_embed(sigma)
        return {"a": self.fc21(h), "R": R}

In [8]:
print_latex(Encoder(x_dim, a_dim, sigma_dim).to(device))

<IPython.core.display.Math object>

In [9]:
from torch.distributions.dirichlet import Dirichlet
class ResampledTransition(Deterministic):
    def __init__(self, z_dim, h_dim, k_dim):
        """
        what is k_dim?
        h_dim = k_dim =  z_dim
        """
        super(ResampledTransition, self).__init__(name="Transition", cond_var=["z_prev_prev", "h_prev", "c_prev"], var=["h", "c", "A", "sigma_Q"])
        self.rnn_1 = nn.LSTMCell(z_dim, h_dim)
        self.rnn_2 = nn.LSTMCell(h_dim, h_dim)
        self.dropout = nn.Dropout(0.2)
        
        self.fc1 = nn.Linear(h_dim, k_dim)
        self.fc2 = nn.Linear(h_dim, k_dim)
        
        self.alpha_diagonal = torch.eye(k_dim).to(device)
    
    def forward(self, z_prev_prev, h_prev, c_prev):
        h_1, c_1 = self.rnn_1(z_prev_prev, (h_prev[0], c_prev[0]))
        rnn_2_input = self.dropout(h_1)
        h_2, c_2 = self.rnn_2(h_1, (h_prev[1], c_prev[1]))
        
        alpha = F.relu(self.fc1(h_2))
        alpha = Dirichlet(alpha.cpu()).rsample().to(device)
        sigma_q = F.relu(self.fc2(alpha))
        
        # convert alpha, sigma_q to (batch_size, z_dim, z_dim) diagonal matrix
        A = alpha.unsqueeze(2).expand(*alpha.size(), alpha.size(1))*self.alpha_diagonal
        sigma_Q = sigma_q.unsqueeze(2).expand(*sigma_q.size(), sigma_q.size(1))*self.alpha_diagonal
        
        h = [h_1, h_2]
        c = [c_1, c_2]
        
        return {"h": h, "c": c, "A": A, "sigma_Q": sigma_Q}

In [10]:
print_latex(ResampledTransition(z_dim, h_dim, k_dim).to(device))

<IPython.core.display.Math object>

In [11]:
class KalmanPredictor(Deterministic):
    def __init__(self, z_dim):
        super(KalmanPredictor, self).__init__(name="KalmanPredictor", cond_var=["z_prev_prev", "P_prev_prev", "A", "sigma_Q"], var=["z_prev", "P_prev"])
        
        # init Q somehow
        self.Q = torch.tensor(0.08 * np.eye(z_dim, dtype=np.float32)).to(device)
        self.Q = self.Q.unsqueeze(0)
        self.Q = self.Q.repeat(batch_size, 1, 1) # (bs, dim_z, dim_z)
    
    def forward(self, z_prev_prev, P_prev_prev, A, sigma_Q):
        """
        z_prev_prev: (batch_size, z_dim)
        P_prev_prev: (batch_size, z_dim, z_dim)
        A: (batch_size, z_dim, z_dim)
        """
        Q = self.Q + sigma_Q
    
        # z_prev: (batch_size, z_dim)
        z_prev = torch.bmm(A, z_prev_prev.unsqueeze(2)).squeeze(2)
        # P_prev: (batch_size, z_dim, z_dim)
        P_prev = torch.bmm(torch.bmm(A, P_prev_prev), A.transpose(2, 1)) + self.Q
        return {"z_prev": z_prev, "P_prev": P_prev}

In [12]:
print_latex(KalmanPredictor(z_dim).to(device))

<IPython.core.display.Math object>

In [13]:
class KalmanUpdator(Deterministic):
    def __init__(self, a_dim, z_dim):
        super(KalmanUpdator, self).__init__(name="KalmanUpdator", cond_var=["z_prev", "P_prev", "a", "R"], var=["z", "P"])
        # assume a_dim == z_dim
        
        # for section 3.3 setting
        # self.H (batch_size, a_dim, 2 * a_dim)
        # self.H =torch.from_numpy(np.array([np.concatenate((np.eye(a_dim).astype(np.float32),
        #                                                             np.zeros((a_dim, a_dim)).astype(np.float32)), axis=1)
        #                                             for _ in range(batch_size)])
        
        self.H = torch.eye(a_dim, z_dim).to(device)
        self.H = self.H.unsqueeze(0)
        self.H = self.H.repeat(batch_size, 1, 1)
        
        # self.I (batch_size, a_dim, z_dim)
        self.I = torch.eye(a_dim, z_dim).to(device)
        self.I = self.I.unsqueeze(0)
        self.I = self.I.repeat(batch_size, 1, 1)
        
        
    def forward(self, z_prev, P_prev, a, R):
        """
        z_prev: (batch_size, z_dim)
        P_prev: (batch_size, z_dim, z_dim)
        a (observation): (batch_size, a_dim)
        H: (a_dim, z_dim)
        R: (batch_size, a_dim, a_dim)
        """
        # r (batch_size, a_dim)
        r = a - torch.bmm(self.H, z_prev.unsqueeze(2)).squeeze(2)
        
        # S: (batch_size, a_dim, a_dim)
        S = R + torch.bmm(torch.bmm(self.H, P_prev), self.H.transpose(2, 1))
        
        # K: (batch_size, a_dim, a_dim)
        K = torch.bmm(torch.bmm(P_prev, self.H.transpose(2, 1)), S.inverse())
        
        # z: (batch_size, z_dim)
        z = z_prev + torch.bmm(K, r.unsqueeze(2)).squeeze(2)
        
        # P: (batch_size, z_dim, z_dim)
        P = torch.bmm((self.I - torch.bmm(K, self.H)), P_prev)
        return {"z": z, "P": P}

In [14]:
print_latex(KalmanUpdator(a_dim, z_dim).to(device))

<IPython.core.display.Math object>

In [15]:
class FPredictor(Bernoulli):
    def __init__(self, z_dim, y_dim):
        super(FPredictor, self).__init__(name="f_predictor", cond_var=["z"], var=["y"])
        
        self.fc1 = nn.Linear(z_dim, 100)
        self.fc2 = nn.Linear(100, y_dim)
        self.scale = torch.ones(y_dim).to(device)
        
    def forward(self, z):
        h = F.relu(self.fc1(z))
        #return {"loc": self.fc2(h), "scale": self.scale}
        return {"probs": torch.sigmoid(self.fc2(h))}

In [16]:
print_latex(FPredictor(z_dim, y_dim).to(device))

<IPython.core.display.Math object>

## Define Loss(in progress)
Loss(equation (12) in the paper)  
$L(\theta)=\frac{1}{T} \sum_{t=1}^{T}\left(\left\|\mathbf{y}_{t}-\tilde{\mathbf{y}}_{t}\right\|^{2}+\left\|\mathbf{y}_{t}-\hat{\mathbf{y}}_{t}\right\|^{2}\right)$

In [17]:
from pixyz.losses import LogProb, IterativeLoss
from pixyz.distributions import ReplaceVarDistribution

encoder = Encoder(x_dim, a_dim, sigma_dim).to(device)
transition = ResampledTransition(z_dim, h_dim, k_dim).to(device)
kalman_predictor = KalmanPredictor(z_dim).to(device)
kalman_updator = KalmanUpdator(a_dim, z_dim).to(device)
f_predictor = FPredictor(z_dim, y_dim).to(device)

# loss
loss_tilda = -LogProb(f_predictor).sum()
loss_hat = -LogProb(ReplaceVarDistribution(f_predictor, {"z": "z_prev"})).sum()

step_loss = loss_tilda + loss_hat
# loss = IterativeLoss(step_loss, max_iter=t_max, 
#                       series_var=["x"], update_value={"z": "z_prev_prev", "P": "P_prev_prev", "h": "h_prev", "c": "c_prev"}).mean()
print_latex(step_loss)

<IPython.core.display.Math object>

## MNIST experiment

In [18]:
batch_size = 256
epochs = 5
seed = 1
torch.manual_seed(seed)

x_dim = 28
a_dim = 10
sigma_dim =10
z_dim = 10
h_dim = 10
k_dim = 10
y_dim = 28

In [19]:
# generate MNIST by stacking row images(consider row as time step)
def init_dataset(f_batch_size):
    kwargs = {'num_workers': 1, 'pin_memory': True}
    data_dir = '../data'
    mnist_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda data: data[0])
    ])
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, drop_last=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=False, transform=mnist_transform),
        batch_size=f_batch_size, shuffle=True, drop_last=True , **kwargs)

    fixed_t_size = 28
    return train_loader, test_loader, fixed_t_size

train_loader, test_loader, t_max = init_dataset(batch_size)

In [20]:
from pixyz.models import Model

class DynaNet(Model):
    def __init__(self,
                 optimizer=optim.Adam,
                 optimizer_params={},
                 clip_grad_norm=None,
                 clip_grad_value=None):
        
        # distribution
        self.encoder = Encoder(x_dim, a_dim, sigma_dim).to(device)
        self.transition = ResampledTransition(z_dim, h_dim, k_dim).to(device)
        self.kalman_predictor = KalmanPredictor(z_dim).to(device)
        self.kalman_updator = KalmanUpdator(a_dim, z_dim).to(device)
        self.f_predictor = FPredictor(z_dim, y_dim).to(device)
        
        self.sampler = self.encoder * self.transition * self.kalman_predictor * self.kalman_updator
        
        # loss
        loss_tilda = -LogProb(self.f_predictor)
        loss_hat = -LogProb(ReplaceVarDistribution(self.f_predictor, {"z": "z_prev"}))
        self.step_loss = loss_tilda.sum() + loss_hat.sum()
        
        distributions = [self.encoder, self.transition, self.f_predictor]
        self.distributions = nn.ModuleList(distributions)

        # set params and optim
        params = self.distributions.parameters()
        self.optimizer = optimizer(params, **optimizer_params)

        self.clip_norm = clip_grad_norm
        self.clip_value = clip_grad_value
        
        
    
    def calculate_loss(self, input_var_dict={}):
        batch_size = input_var_dict['x'].size()[1]
        x = input_var_dict['x']
        
        z_prev_prev = torch.zeros([batch_size, z_dim]).to(device)
        h_prev = -1. + 2 * torch.rand([2, batch_size, h_dim]).to(device)
        c_prev = torch.randn([2, batch_size, h_dim]).to(device)
        P_prev_prev = 20 * torch.eye(z_dim).to(device)
        P_prev_prev = P_prev_prev.unsqueeze(0)
        P_prev_prev = P_prev_prev.repeat(batch_size, 1, 1)
        
        input_var_dict["z_prev_prev"] = z_prev_prev
        input_var_dict["h_prev"] = h_prev
        input_var_dict["c_prev"] = c_prev
        input_var_dict["P_prev_prev"] = P_prev_prev
        # Without Iterative Loss
        total_loss = 0
        for time_step in range(t_max):
            input_var_dict["x"] = x[time_step]
            sampled_dict = self.sampler.sample(input_var_dict)
            
            # following procedure shows what self.sampler.sample() does
            """
            ---------------------------------------------------
            encoded = self.encoder.sample({"x": x[time_step]})
            
            a, R = encoded["a"], encoded["R"]
            
            transition_output = self.transition.sample({"z_prev_prev": z_prev_prev, "h_prev": h_prev, "c_prev": c_prev})
            h = transition_output["h"]
            c = transition_output["c"]
            A = transition_output["A"]
            sigma_Q = transition_output["sigma_Q"]
            
            kalman_predicted = self.kalman_predictor.sample({"z_prev_prev": z_prev_prev, "P_prev_prev": P_prev_prev, "A": A, "sigma_Q": sigma_Q})
            z_prev = kalman_predicted["z_prev"]
            P_prev = kalman_predicted["P_prev"]

            kalman_updated = self.kalman_updator.sample({"z_prev": z_prev, "P_prev": P_prev, "a": a, "R": R})
            z, P = kalman_updated["z"], kalman_updated["P"]
            ----------------------------------------------------
            """
            
            # update
            input_var_dict["h_prev"] = sampled_dict["h"]
            input_var_dict["c_prev"] = sampled_dict["c"]
            input_var_dict["z_prev_prev"] = sampled_dict["z"]
            input_var_dict["P_prev_prev"] = sampled_dict["P"]
            
            sampled_dict["y"] = x[time_step]
            
            # loss
            total_loss += self.step_loss.eval(sampled_dict)
        
        total_loss = total_loss / t_max
        return total_loss
    
    def train(self, train_x_dict={}):
        self.distributions.train()

        self.optimizer.zero_grad()
        loss = self.calculate_loss(train_x_dict)

        # backprop
        loss.backward()

        if self.clip_norm:
            clip_grad_norm_(self.distributions.parameters(), self.clip_norm)
        if self.clip_value:
            clip_grad_value_(self.distributions.parameters(), self.clip_value)

        # update params
        self.optimizer.step()

        return loss.item()
    def test(self, test_x_dict={}):
        self.distributions.eval()

        with torch.no_grad():
            loss = self.calculate_loss(test_x_dict)

        return loss.item()

    def reconst_image(self, original_data):
        self.distributions.eval()
        with torch.no_grad():
            xs = []
            x = original_data.transpose(0, 1)
            batch_size = original_data.size()[0]
            z_prev_prev = torch.zeros([batch_size, z_dim]).to(device)
            h_prev = -1. + 2 * torch.rand([2, batch_size, h_dim]).to(device)
            c_prev = torch.randn([2, batch_size, h_dim]).to(device)
            P_prev_prev = 20 * torch.eye(z_dim).to(device)
            P_prev_prev = P_prev_prev.unsqueeze(0)
            P_prev_prev = P_prev_prev.repeat(batch_size, 1, 1)
            
            input_var_dict = {"z_prev_prev": z_prev_prev, "h_prev": h_prev, "c_prev": c_prev, "P_prev_prev": P_prev_prev}
            for time_step in range(t_max):
                input_var_dict["x"] = x[time_step]
                sampled_dict = self.sampler.sample(input_var_dict)

                # update
                input_var_dict["h_prev"] = sampled_dict["h"]
                input_var_dict["c_prev"] = sampled_dict["c"]
                input_var_dict["z_prev_prev"] = sampled_dict["z"]
                input_var_dict["P_prev_prev"] = sampled_dict["P"]
                dec_x = self.f_predictor.sample_mean({"z": sampled_dict["z"]})
                xs.append(dec_x[None, :])
            recon_img = torch.cat(xs, dim=0).transpose(0, 1)
        return recon_img
    
    def generate_nstep(self, original_data, n_step=14):
        self.distributions.eval()
        with torch.no_grad():
            xs = []
            x = original_data.transpose(0, 1)
            batch_size = original_data.size()[0]
            z_prev_prev = torch.zeros([batch_size, z_dim]).to(device)
            h_prev = -1. + 2 * torch.rand([2, batch_size, h_dim]).to(device)
            c_prev = torch.randn([2, batch_size, h_dim]).to(device)
            P_prev_prev = 20 * torch.eye(z_dim).to(device)
            P_prev_prev = P_prev_prev.unsqueeze(0)
            P_prev_prev = P_prev_prev.repeat(batch_size, 1, 1)
            
            input_var_dict = {"z_prev_prev": z_prev_prev, "h_prev": h_prev, "c_prev": c_prev, "P_prev_prev": P_prev_prev}
            for time_step in range(t_max):
                if time_step < n_step:
                    input_var_dict["x"] = x[time_step]
                    sampled_dict = self.sampler.sample(input_var_dict)

                    # update
                    input_var_dict["h_prev"] = sampled_dict["h"]
                    input_var_dict["c_prev"] = sampled_dict["c"]
                    input_var_dict["z_prev_prev"] = sampled_dict["z"]
                    input_var_dict["P_prev_prev"] = sampled_dict["P"]
                    dec_x = self.f_predictor.sample_mean({"z": sampled_dict["z"]})
                else:
                    transition_output = self.transition.sample({"z_prev_prev": z_prev_prev, "h_prev": h_prev, "c_prev": c_prev})
                    h = transition_output["h"]
                    c = transition_output["c"]
                    A = transition_output["A"]
                    sigma_Q = transition_output["sigma_Q"]


                    kalman_predicted = self.kalman_predictor.sample({"z_prev_prev": z_prev_prev, "P_prev_prev": P_prev_prev, "A": A, "sigma_Q": sigma_Q})
                    z_prev = kalman_predicted["z_prev"]
                    P_prev = kalman_predicted["P_prev"]
                    # update
                    h_prev = h
                    c_prev = c
                    z_prev_prev = z_prev
                    P_prev_prev = P_prev
                    dec_x = self.f_predictor.sample_mean({"z": z_prev})
                    
                xs.append(dec_x[None, :])
            generated_img = torch.cat(xs, dim=0).transpose(0, 1)
        return generated_img

In [21]:
_x, _ = iter(test_loader).next()
fixed_batch = _x.to(device)
batch_size = fixed_batch.size()[0]
sequential_x = fixed_batch.transpose(0, 1)

dynanet = DynaNet()
dynanet.calculate_loss(input_var_dict={'x': sequential_x})

tensor(9857.7939, device='cuda:0', grad_fn=<DivBackward0>)

In [22]:
def data_loop(epoch, loader, model, device, train_mode=False):
    mean_loss = 0
    for batch_idx, (data, _) in enumerate(tqdm(loader)):
        data = data.to(device)
        batch_size = data.size()[0]
        x = data.transpose(0, 1)
        #q_z_prev = torch.zeros(batch_size, z_dim).to(device)
        if train_mode:
            mean_loss += model.train({'x': x}) * batch_size
        else:
            mean_loss += model.test({'x': x}) * batch_size
    mean_loss /= len(loader.dataset)
    if train_mode:
        print('Epoch: {} Train loss: {:.4f}'.format(epoch, mean_loss))
    else:
        print('Test loss: {:.4f}'.format(mean_loss))
    return mean_loss

In [None]:
import datetime

dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

In [23]:
import pixyz
v = pixyz.__version__
writer = SummaryWriter("../runs/" + v + ".dynanet" + exp_time)
# fixed _x for watching reconstruction improvement
_x, _ = iter(test_loader).next()
_x = _x.to(device)
dynanet = DynaNet(optimizer=optim.Adam, optimizer_params={'lr': 1e-3})

import time
start = time.time()

for epoch in range(1, epochs + 1):
    train_loss = data_loop(epoch, train_loader, dynanet, device, train_mode=True)
    test_loss = data_loop(epoch, test_loader, dynanet, device)

    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss', test_loss, epoch)
    
    recon_img = dynanet.reconst_image(_x)
    writer.add_images('Reconstructed',  recon_img[:, None], epoch)
    
    generated_img = dynanet.generate_nstep(_x)
    writer.add_images('Generated',  generated_img[:, None], epoch)
    
    writer.add_images('orignal', _x[:, None], epoch)
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time)
writer.close()

100%|██████████| 234/234 [01:06<00:00,  3.54it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 1 Train loss: 4728.1433


100%|██████████| 39/39 [00:04<00:00,  7.99it/s]


Test loss: 2780.0844


100%|██████████| 234/234 [01:07<00:00,  3.47it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 2 Train loss: 2540.3611


100%|██████████| 39/39 [00:05<00:00,  7.54it/s]


Test loss: 2402.5258


100%|██████████| 234/234 [01:07<00:00,  3.46it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 3 Train loss: 2367.6742


100%|██████████| 39/39 [00:05<00:00,  7.66it/s]


Test loss: 2320.4956


100%|██████████| 234/234 [01:07<00:00,  3.46it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 4 Train loss: 2305.7899


100%|██████████| 39/39 [00:05<00:00,  7.77it/s]


Test loss: 2266.7658


100%|██████████| 234/234 [01:07<00:00,  3.45it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

Epoch: 5 Train loss: 2258.3235


100%|██████████| 39/39 [00:05<00:00,  7.27it/s]


Test loss: 2224.5651
