# D$^2$PCCA Model

Code adapted from https://github.com/pyro-ppl/pyro/blob/dev/examples/dmm.py

In [None]:
!pip install pyro-ppl

import os
from google.colab import files

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import pyro
import pyro.poutine as poutine
import pyro.distributions as dist
from pyro.distributions import TransformedDistribution
from pyro.distributions.transforms import affine_autoregressive
from pyro.infer import (
    SVI,
    JitTrace_ELBO,
    Trace_ELBO,
    TraceEnum_ELBO,
    TraceTMC_ELBO,
    config_enumerate,
)
from pyro.optim import (
    Adam,
    ClippedAdam,
)

In [None]:
# p(x_t | z_t) or p(y_t | z_t)
class Emitter_1_input(nn.Module):
    def __init__(self, output_dim, z_dim, emission_dim):
        super().__init__()
        self.lin_z_to_hidden = nn.Linear(z_dim, emission_dim)
        self.lin_hidden_to_hidden = nn.Linear(emission_dim, emission_dim)
        self.lin_hidden_to_loc = nn.Linear(emission_dim, output_dim)
        self.lin_hidden_to_scale = nn.Linear(emission_dim, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z_t):
        h1 = self.relu(self.lin_z_to_hidden(z_t))
        h2 = self.relu(self.lin_hidden_to_hidden(h1))
        x_loc = self.lin_hidden_to_loc(h2)
        x_scale = torch.exp(self.lin_hidden_to_scale(h2))
        return x_loc, x_scale

# p(x_t | z_t, z_t^1) or p(y_t | z_t, z_t^2)
class Emitter_2_input(nn.Module):
    def __init__(self, x_dim, z_dim, zx_dim, emission_dim_zx_x):
        super().__init__()
        self.lin_z_to_hidden = nn.Linear(z_dim + zx_dim, emission_dim_zx_x)
        self.lin_hidden_to_hidden = nn.Linear(emission_dim_zx_x, emission_dim_zx_x)
        self.lin_hidden_to_loc = nn.Linear(emission_dim_zx_x, x_dim)
        self.lin_hidden_to_scale = nn.Linear(emission_dim_zx_x, x_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z_t, zi_t):
        z_combined = torch.cat((z_t, zi_t), dim=-1)
        h1 = self.relu(self.lin_z_to_hidden(z_combined))
        h2 = self.relu(self.lin_hidden_to_hidden(h1))
        x_loc = self.lin_hidden_to_loc(h2)
        x_scale = torch.exp(self.lin_hidden_to_scale(h2))
        return x_loc, x_scale

# p(z_t | z_{t-1})
class GatedTransition(nn.Module):
    def __init__(self, z_dim, zx_dim=0, zy_dim=0, transition_dim_z=10, transition_dim_zx=10, transition_dim_zy=10):
        super().__init__()
        # g_t: gating units
        self.lin_gate_z_to_hidden = nn.Linear(z_dim, transition_dim_z)
        self.lin_gate_hidden_to_z = nn.Linear(transition_dim_z, z_dim)
        # h_t: proposed mean
        self.lin_proposed_mean_z_to_hidden = nn.Linear(z_dim, transition_dim_z)
        self.lin_proposed_mean_hidden_to_z = nn.Linear(transition_dim_z, z_dim)
        # S(Z_{t-1}): variance
        self.lin_sig = nn.Linear(z_dim, z_dim)
        # MLP(z_{t-1}, I)
        self.lin_z_to_loc = nn.Linear(z_dim, z_dim)
        # modify the default initialization of lin_z_to_loc so that it's starts out as the identity function
        self.lin_z_to_loc.weight.data = torch.eye(z_dim)
        self.lin_z_to_loc.bias.data = torch.zeros(z_dim)

        if zx_dim != 0:
            self.lin_gate_zx_to_hidden = nn.Linear(zx_dim, transition_dim_zx)
            self.lin_gate_hidden_to_zx = nn.Linear(transition_dim_zx, zx_dim)
            self.lin_proposed_mean_zx_to_hidden = nn.Linear(zx_dim, transition_dim_zx)
            self.lin_proposed_mean_hidden_to_zx = nn.Linear(transition_dim_zx, zx_dim)
            self.lin_sig_x = nn.Linear(zx_dim, zx_dim)
            self.lin_zx_to_loc = nn.Linear(zx_dim, zx_dim)
            self.lin_zx_to_loc.weight.data = torch.eye(zx_dim)
            self.lin_zx_to_loc.bias.data = torch.zeros(zx_dim)

        if zy_dim != 0:
            self.lin_gate_zy_to_hidden = nn.Linear(zy_dim, transition_dim_zy)
            self.lin_gate_hidden_to_zy = nn.Linear(transition_dim_zy, zy_dim)
            self.lin_proposed_mean_zy_to_hidden = nn.Linear(zy_dim, transition_dim_zy)
            self.lin_proposed_mean_hidden_to_zy = nn.Linear(transition_dim_zy, zy_dim)
            self.lin_sig_y = nn.Linear(zy_dim, zy_dim)
            self.lin_zy_to_loc = nn.Linear(zy_dim, zy_dim)
            self.lin_zy_to_loc.weight.data = torch.eye(zy_dim)
            self.lin_zy_to_loc.bias.data = torch.zeros(zy_dim)

        self.z_dim = z_dim
        self.zx_dim = zx_dim
        self.zy_dim = zy_dim
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softplus = nn.Softplus()

    def forward(self, z_t_1):
        z_t = z_t_1[:, :self.z_dim]
        # g_t: gating units
        _gate = self.relu(self.lin_gate_z_to_hidden(z_t))
        gate = self.sigmoid(self.lin_gate_hidden_to_z(_gate))
        # h_t: proposed mean
        _proposed_mean = self.relu(self.lin_proposed_mean_z_to_hidden(z_t))
        proposed_mean = self.lin_proposed_mean_hidden_to_z(_proposed_mean)
        # loc, sacle
        z_loc = (1 - gate) * self.lin_z_to_loc(z_t) + gate * proposed_mean
        z_scale = self.softplus(self.lin_sig(self.relu(proposed_mean)))

        # naive model: z_t -> z_{t+1}
        if self.zx_dim == 0 and self.zy_dim == 0:
            return z_loc, z_scale

        # cca: z_t -> z_{t+1}; zx_t -> zx_{t+1}; zy_t -> zy_{t+1}
        # z_t_1 = [z, zx, zy]
        if self.zx_dim != 0:
            # zx_t -> zx_{t+1}
            zx_t = z_t_1[:, self.z_dim:self.z_dim+self.zx_dim]
            _gate_x = self.relu(self.lin_gate_zx_to_hidden(zx_t))
            gate_x = self.sigmoid(self.lin_gate_hidden_to_zx(_gate_x))
            _proposed_mean_x = self.relu(self.lin_proposed_mean_zx_to_hidden(zx_t))
            proposed_mean_x = self.lin_proposed_mean_hidden_to_zx(_proposed_mean_x)
            zx_loc = (1 - gate_x) * self.lin_zx_to_loc(zx_t) + gate_x * proposed_mean_x
            zx_scale = self.softplus(self.lin_sig_x(self.relu(proposed_mean_x)))

            if self.zy_dim == 0:
                # pls: z_t -> z_{t+1}; zx_t -> zx_{t+1}
                pls_loc = torch.cat((z_loc, zx_loc), dim=-1)
                pls_scale = torch.cat((z_scale, zx_scale), dim=-1)
                return pls_loc, pls_scale

            # zy_t -> zy_{t+1}
            zy_t = z_t_1[:, self.z_dim+self.zx_dim:]
            _gate_y = self.relu(self.lin_gate_zy_to_hidden(zy_t))
            gate_y = self.sigmoid(self.lin_gate_hidden_to_zy(_gate_y))
            _proposed_mean_y = self.relu(self.lin_proposed_mean_zy_to_hidden(zy_t))
            proposed_mean_y = self.lin_proposed_mean_hidden_to_zy(_proposed_mean_y)
            zy_loc = (1 - gate_y) * self.lin_zy_to_loc(zy_t) + gate_y * proposed_mean_y
            zy_scale = self.softplus(self.lin_sig_y(self.relu(proposed_mean_y)))

            # concat
            cca_loc = torch.cat((z_loc, zx_loc, zy_loc), dim=-1)
            cca_scale = torch.cat((z_scale, zx_scale, zy_scale), dim=-1)
            return cca_loc, cca_scale



# q(z_t|z_{t-1}, h_t^r)
class Combiner(nn.Module):
    def __init__(self, z_dim, rnn_dim):
        super().__init__()
        self.lin_z_to_hidden = nn.Linear(z_dim, rnn_dim)
        self.lin_hidden_to_loc = nn.Linear(rnn_dim, z_dim)
        self.lin_hidden_to_scale = nn.Linear(rnn_dim, z_dim)
        self.tanh = nn.Tanh()
        self.softplus = nn.Softplus()

    def forward(self, z_t_1, h_rnn):
        h_combined = 0.5 * (self.tanh(self.lin_z_to_hidden(z_t_1)) + h_rnn)
        loc = self.lin_hidden_to_loc(h_combined)
        scale = self.softplus(self.lin_hidden_to_scale(h_combined))
        return loc, scale

# z: hidden chain shared by x and y.
# z_x: hidden chain unique for x
# z_y: hidden chain unique for y

# pls: zx -> x <- z -> y

class D2PCCA(nn.Module):
    def __init__(
        self,
        x_dim=100,                # x dimensions
        y_dim=100,                # y dimensions
        z_dim=100,                # z dimensions
        zx_dim=0,                 # z_x dimensions
        zy_dim=0,                 # z_y dimensions
        emission_dim_z_x=100,     # hidden dimensions in emission network from z to x
        emission_dim_z_y=100,     # hidden dimensions in emission network from z to y
        emission_dim_zx_x=0,      # hidden dimensions in emission network from z,z_x to x
        emission_dim_zy_y=0,      # hidden dimensions in emission network from z,z_y to y
        transition_dim_z=100,     # hidden dimension in transition network z
        transition_dim_zx=0,      # hidden dimension in transition network z_x
        transition_dim_zy=0,      # hidden dimension in transition network z_y
        rnn_dim=600,              # RNN hidden dimensions
        num_layers=1,             # RNN layers
        rnn_dropout_rate=0.0,     # RNN dropout rate
        multisteps=1,
        num_iafs=0,
        iaf_dim=50,
        beta_d = .1
    ):
        super().__init__()

        # naive model: x <- z -> y
        if zx_dim == 0 and zy_dim == 0:
            self.emitter_z_x = Emitter_1_input(x_dim, z_dim, emission_dim_z_x) # z -> x
            self.emitter_z_y = Emitter_1_input(y_dim, z_dim, emission_dim_z_y) # z -> y
        # cca: zx -> x <- z -> y <- zy
        if zx_dim != 0 and zy_dim != 0:
            self.emitter_zx_x = Emitter_2_input(x_dim, z_dim, zx_dim, emission_dim_zx_x) # z, z_x -> x
            self.emitter_zy_y = Emitter_2_input(y_dim, z_dim, zy_dim, emission_dim_zy_y) # z, z_y -> y
        # pls: zx -> x <- z -> y
        if zx_dim != 0 and zy_dim == 0:
            self.emitter_zx_x = Emitter_2_input(x_dim, z_dim, zx_dim, emission_dim_zx_x) # z, z_x -> x
            self.emitter_z_y = Emitter_1_input(y_dim, z_dim, emission_dim_z_y) # z -> y

        self.trans = GatedTransition(z_dim, zx_dim, zy_dim, transition_dim_z, transition_dim_zx, transition_dim_zy) # z_t -> z_{t+1}
        self.combiner = Combiner(z_dim + zx_dim + zy_dim, rnn_dim)
        # dropout just takes effect on inner layers of rnn
        rnn_dropout_rate = 0.0 if num_layers == 1 else rnn_dropout_rate
        self.rnn = nn.RNN(
            input_size=x_dim + y_dim,
            hidden_size=rnn_dim,
            nonlinearity="relu",
            batch_first=True,
            bidirectional=False,
            num_layers=num_layers,
            dropout=rnn_dropout_rate,
        )
        # normalizing flows
        self.iafs = [
            affine_autoregressive(z_dim+zx_dim+zy_dim, hidden_dims=[iaf_dim]) for _ in range(num_iafs)
        ]
        self.iafs_modules = nn.ModuleList(self.iafs)
        # p(z_0)
        self.z_0 = nn.Parameter(torch.zeros(z_dim + zx_dim + zy_dim))
        # q(z_0)
        self.z_q_0 = nn.Parameter(torch.zeros(z_dim + zx_dim + zy_dim))
        # initial hidden state of the rnn
        self.h_0 = nn.Parameter(torch.zeros(1, 1, rnn_dim))
        # register attributes
        self.z_dim = z_dim
        self.zx_dim = zx_dim
        self.zy_dim = zy_dim
        self.multisteps = multisteps
        self.beta_d = beta_d

    def model(self, mini_batch_x, mini_batch_y, annealing_factor=1.0):
        pyro.module("dmm", self)
        T_max = mini_batch_x.size(1) # T
        batch_size = mini_batch_x.size(0) # batch size

        # p(z_0)
        z_prev = self.z_0.expand(batch_size, self.z_0.size(0)) # replicate z_0 batch_size times.

        with pyro.plate("z_minibatch", batch_size):
            for t in pyro.markov(range(1, T_max + 1)):
                # p(z_hat_t | z_hat_{t-1})
                z_all_loc, z_all_scale = self.trans(z_prev)
                with poutine.scale(scale=annealing_factor):
                    z_all_t = pyro.sample("z_%d" % t, dist.Normal(z_all_loc, z_all_scale).to_event(1))
                # naive model
                if self.zx_dim == 0 and self.zy_dim == 0:
                    x_loc, x_scale = self.emitter_z_x(z_all_t)
                    y_loc, y_scale = self.emitter_z_y(z_all_t)
                # cca
                if self.zx_dim != 0 and self.zy_dim != 0:
                    z_t = z_all_t[:, :self.z_dim]
                    zx_t = z_all_t[:, self.z_dim:self.z_dim+self.zx_dim]
                    zy_t = z_all_t[:, self.z_dim+self.zx_dim:]
                    x_loc, x_scale = self.emitter_zx_x(z_t, zx_t)
                    y_loc, y_scale = self.emitter_zy_y(z_t, zy_t)
                # pls
                if self.zx_dim != 0 and self.zy_dim == 0:
                    z_t = z_all_t[:, :self.z_dim]
                    zx_t = z_all_t[:, self.z_dim:]
                    x_loc, x_scale = self.emitter_zx_x(z_t, zx_t)
                    y_loc, y_scale = self.emitter_z_y(z_t)
                pyro.sample("obs_x_%d" % t, dist.Normal(x_loc, x_scale).to_event(1), obs=mini_batch_x[:, t - 1, :])
                pyro.sample("obs_y_%d" % t, dist.Normal(y_loc, y_scale).to_event(1), obs=mini_batch_y[:, t - 1, :])
                # update time step
                z_prev = z_all_t

    def guide(self, mini_batch_x, mini_batch_y, annealing_factor=1.0):
        pyro.module("dmm", self)
        T_max = mini_batch_x.size(1) # T
        batch_size = mini_batch_x.size(0) # batch size
        # combining x and y batches
        mini_batch_combined = torch.cat((mini_batch_x, mini_batch_y), dim=-1)
        # expand h_0 to fit batch size
        h_0_contig = self.h_0.expand(1, batch_size, self.rnn.hidden_size).contiguous()
        # reverse batch
        mini_batch_reversed = torch.flip(mini_batch_combined, dims=[1])
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        rnn_output = torch.flip(rnn_output, dims=[1])

        # q(z_0)
        z_prev = self.z_q_0.expand(batch_size, self.z_q_0.size(0))

        with pyro.plate("z_minibatch", batch_size):
            for t in pyro.markov(range(1, T_max + 1)):
                # ST-R: q(z_t | z_{t-1}, x_{t:T}, y_{t:T})
                z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])
                if len(self.iafs) > 0:
                    z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
                else:
                    z_dist = dist.Normal(z_loc, z_scale)
                with pyro.poutine.scale(scale=annealing_factor):
                    if len(self.iafs) > 0:
                        z_t = pyro.sample("z_%d" % t, z_dist)
                    else:
                        z_t = pyro.sample("z_%d" % t, z_dist.to_event(1))
                # update time step
                z_prev = z_t

    # latent overshooting
    def model_lo(self, mini_batch_x, mini_batch_y):
        pyro.module("dmm", self)
        T_max = mini_batch_x.size(1) # T
        d = self.multisteps # d=1 by default
        batch_size = mini_batch_x.size(0) # batch size

        # p(z_0)
        z_0_expand = self.z_0.expand(batch_size, self.z_0.size(0))
        z_prev = torch.zeros(batch_size, d, self.z_0.size(0))
        z_prev = torch.cat([z_0_expand.unsqueeze(1), z_prev[:, 1:, :]], dim=1)
        z_cur = torch.zeros(batch_size, d, self.z_0.size(0))

        with pyro.plate("z_minibatch", batch_size):
            for t in pyro.markov(range(1, T_max + 1)):
                # z_{t|t}
                z_tt_prev = z_prev[:,0,:].squeeze()
                z_tt_loc, z_tt_scale = self.trans(z_tt_prev)
                z_tt_dist = dist.Normal(z_tt_loc, z_tt_scale)
                #with pyro.poutine.scale(scale=self.beta_d):
                z_tt = pyro.sample(f"z_{t}_{t}", z_tt_dist.to_event(1))
                z_cur = torch.cat([z_tt.unsqueeze(1), z_cur[:, 1:, :]], dim=1)
                # p(obs|z_{t|t})
                # naive model
                if self.zx_dim == 0 and self.zy_dim == 0:
                    x_loc, x_scale = self.emitter_z_x(z_tt)
                    y_loc, y_scale = self.emitter_z_y(z_tt)
                # cca
                if self.zx_dim != 0 and self.zy_dim != 0:
                    z_t = z_tt[:, :self.z_dim]
                    zx_t = z_tt[:, self.z_dim:self.z_dim+self.zx_dim]
                    zy_t = z_tt[:, self.z_dim+self.zx_dim:]
                    x_loc, x_scale = self.emitter_zx_x(z_t, zx_t)
                    y_loc, y_scale = self.emitter_zy_y(z_t, zy_t)
                # pls
                if self.zx_dim != 0 and self.zy_dim == 0:
                    z_t = z_tt[:, :self.z_dim]
                    zx_t = z_tt[:, self.z_dim:]
                    x_loc, x_scale = self.emitter_zx_x(z_t, zx_t)
                    y_loc, y_scale = self.emitter_z_y(z_t)
                # multiplication factor
                #mf = d if t >= d else t
                #with pyro.poutine.scale(scale=1):
                pyro.sample("obs_x_%d" % t, dist.Normal(x_loc, x_scale).to_event(1), obs=mini_batch_x[:, t - 1, :])
                pyro.sample("obs_y_%d" % t, dist.Normal(y_loc, y_scale).to_event(1), obs=mini_batch_y[:, t - 1, :])
                # z_{t|t-1}

                z_tt1 = pyro.sample(f"z_{t}_{t-1}_r", z_tt_dist.to_event(1))
                #z_cur[:,1,:] = z_tt1
                z_cur = torch.cat([z_cur[:, :1, :], z_tt1.unsqueeze(1), z_cur[:, 2:, :]], dim=1)

                if t >= d:
                    for j in range(2,d):
                        z_tj_prev = z_prev[:,j-1,:].squeeze()
                        z_tj_loc, z_tj_scale = self.trans(z_tj_prev)
                        z_tj_dist = dist.Normal(z_tj_loc, z_tj_scale)
                        with pyro.poutine.scale(scale=self.beta_d):
                            z_tj = pyro.sample(f"z_{t}_{t-j}", z_tj_dist.to_event(1))

                        z_tj_r = pyro.sample(f"z_{t}_{t-j}_r", z_tj_dist.to_event(1))
                        #z_cur[:,j,:] = z_tj_r
                        z_cur = torch.cat([z_cur[:, :j, :], z_tj_r.unsqueeze(1), z_cur[:, j+1:, :]], dim=1)
                    z_td_prev = z_prev[:,d-1,:].squeeze()
                    z_td_loc, z_td_scale = self.trans(z_td_prev)
                    z_td_dist = dist.Normal(z_td_loc, z_td_scale)
                    with pyro.poutine.scale(scale=self.beta_d):
                        z_td = pyro.sample(f"z_{t}_{t-d}", z_td_dist.to_event(1))

                if t > 1 and t < d:
                    for j in range(2,t+1):
                        z_tj_prev = z_prev[:,j-1,:].squeeze()
                        z_tj_loc, z_tj_scale = self.trans(z_tj_prev)
                        z_tj_dist = dist.Normal(z_tj_loc, z_tj_scale)
                        with pyro.poutine.scale(scale=self.beta_d):
                            z_tj = pyro.sample(f"z_{t}_{t-j}", z_tj_dist.to_event(1))

                        z_tj_r = pyro.sample(f"z_{t}_{t-j}_r", z_tj_dist.to_event(1))
                        #z_cur[:,j,:] = z_tj_r
                        z_cur = torch.cat([z_cur[:, :j, :], z_tj_r.unsqueeze(1), z_cur[:, j+1:, :]], dim=1)

                # update time step
                z_prev = z_cur

    def guide_lo(self, mini_batch_x, mini_batch_y):
        pyro.module("dmm", self)
        T_max = mini_batch_x.size(1) # T
        d = self.multisteps # d=1 by default
        batch_size = mini_batch_x.size(0) # batch size
        mini_batch_combined = torch.cat((mini_batch_x, mini_batch_y), dim=-1)
        h_0_contig = self.h_0.expand(1, batch_size, self.rnn.hidden_size).contiguous()
        mini_batch_reversed = torch.flip(mini_batch_combined, dims=[1])
        # We are using ST-L, leaving no need for reversing RNN.
        rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
        rnn_output = torch.flip(rnn_output, dims=[1])

        # p(z_0)
        z_0_expand = self.z_q_0.expand(batch_size, self.z_q_0.size(0))
        z_prev = torch.zeros(batch_size, d, self.z_q_0.size(0))
        z_prev = torch.cat([z_0_expand.unsqueeze(1), z_prev[:, 1:, :]], dim=1)
        z_cur = torch.zeros(batch_size, d, self.z_0.size(0))

        # p(z_0)
        z_0_batch = self.z_0.expand(batch_size, self.z_0.size(0))

        with pyro.plate("z_minibatch", batch_size):
            for t in pyro.markov(range(1, T_max + 1)):
                # z_{t|t}
                z_tt_prev = z_prev[:,0,:].squeeze()
                z_tt_loc, z_tt_scale = self.combiner(z_tt_prev, rnn_output[:, t - 1, :])
                z_tt_dist = dist.Normal(z_tt_loc, z_tt_scale)
                #with pyro.poutine.scale(scale=self.beta_d):
                z_tt = pyro.sample(f"z_{t}_{t}", z_tt_dist.to_event(1))
                #z_cur[:,0,:] = z_tt
                z_cur = torch.cat([z_tt.unsqueeze(1), z_cur[:, 1:, :]], dim=1)
                # z_{t|t-1}
                if t == 1:
                    z_tt1_loc, z_tt1_scale = self.trans(z_0_batch)
                else:
                    z_tt1_loc, z_tt1_scale = self.trans(z_tt_prev)
                z_tt1_dist = dist.Normal(z_tt1_loc, z_tt1_scale)

                z_tt1 = pyro.sample(f"z_{t}_{t-1}_r", z_tt1_dist.to_event(1))
                #z_cur[:,1,:] = z_tt1
                z_cur = torch.cat([z_cur[:, :1, :], z_tt1.unsqueeze(1), z_cur[:, 2:, :]], dim=1)

                if t >= d:
                    for j in range(2,d):
                        with pyro.poutine.scale(scale=self.beta_d):
                            z_tj = pyro.sample(f"z_{t}_{t-j}", z_tt_dist.to_event(1))
                        z_tj_prev = z_prev[:,j-1,:].squeeze()
                        z_tj_loc, z_tj_scale = self.trans(z_tj_prev)
                        z_tj_dist = dist.Normal(z_tj_loc, z_tj_scale)

                        z_tj_r = pyro.sample(f"z_{t}_{t-j}_r", z_tj_dist.to_event(1))
                        #z_cur[:,j,:] = z_tj_r
                        z_cur = torch.cat([z_cur[:, :j, :], z_tj_r.unsqueeze(1), z_cur[:, j+1:, :]], dim=1)
                    with pyro.poutine.scale(scale=self.beta_d):
                        z_td = pyro.sample(f"z_{t}_{t-d}", z_tt_dist.to_event(1))

                if t > 1 and t < d:
                    for j in range(2,t+1):
                        with pyro.poutine.scale(scale=self.beta_d):
                            z_tj = pyro.sample(f"z_{t}_{t-j}", z_tt_dist.to_event(1))
                        z_tj_prev = z_prev[:,j-1,:].squeeze()
                        z_tj_loc, z_tj_scale = self.trans(z_tj_prev)
                        z_tj_dist = dist.Normal(z_tj_loc, z_tj_scale)

                        z_tj_r = pyro.sample(f"z_{t}_{t-j}_r", z_tj_dist.to_event(1))
                        #z_cur[:,j,:] = z_tj_r
                        z_cur = torch.cat([z_cur[:, :j, :], z_tj_r.unsqueeze(1), z_cur[:, j+1:, :]], dim=1)

                # update time step
                z_prev = z_cur


def elbo_lo(model, guide, *args, **kwargs):
    # Get the trace from the guide and the model
    guide_trace = pyro.poutine.trace(guide).get_trace(*args, **kwargs)
    model_trace = pyro.poutine.trace(pyro.poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)

    elbo = 0.0

    # Iterate through the nodes in the guide trace
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            log_prob = site["fn"].log_prob(site["value"]).sum()
            # Check if the variable name contains "_r" and exclude it from the ELBO calculation
            if "_r" not in name:
                elbo += log_prob

    # Iterate through the nodes in the model trace
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample":
            log_prob = site["fn"].log_prob(site["value"]).sum()
            if "_r" not in name:
                elbo -= log_prob

    return elbo

# find emperical mean of a transformed distribution
def mean_iaf(z_loc, z_scale, iaf_module, num_samples = 1000):
    base_dist = dist.Normal(z_loc, z_scale).to_event(1)
    transformed_dist = TransformedDistribution(base_dist, iaf_module)
    samples = transformed_dist.sample([num_samples])
    estimated_mean = samples.mean(dim=0)
    return estimated_mean



# draw visualization for y
def visualize(dmm_trained, x_data, y_data, data_idx, y_idx, find_RMSE=False, store_pic=False):
    if not find_RMSE:
        x_obs = x_data[data_idx:data_idx+1,:,:]
        y_obs = y_data[data_idx:data_idx+1,:,:]
        batch_size = 1
    else:
        x_obs = x_data
        y_obs = y_data
        batch_size = x_data.shape[0]
        N = batch_size
    y_pred = []
    y_scales = []
    x_pred = []
    x_scales = []
    z_pred = []
    z_scales = []
    T = x_obs.shape[1]

    obs = torch.cat((x_obs, y_obs), dim=-1)
    h_0_contig = dmm_trained.h_0.expand(1, batch_size, dmm_trained.rnn.hidden_size).contiguous()
    obs_reversed = torch.flip(obs, dims=[1])
    rnn_output, _ = dmm_trained.rnn(obs_reversed, h_0_contig)
    rnn_output = torch.flip(rnn_output, dims=[1])
    z_prev = dmm_trained.z_q_0
    for t in range(T):
        z_t_loc, z_t_scale = dmm_trained.combiner(z_prev, rnn_output[:, t, :])
        if len(dmm_trained.iafs) > 0:
            z_t_loc = mean_iaf(z_t_loc, z_t_scale, dmm_trained.iafs, num_samples = 1000)
        # naive model
        if dmm_trained.zx_dim == 0 and dmm_trained.zy_dim == 0:
            x_loc, x_scale = dmm_trained.emitter_z_x(z_t_loc)
            y_loc, y_scale = dmm_trained.emitter_z_y(z_t_loc)
        # cca
        if dmm_trained.zx_dim != 0 and dmm_trained.zy_dim != 0:
            z_t = z_t_loc[:, :dmm_trained.z_dim]
            zx_t = z_t_loc[:, dmm_trained.z_dim:dmm_trained.z_dim+dmm_trained.zx_dim]
            zy_t = z_t_loc[:, dmm_trained.z_dim+dmm_trained.zx_dim:]
            x_loc, x_scale = dmm_trained.emitter_zx_x(z_t, zx_t)
            y_loc, y_scale = dmm_trained.emitter_zy_y(z_t, zy_t)
        # pls
        if dmm_trained.zx_dim != 0 and dmm_trained.zy_dim == 0:
            z_t = z_t_loc[:, :dmm_trained.z_dim]
            zx_t = z_t_loc[:, dmm_trained.z_dim:]
            x_loc, x_scale = dmm_trained.emitter_zx_x(z_t, zx_t)
            y_loc, y_scale = dmm_trained.emitter_z_y(z_t)
        y_pred.append(y_loc.squeeze(0))
        y_scales.append(y_scale.squeeze(0))
        x_pred.append(x_loc.squeeze(0))
        x_scales.append(x_scale.squeeze(0))
        z_pred.append(z_t_loc.squeeze(0))
        z_scales.append(z_t_scale.squeeze(0))
        z_prev = z_t_loc
    if find_RMSE:
        y_pred = torch.stack(y_pred, dim=1).detach().numpy()
        y_scales = torch.stack(y_scales, dim=1).detach().numpy()
        x_pred = torch.stack(x_pred, dim=1).detach().numpy()
        x_scales = torch.stack(x_scales, dim=1).detach().numpy()
        z_pred = torch.stack(z_pred, dim=1).detach().numpy()
        z_scales = torch.stack(z_scales, dim=1).detach().numpy()
    else:
        y_pred = torch.stack(y_pred).detach().numpy()
        y_scales = torch.stack(y_scales).detach().numpy()
        x_pred = torch.stack(x_pred).detach().numpy()
        x_scales = torch.stack(x_scales).detach().numpy()
        z_pred = torch.stack(z_pred).detach().numpy()
        z_scales = torch.stack(z_scales).detach().numpy() #print(y_pred.shape)
    # upper and lower CI
    y_upper = y_pred + 2*y_scales
    y_lower = y_pred - 2*y_scales
    x_upper = x_pred + 2*x_scales
    x_lower = x_pred - 2*x_scales
    z_upper = z_pred + 2*z_scales
    z_lower = z_pred - 2*z_scales
    # calculate RMSE:
    if find_RMSE:
        squared_diff_x = (x_data - x_pred) ** 2
        squared_diff_y = (y_data - y_pred) ** 2
        MSE = (squared_diff_x.sum() + squared_diff_y.sum()) / (N * T)
        RMSE = torch.sqrt(MSE)
        # save all the predictions:
        if store_pic:
            x_dim = x_upper.shape[-1]
            y_dim = y_upper.shape[-1]
            z_dim = z_upper.shape[-1]

            for xj in range(x_dim):
                plt.figure(figsize=(10, 6))
                plt.plot(x_pred[data_idx, :, xj], label='Prediction')
                plt.plot(x_data[data_idx, :, xj], label='Observations')
                if len(dmm_trained.iafs) == 0:
                    plt.fill_between(range(len(x_upper[data_idx, :, xj])), x_lower[data_idx,:,xj], x_upper[data_idx, :, xj], color='gray', alpha=0.2, label='Confidence Interval')
                plt.xlabel("Time Step")
                plt.ylabel("Value")
                plt.legend()
                filename = f'x_{data_idx}_{xj+1}.png'
                plt.savefig(filename)
                plt.close()
                files.download(filename)
                print(filename)

            for yj in range(y_dim):
                plt.figure(figsize=(10, 6))
                plt.plot(y_pred[data_idx, :, yj], label='Prediction')
                plt.plot(y_data[data_idx, :, yj], label='Observations')
                if len(dmm_trained.iafs) == 0:
                    plt.fill_between(range(len(y_upper[data_idx, :, yj])), y_lower[data_idx,:,yj], y_upper[data_idx, :, yj], color='gray', alpha=0.2, label='Confidence Interval')
                plt.xlabel("Time Step")
                plt.ylabel("Value")
                plt.legend()
                filename = f'y_{data_idx}_{yj+1}.png'
                plt.savefig(filename)
                plt.close()
                files.download(filename)
                print(filename)

            for zj in range(z_dim):
                plt.figure(figsize=(10, 6))
                plt.plot(z_pred[data_idx, :, zj], label='Prediction')
                if len(dmm_trained.iafs) == 0:
                    plt.fill_between(range(len(z_upper[data_idx, :, zj])), z_lower[data_idx,:,zj], z_upper[data_idx, :, zj], color='gray', alpha=0.2, label='Confidence Interval')
                plt.xlabel("Time Step")
                plt.ylabel("Value")
                plt.legend()
                filename = f'z_{data_idx}_{zj+1}.png'
                plt.savefig(filename)
                plt.close()
                files.download(filename)
                print(filename)
        return RMSE
    else:
        plt.figure(figsize=(10, 6))
        plt.plot(y_pred[:, y_idx], label='Prediction')
        plt.plot(y_data[data_idx, :, y_idx], label='Actual Values')
        # Fill the area between y_upper and y_lower
        if len(dmm_trained.iafs) == 0:
            plt.fill_between(range(len(y_upper[:, y_idx])), y_lower[:, y_idx], y_upper[:, y_idx], color='gray', alpha=0.2, label='Confidence Interval')
        plt.title("Visualization of the First Dimension over 100 Time Steps")
        plt.xlabel("Time Step")
        plt.ylabel("Value")
        plt.legend()
        plt.show()


# Training one epoch of the training set
def train(svi, train_loader):
    epoch_loss = 0.
    for x, y in train_loader: # x is mini-batch
        epoch_loss += svi.step(x,y)

    # return average epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train

# evaluate model on test set
def evaluate(svi, test_loader):
    test_loss = 0.
    for x, y in test_loader:
        test_loss += svi.evaluate_loss(x,y)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test

def train_KL_annealing(svi, train_loader, epoch, annealing_epochs, minimum_annealing_factor):
    batch_size = train_loader.batch_size
    N_mini_batches = len(train_loader)
    epoch_nll = 0.0
    for which_mini_batch, (x, y) in enumerate(train_loader):
        if annealing_epochs > 0 and epoch < annealing_epochs:
            annealing_factor = minimum_annealing_factor + (1.0 - minimum_annealing_factor) * (
                float(which_mini_batch + epoch * N_mini_batches + 1)
                / float(annealing_epochs * N_mini_batches)
            )
        else:
            annealing_factor = 1.0
        epoch_nll += svi.step(x, y, annealing_factor)
    # return average epoch loss
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_nll / normalizer_train
    return total_epoch_loss_train


# Simulation Data

In [None]:
# Ensure reproducibility
torch.manual_seed(42)
np.random.seed(torch.initial_seed() % (2**32 - 1))

def nonlinear_transition(z_prev, A, nonlinearity, noise_std):
    z_linear = np.dot(A, z_prev)
    z_nonlinear = 0.5 * z_linear + 0.5 * nonlinearity(z_linear)
    z_next = z_nonlinear + np.random.normal(0, noise_std, z_prev.shape)
    return z_next

def generate_system(T, A0, Ax, Ay, Wx, Wy, initial_states, noise_std=0.1, epsilon_x=0.05, epsilon_y=0.05):
    Z0 = np.zeros(T+1)
    Zx = np.zeros((T+1, Ax.shape[0]))
    Zy = np.zeros((T+1, Ay.shape[0]))
    X = np.zeros((T, Wx.shape[0]))
    Y = np.zeros((T, Wy.shape[0]))

    Z0[0] = initial_states['z0']
    Zx[0, :], Zy[0, :] = initial_states['zx'], initial_states['zy']

    for t in range(1,T+1):
        Z0[t] = nonlinear_transition(Z0[t-1], A0, np.sin, noise_std)
        Zx[t, :] = nonlinear_transition(Zx[t-1, :], Ax, np.tanh, noise_std)
        Zy[t, :] = nonlinear_transition(Zy[t-1, :], Ay, np.tanh, noise_std)

        x_state = np.concatenate(([Z0[t]], Zx[t, :]))
        y_state = np.concatenate(([Z0[t]], Zy[t, :]))
        X[t-1, :] = np.dot(Wx, x_state) + np.random.normal(0, epsilon_x, Wx.shape[0])
        Y[t-1, :] = np.dot(Wy, y_state) + np.random.normal(0, epsilon_y, Wy.shape[0])

    return Z0, Zx, Zy, X, Y

def generate_multiple_samples(N, T, A0, Ax, Ay, Wx, Wy, noise_std, epsilon_x, epsilon_y):
    data_x = np.zeros((N, T, Wx.shape[0]))
    data_y = np.zeros((N, T, Wy.shape[0]))

    for n in range(N):
        initial_states = {
            'z0': np.random.rand(1),
            'zx': np.random.rand(3),
            'zy': np.random.rand(2)
        }
        _, _, _, X, Y = generate_system(T, A0, Ax, Ay, Wx, Wy, initial_states, noise_std, epsilon_x, epsilon_y)
        data_x[n] = X
        data_y[n] = Y

    return data_x, data_y

# Parameters
N = 500
T = 100
A0 = np.array([[0.95]])
Ax = np.array([[0.9, 0.1, 0.0], [0.1, 0.8, 0.1], [0.0, 0.1, 0.9]])
Ay = np.array([[0.85, 0.15], [0.15, 0.85]])
Wx = np.random.rand(10, 4)  # Observation matrix for x
Wy = np.random.rand(5, 3)  # Observation matrix for y


# Generate N samples
data_x, data_y = generate_multiple_samples(N, T, A0, Ax, Ay, Wx, Wy, 0.1, 0.05, 0.05)
data_x = torch.tensor(data_x, dtype=torch.float32)
data_y = torch.tensor(data_y, dtype=torch.float32)

# Create TensorDatasets
train_dataset = TensorDataset(data_x, data_y)

# DataLoader
batch_size = 10
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Loading Financial Dataset and Preprocessing

In [None]:
from google.colab import drive
drive.mount('/content/drive')

file_path = '/content/drive/MyDrive/xxxxx.csv'
df = pd.read_csv(file_path)
data_X = df.iloc[:, 1:11].astype(float).values
data_Y = df.iloc[:, 11:21].astype(float).values

# Normalize the data up front
mean_X = data_X.mean(axis=0)
std_X = data_X.std(axis=0)
normalized_X = (data_X - mean_X) / std_X

mean_Y = data_Y.mean(axis=0)
std_Y = data_Y.std(axis=0)
normalized_Y = (data_Y - mean_Y) / std_Y

# Create sequences
sequence_length = 30
sequences_X = []
sequences_Y = []
for i in range(len(data_X) - sequence_length + 1):
    sequences_X.append(normalized_X[i:i+sequence_length])
    sequences_Y.append(normalized_Y[i:i+sequence_length])

sequences_X = np.array(sequences_X)
sequences_Y = np.array(sequences_Y)

# Convert to PyTorch tensors
tensor_X = torch.tensor(sequences_X, dtype=torch.float32)
tensor_Y = torch.tensor(sequences_Y, dtype=torch.float32)

# Create TensorDataset and DataLoader
dataset = TensorDataset(tensor_X, tensor_Y)
train_loader = DataLoader(dataset, batch_size=10, shuffle=True)

# Usage

In [None]:
# experiment x only.

my_model = D2PCCA(
    ## naive
    x_dim=10,
    y_dim=10,
    z_dim=1,
    emission_dim_z_x=20,
    emission_dim_z_y=20,
    transition_dim_z=20,
    rnn_dim = 30,
    rnn_dropout_rate=.1,
    ## cca
    zx_dim=2,
    zy_dim=2, #2
    emission_dim_zx_x=20,
    emission_dim_zy_y=20,
    transition_dim_zx=10,
    transition_dim_zy=10,
    multisteps=2,
    # iafs
    #num_iafs=5,
    #iaf_dim=10,
    #beta_d = .01
)

# setup optimizer
adam_params = {
    "lr": 0.0003,
    "betas": (0.96, 0.999),
    "clip_norm": 20.0,
    "lrd": 0.99996,
    "weight_decay": 2.0,
}

# clear parameters
pyro.clear_param_store()

adam = ClippedAdam(adam_params)

# D2PCCA
svi = SVI(my_model.model, my_model.guide, adam, Trace_ELBO())
# with Latent Overshooting
# svi = SVI(my_model.model_lo, my_model.guide_lo, adam, elbo_lo)

In [None]:
NUM_EPOCHS = 1500

# For KL-Annealing Training
#annealing_epochs = 90
#minimum_annealing_factor = 0.2

train_elbo = []

# training loop
for epoch in range(NUM_EPOCHS):
    total_epoch_loss_train = train(svi, train_loader)
    # total_epoch_loss_train = train_KL_annealing(svi, train_loader, epoch, annealing_epochs, minimum_annealing_factor)
    train_elbo.append(-total_epoch_loss_train)
    print("[epoch %03d]  average training loss: %.4f" % (epoch, total_epoch_loss_train))
    if epoch % 50 == 0:
        rmse = visualize(my_model, tensor_X, tensor_Y, 10, 3, True, False)
        print("[epoch %03d]  RMESE: %.4f" % (epoch, rmse))
