# Discrete HMMs

## Agenda
* Key ideas and words
* HMM Overview
  * What confused me
* How I'm approaching this
* Going through examples
  * State estimation
  * Parameter estimation
* Where I would use Dynamax in my past career
  * Practical considerations
* Updated Book Club Focus

## Summary
* HMM can be confusing, especially between implementations and texts
* Dynamax's implementation is quite good
  * Their docs are great for practitioners
* JAX is used liberally throughout the library for speedups
* There's many different implementations
  * More than I had originally expected

## References
* This repo https://github.com/canyon289/ssm_book_club
*  Casino HMMs from Dynamax
    *  https://probml.github.io/dynamax/notebooks/hmm/casino_hmm_inference.html
    * https://probml.github.io/dynamax/notebooks/hmm/casino_hmm_learning.html
*  Complete reading of https://nipunbatra.github.io/hmm/


# Key ideas and words

## Discrete HMM

$$
\begin{align}
p(y_{1:T}, z_{1:T} \mid \theta) 
&= \overbrace{\mathrm{Cat}(z_1 \mid \pi)}^{Prior for Initial State}
\underbrace{\prod_{t=2}^T \mathrm{Cat}(z_t \mid A_{z_{t-1}})}_{Transition Model}
\overbrace{\prod_{t=1}^T \mathrm{Cat}(y_t \mid B_{z_t})}^{Observation Model}
\end{align}
$$


## Terminology
* Forward Filter - Estimating state probability using only "seen" data
* Prediction - Estimating the next time step
* Smoothing  Estimating state probability using all data
  * Forward Backward Pass
* Viterbi Algorithm - Max
  * Sequence Assignment

## Symbols

https://github.com/canyon289/ssm_book_club/blob/hmms/SymbolList.md#nipunbatra-article

## Things that were challenging to me
* Change of symbols between sources and withing texts
  * Duplication of symbols in same text
  * Python variable names that correlated to these terms
* Differing usage of words
* Ambiguous terms
  * Especially in the code
* Abstractions in Dynamax
  * Required lots of tracing

## Various Time dependencies ( or Lack thereof)
* Things that are independent of time
  * Transition Matrix
  * Prior Probability
  * Emission Prob assuming state
  
* Things dependent on a particular time window but independent of others
  * Observations
  * Log likelihood

* Things that change over time
  * State estimation
  * Emission probability after updating state probability


## How I'm learning
Referencing
* External articles
* ProbML book
* Diving into the code
* Writing my own mini examples

## Applied Example, Filtering, Smoothing  - Biased coin toss

In [1]:
import jax.numpy as jnp
import jax.random as jr
# import matplotlib.pyplot as plt
from jax import vmap

In [2]:
%load_ext autoreload
%autoreload 2
import dynamax
from dynamax.hidden_markov_model import CategoricalHMM

This a modified version of dynamax for the HMM Book club session
Code can be found here https://github.com/canyon289/dynamax/tree/hmm_session


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
initial_probs = jnp.array([0.5, 0.5])

transition_matrix = jnp.array([[1.0, 0.0],
                               [0.0, 1.0]])

In [4]:
initial_probs = jnp.array([0.5, 0.5])

transition_matrix = jnp.array([[1.0, 0.0],
                               [0.0, 1.0]])

In [5]:
num_states = 2      # two types of dice (fair and loaded)
num_emissions = 1
num_classes = 2

In [6]:
emission_probs = jnp.array([[1/2,  1/2],    # fair die
                            [1/10, 9/10]])  # loaded di

In [7]:
# A bunch of stuff happens here, like log likelihood calculations and such
hmm = CategoricalHMM(num_states, num_emissions, num_classes)

# Initialize the parameters struct with known values
params, _ = hmm.initialize(initial_probs=initial_probs,
                           transition_matrix=transition_matrix,
                           emission_probs=emission_probs.reshape(num_states, num_emissions, num_classes))

In [8]:
num_timesteps = 3
true_states, emissions = hmm.sample(params, jr.PRNGKey(42), num_timesteps)
true_states, emissions

(Array([1, 1, 1], dtype=int32),
 Array([[1],
        [1],
        [1]], dtype=int32))

In [9]:
posterior = hmm.filter(params, emissions)

log_probs; [[-0.6931472  -0.10536057]
 [-0.6931472  -0.10536057]
 [-0.6931472  -0.10536057]]
Iteration: 0
predicted_probs: [0.5 0.5]
Log Likelihood of emission given state: [-0.6931472  -0.10536057]
Filtered probs: [0.35714287 0.64285713]
Predicted probs: [0.35714287 0.64285713]


Iteration: 1
predicted_probs: [0.35714287 0.64285713]
Log Likelihood of emission given state: [-0.6931472  -0.10536057]
Filtered probs: [0.23584908 0.7641509 ]
Predicted probs: [0.23584908 0.7641509 ]


Iteration: 2
predicted_probs: [0.23584908 0.7641509 ]
Log Likelihood of emission given state: [-0.6931472  -0.10536057]
Filtered probs: [0.14637005 0.85362995]
Predicted probs: [0.14637005 0.85362995]




## Log Likelihoods Verification
Verify what we're seeing above

In [13]:
# Log likelihood verification
from scipy import stats
stats.bernoulli([.5, .9]).logpmf(1)

array([-0.69314718, -0.10536052])

## Bayesian Update for Filtering
Calculating Probability that we're in biased coin state after one coin toss

$$p(Biased | one heads) = p(Biased | x=H) $$

In [17]:
# (Likelihood of heads  assumed biased) * prior/(total probability of heads)
p_biased_state_one_heads = (.9*.5)/(.9*.5 + .5*.5)
p_biased_state_one_heads

0.6428571428571429

In [18]:
(.9*p_biased_state_one_heads)/(.9*p_biased_state_one_heads + .5*(1-p_biased_state_one_heads))

0.7641509433962265

Predicted probs stays the same because of filtering

### Non identity transition matrix

In [26]:
transition_matrix = jnp.array([[.5, .5],
                               [0.0, 1.0]])

In [27]:
# Initialize the parameters struct with known values
params, _ = hmm.initialize(initial_probs=initial_probs,
                           transition_matrix=transition_matrix,
                           emission_probs=emission_probs.reshape(num_states, num_emissions, num_classes))

In [28]:
posterior = hmm.filter(params, emissions)

log_probs; [[-0.6931472  -0.10536057]
 [-0.6931472  -0.10536057]
 [-0.6931472  -0.10536057]]
Iteration: 0
predicted_probs: [0.5 0.5]
Log Likelihood of emission given state: [-0.6931472  -0.10536057]
Filtered probs: [0.35714287 0.64285713]
Predicted probs: [0.17857143 0.82142854]


Iteration: 1
predicted_probs: [0.17857143 0.82142854]
Log Likelihood of emission given state: [-0.6931472  -0.10536057]
Filtered probs: [0.10775863 0.8922414 ]
Predicted probs: [0.05387932 0.94612074]


Iteration: 2
predicted_probs: [0.05387932 0.94612074]
Log Likelihood of emission given state: [-0.6931472  -0.10536057]
Filtered probs: [0.03066732 0.96933264]
Predicted probs: [0.01533366 0.9846663 ]




## Things I noticed in the code
* Architecture of Dynamax
* What is coming next
* JAX usage

## Computational complexity
* Forward Algorithm $O(TK^{2})$
  * Linear with time
  * Quadratic with number of states

## JAX Speedups

* Parallelization over independent time series
* Gradient

## Real World Example: Parts from a machine shop
https://www.youtube.com/watch?v=OCc2F8KccD4


### Questions asked of me
* What state is the machine in? 
  * Are we going to get good parts or bad parts?
* How often does it tend to switch?
  * Once its good does it stay good, or is totally unreliable?
* What is the probability of bad parts coming off this line?

### Data
* The order parts were produced from the machine
* Which ones were good and which ones weren't
* Parts from many machines

### Use of HMM
* **Smoothing** - What state(s) was the machine in over the last 7 days and when?
  * Look for correlation such as shift changes time etc
* **Parameter Estimation** - How faulty was the machine and when did it tend to switch?
* **Filtering** - What state is the machine in now?

## Useful Extensions for the above case
* Covariate HMM
* Autoregressive HMMs

## Other HMMs in dynamax

* There are many more present
  * https://probml.github.io/dynamax/api.html#high-level-models

## Takeaways

* For state estimation smoothing, filtering, and viterbi are all supported
* Multiple methods for parameter estimation
  * SGD
  * Minibatch
  * Expectation Maximization

## (Refined) Book Club Focus
* Practitioner that's looking to use the library
  * Understand what exists
  * How it works from "the drivers seat"

### What you folks said
* How do i convert a well estimated state space model into a compelling case for action or decision?
* Using this as a forcing function to help learn more about state space models and work up to structural time series and causal impact
* Fluency in expressing diverse state space models in python


## Next Weeks Agenda
* Finish off last two notebooks on State Space Models