# Hidden Markov Models

In [1]:
import numpy as np

def pred_states_max_product(obs_probs: np.ndarray, transition_probs: np.ndarray, obs_inds: list[int]) -> np.ndarray:
    """
    Predict hidden states using the max product algorithm.
    Finds the hidden states with the greatest probability given the observations.

    Args:
        obs_probs (np.ndarray): Matrix of conditional probabilities of the observation given the state.
            State indexes row, and observation indexes column.
            Probability of observation i given state j = obs_probs[j, i].
        transition_probs (np.ndarray): Matrix of transition probabilities.
            State t indexes row, state t+1 indexes column.
            Probability of transitioning from state i to j = transition_probs[i, j].
        obs_inds (list[int]): List of observation indices. Used to index obs_probs.

    Returns:
        np.ndarray: Predicted state indices.
    """
    obs_log_probs = np.log(obs_probs)
    transition_log_probs = np.log(transition_probs)
    
    N = len(obs_probs)
    T = len(obs_inds)
    P = np.full((N, T), -np.inf, dtype=np.float64)
    inds = np.full((N, T), -1, np.int32)

    for j in range(N):
        P[j, 0] = obs_log_probs[j, obs_inds[0]]

    for t in range(1, T):
        for j in range(N):
            log_probs = P[:, t-1] + transition_log_probs[:, j] + obs_log_probs[j, obs_inds[t]]
            inds[j, t] = np.argmax(log_probs)
            P[j, t] = log_probs[inds[j,t]]

    pred_states = np.full(T, -1, dtype=np.int32)
    pred_states[-1] = P[:, -1].argmax()
    for t in range(T-1, 0, -1):
        pred_states[t-1] = inds[pred_states[t], t]
    return pred_states

In [2]:
def log_sum_exp(x: np.ndarray) -> np.ndarray:
    """Calculate log(sum(exp(x))) using the log sum exp trick to avoid numerical instability."""
    x_max = x.max()
    return np.log(np.exp(x - x_max).sum()) + x_max

def pred_states_marginal_probs(obs_probs: np.ndarray, transition_probs: np.ndarray, obs_inds: list[int]) -> np.ndarray:
    """
    Predict hidden states using forward backward marginal state probabilities.
    Finds the hidden states maximizing the sum of probabilities of all paths reaching every individual state.

    Args:
        obs_probs (np.ndarray): Matrix of conditional probabilities of the observation given the state.
            State indexes row, and observation indexes column.
            Probability of observation i given state j = obs_probs[j, i].
        transition_probs (np.ndarray): Matrix of transition probabilities.
            State t indexes row, state t+1 indexes column.
            Probability of transitioning from state i to j = transition_probs[i, j].
        obs_inds (list[int]): List of observation indices. Used to index obs_probs.

    Returns:
        np.ndarray: Predicted state indices.
    """
    obs_log_probs = np.log(obs_probs)
    transition_log_probs = np.log(transition_probs)

    N = len(obs_probs)
    T = len(obs_inds)
    F = np.full((N, T), -np.inf, dtype=np.float64)
    B = np.full((N, T), -np.inf, dtype=np.float64)

    for j in range(N):
        F[j, 0] = obs_log_probs[j, obs_inds[0]]
        B[j, -1] = obs_log_probs[j, obs_inds[-1]]

    for t in range(1, T):
        for j in range(N):
            F[j, t] = log_sum_exp(F[:, t-1] + transition_log_probs[:, j] + obs_log_probs[j, obs_inds[t]])
            B[j, -t-1] = log_sum_exp(B[:, -t] + transition_log_probs[:, j] + obs_log_probs[j, obs_inds[-t-1]])
    
    marginal_log_probs = F + B
    marginal_log_probs -= log_sum_exp(marginal_log_probs.sum(axis=0, keepdims=True))
    pred_states = np.argmax(marginal_log_probs, axis=0)
    return pred_states

In [3]:
from pathlib import Path

# Fair is state 0, loaded is state 1.
# State indexes row, observation index (dice roll - 1) indexes col.
# Prob of a loaded 6 is at [1, 6-1].
obs_probs = np.array([
    np.ones(6, dtype=np.float64) / 6.0,
    np.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.5], dtype=np.float64)
])

# Current state indexes row, next state indexes col.
# Prob of transition from fair to loaded is at [0, 1].
transition_probs = np.array([
    [0.95, 0.05],
    [0.1, 0.9],
], dtype=np.float64)

def state_inds_to_state_str(state_inds: np.ndarray) -> str:
    """Convert state inds to a human-friendly state string."""
    return "".join(["F" if i == 0 else "L" for i in state_inds])

data_dir = Path("data_and_sols/")
for file in sorted(data_dir.glob("casino*_sols.txt")):
    with open(file) as f:
        observations, true_states = f.read().split("\n")
    obs_inds = [int(i) - 1 for i in observations]
    true_state_inds = np.array([int(s == "L") for s in true_states], dtype=np.int32)

    max_product_pred_states = pred_states_max_product(obs_probs, transition_probs, obs_inds)
    marginal_probs_pred_states = pred_states_marginal_probs(obs_probs, transition_probs, obs_inds)

    max_product_accuracy = (max_product_pred_states == true_state_inds).mean()
    marginal_probs_accuracy = (marginal_probs_pred_states == true_state_inds).mean()

    print(f"File: {file.name}")
    print(f"observations: {observations}")
    print(f"true states:  {true_states}")

    print("\nMAX PRODUCT:")
    print(f"pred states:  {state_inds_to_state_str(max_product_pred_states)}")
    print(f"accuracy: {max_product_accuracy*100:.3f}%")
    
    print("\nMARGINAL PROBS:")
    print(f"pred states:  {state_inds_to_state_str(marginal_probs_pred_states)}")
    print(f"accuracy: {marginal_probs_accuracy*100:.3f}%\n\n")

File: casino1_sols.txt
observations: 23443224462431261412355552456612616666663546661636616563556412441436124342246236511262136656662263243
true states:  FFFFFFFFFFFFFFFFFFFFFFFFFFFFFLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLLFFFFFFFFFFFFFFFFFFFFFFFFLLLLLLLLLLLLLLLL

MAX PRODUCT:
pred states:  FFFFFFFFFFFFFFFFFFFFFFFFFFFFLLLLLLLLLLLLLLLLLLLLLLLLLLLFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF
accuracy: 77.228%

MARGINAL PROBS:
pred states:  FFFFFFFFFFFFFFFFFFFFFFFFFFFFLLFFLLLLLLLLLLLLLLLLLLLLLFLFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFLLLLLLFFLFFFF
accuracy: 81.188%


File: casino2_sols.txt
observations: 23266225625664356126236666643424254611654161666666524312152423664326463664566661646426653632
true states:  FFFFFFFFFFFLLLLLLLLLLLLLLLLLFFFFFFFFFFFFLLLLLLLLLLLLLFFFFFFFFFFFFFFLLLLLLLLLLLLLLLLLLLLLLLLL

MAX PRODUCT:
pred states:  LLLLLLLLLLLLLLLLLLLLLLLLLLLFFFFFFFFFFFFFFFLLLLLLLLFFFFFFFFFFFFLLLLLLLLLLLLLLLLLLLLLLLLLLLLLL
accuracy: 76.087%

MARGINAL PROBS:
pred states:  FFFLLFFFLFFLLFFFLFFLFLLLLLLFFFFFFFFFF