# Bayesian Inference in the Poisson HMM

In this notebook we will fit a Hidden Markov Model. This content is based on Professor Scott Linderman's slides for the (Spring 2021) Stanford course STATS 271 ([Part 1](https://github.com/slinderman/stats271sp2021/blob/main/slides/lap7_hmms.pdf), [Part 2](https://github.com/slinderman/stats271sp2021/blob/main/slides/lap7_hmms_b.pdf)) and Chapter 23 from "Bayesian Reasoning and Machine Learning" (Barber, 2012).

Mathematically, let $\mathbf{x}_t^{(v)} \in \mathbb{R}^{p}$ and $y_t^{(v)}$ denote the combined features and case counts, respectively, at time step $t$ of the $v$-th event, and let $\mathbf{x}_{1:T_i}^{(v)} = (\mathbf{x}_1^{(v)}, \ldots, \mathbf{x}_{T_v}^{(v)})$ and $y_{1:T_i}^{(v)} = (y_1^{(v)}, \ldots, y_{T_v}^{(v)})$ denote the full sequence of features and case counts, respectively, for the $v$-th event, where $T_v$ is the number of time steps. Likewise, let $z_{1:T_v}^{(v)} = (z_1^{(v)}, \ldots, z_{T_v}^{(v)})$ denote the sequence of discrete states for the $i$-th event.

The joint distribution of the complete dataset is,

\begin{align}
p(\{(z_{1:T_v}^{(v)}, y_{1:T_v}^{(v)}\}_{v=1}^V \mid \{\mathbf{x}_{1:T_v}^{(v)}\}_{v=1}^V, \Theta) 
&= \prod_{v=1}^V p(z_{1:T_v}^{(v)}, y_{1:T_v}^{(v)} \mid \mathbf{x}_{1:T_v}^{(v)}, \Theta) \\
&= \prod_{v=1}^V \left[p(z_1^{(v)} \mid \Theta) \prod_{t=2}^{T_v} p(z_{t}^{(v)} \mid z_{t-1}^{(v)} , \Theta)  \prod_{t=1}^{T_v} p(y_{t}^{(v)} \mid z_{t}^{(v)}, \mathbf{x}_{t}^{(v)}, \Theta) \right]
\end{align}

The goal is to find the parameters $\Theta$ that maximize the marginal likelihood of the data by using EM. Then we'll use cross validation, holding out a random subset of events, to determine the number of discrete states. Finally, we'll visualize the inferred states in terms of the distribution over features.

# Model

<img src="../hmm.png" width="400">

\begin{align}
\theta_k \sim Dirichlet_k(\alpha)\\
z_t\sim Discrete(\theta_{z_{t-1}})\\
\mu_k\sim Normal(0, \sigma)\\
y_t \sim Poisson(\mu_{z_t}^T\mathbf{x}_t)
\end{align}



# EM for Poissson GLM

The Expectation-Maximization (EM) algorithm performs coordinate ascent on parameters and latent variable posteriors. The **E-step** computes an update of the posteror over latent variables: 

$$q\leftarrow p(z\mid x, y, \Theta)$$

Then, the **M-step**, maximizes the expected log probability

\begin{align}
    \Theta\leftarrow\arg \max_\Theta \;\mathbb{E}_{q(z)}[\log p(x,z,,y,\Theta)]
\end{align}

## E-step: Running the forward-backward algorithm

Consider the marginal probability of state $k$ at time $t$

\begin{align}
    q(z_t=k) & = \sum_{z_1=1}^K\dots\sum_{z_{t-1}=1}^K\sum_{z_{t+1}=1}^K\dots\sum_{z_T=1}^Kq(z_1,\dots,z_{t-1},z_t=k,z_{t+1},\dots,z_T)\\
    & \propto \left[\sum_{z_1=1}^K\dots\sum_{z_{t-1}=1}^Kp(z_1)\prod_{s=1}^{t-1}p(y_s\mid \mu_{z_s}^T\mathbf{x}_s)p(z_{s+1}\mid z_s)\right]\times \left[p(y_t\mid\mu_k^T\mathbf{x}_t)\right]\\
    & \times \left[\sum_{z_{t+1}=1}^K\dots\sum_{z_{T}=1}^K\prod_{u=t+1}^{T}p(z_{u}\mid z_{u-1})p(y_u\mid \mu_{z_u}^T\mathbf{x}_u)\right]\\
    & \triangleq \alpha_t(z_t=k)\times p(y_t\mid\mu_k^T\mathbf{x}_t) \times \beta_t(z_t=k)
\end{align}

Where $\alpha_t(z_t)$ and $\beta_t(z_t)$ are generally known as the forward and backward messages, respectively. These messages can be computed efficiently by recursion:

\begin{align}
    \alpha(z_t=k)=\sum_{z_{t-1}=1}^K\alpha_{t-1}(z_{t-1})p(y_{t-1}\mid\mu_{z_{t-1}}^T\mathbf{x}_{t-1})p(z_t=k\mid z_{t-1})\\
    \beta(z_t=k)=\sum_{z_{t+1}=1}^Kp(z_{t+1}\mid z_{t}=k)p(y_{t+1}\mid\mu_{z_{t+1}}^T\mathbf{x}_{t+1})\beta_{t-1}(z_{t+1})
\end{align}

Finally, the probabilities need to sum to one. Normalizing yields,

\begin{align}
q(z_t=k)=\frac{\alpha_t(z_t=k) p(y_t\mid\mu_k^T\mathbf{x}_t)  \beta_t(z_t=k)}{\sum_{k=1}^K\alpha_t(z_t=k) p(y_t\mid\mu_k^T\mathbf{x}_t)  \beta_t(z_t=k)}
\end{align}

We can also use this forward-backward algorithm to compute the transition marginals $q(z_t=i, z_{t+1}=j)$

\begin{align}
q(z_t=i, z_{t+1}=j)=\frac{\alpha_t(z_t=i) p(y_t\mid\mu_i^T\mathbf{x}_t)p(z_{t+1}=j \mid z_{t}=i)p(y_{t+1}\mid\mu_j^T\mathbf{x}_{t+1}) \beta_{t+1}(z_{t+1}=j)}{\sum_{i=1}^K\sum_{j=1}^K\alpha_t(z_t=i) p(y_t\mid\mu_i^T\mathbf{x}_t)p(z_{t+1}=j \mid z_{t}=i)p(y_{t+1}\mid\mu_j^T\mathbf{x}_{t+1}) \beta_{t+1}(z_{t+1}=j)}
\end{align}

which will be necessary to update the transition matrix.

## M-step: Updating the parameters

For this model, we have to estimate the parameters $\Theta=(\{\theta_k\}^K, \{\mu_k\}^K)$. First, let us rewrite the joint distribution of the complete dataset by replacing

\begin{align}
p(z_{t}^{(v)} \mid z_{t-1}^{(v)} , \Theta) & = \prod_{i=1}^K\prod_{j=1}^K\theta_{ij}^{\mathbb{I}(z_{t-1}=i)\mathbb{I}(z_{t}=j)}\\
p(y_{t}^{(v)} \mid z_{t}^{(v)}, \mathbf{x}_{t}^{(v)}, \Theta) & = \prod_{k=1}^Kp(y_{t}^{(v)} \mid \mu_k^T\mathbf{x}_{t}^{(v)})^{\mathbb{I}(z_{t}=k)}
\end{align}

\begin{align}
p(\{(z_{1:T_v}^{(v)}, y_{1:T_v}^{(v)}\}_{v=1}^V \mid \{\mathbf{x}_{1:T_v}^{(v)}\}_{v=1}^V, \Theta) 
&= \prod_{v=1}^V \left[p(z_1^{(v)} \mid \Theta) \prod_{t=2}^{T_v} p(z_{t}^{(v)} \mid z_{t-1}^{(v)} , \Theta)  \prod_{t=1}^{T_v} p(y_{t}^{(v)} \mid z_{t}^{(v)}, \mathbf{x}_{t}^{(v)}, \Theta) \right]\\
&= \prod_{v=1}^V \left[p(z_1^{(v)} \mid \Theta) \prod_{t=2}^{T_v} p(z_{t}^{(v)} \mid z_{t-1}^{(v)} , \Theta)  \prod_{t=1}^{T_v} p(y_{t}^{(v)} \mid z_{t}^{(v)}, \mathbf{x}_{t}^{(v)}, \Theta) \right]\\
\end{align}




## Initialization

In [None]:
import sys
sys.path.append('../')
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from src.d01_data.dengue_data_api import DengueDataApi
from src.d04_modeling.hmm import HMM

In [None]:
dda = DengueDataApi()
x_train, x_validate, y_train, y_validate = dda.split_data()

In [None]:
event_data = dict()
u, s, vh = np.linalg.svd(x_train, full_matrices=True)
num_components = 4
new_features = ["pc%i" % i for i in range(num_components)]
z_train = pd.DataFrame(np.dot(x_train, vh[:num_components, :].T), columns=new_features, index=x_train.index)
z_validate = pd.DataFrame(np.dot(x_validate, vh[:num_components, :].T), columns=new_features, index=x_validate.index)

num_states=3
model = HMM(num_states=num_states)
event_data['x'] = model.format_event_data(z_train)
event_data['y'] = model.format_event_data(y_train)
lls_k, parameters_k = model.fit(event_data=event_data)

In [None]:
event_data['x'] = model.format_event_data(z_validate)
event_data['y'] = model.format_event_data(y_validate)

In [None]:
expectations, _, _ = model.e_step(event_data, parameters_k)

In [None]:
c = 0
df = y_validate
latent_states = pd.DataFrame(np.nan, index=df.index, columns=["s%i" % i for i in range(num_states)])
for city in df.index.get_level_values('city').unique():
    for year in df.loc[city].index.get_level_values('year').unique():
        latent_states.loc[city].loc[year] = expectations[c]
        c += 1

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 5))
for i in range(num_states):
    if i >= 3:
        break
    state = latent_states.columns[i]
    sns.heatmap(ax=axes[i], data=latent_states[state].unstack('weekofyear'))
    axes[i].set_title(state)

In [None]:
sns.heatmap(data=y_validate.unstack('weekofyear'))
plt.show()

In [None]:
event_data_validate = dict()
event_data_validate['x'] = model.format_event_data(z_validate)
event_data_validate['y'] = model.format_event_data(y_validate)

marginal_ll, mae = model.validate_model(event_data=event_data_validate, parameters=parameters_k)
print("MAE: %.6f" % mae)