[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mx60s/factorial-hmm/blob/main/factorial_hmm_clean.ipynb)

# Factorial HMM: Belief Geometry and Explaining Away

I went looking for HMMs which could satisfy some conditions for the project: ergodic, naturalistic for LLMs, and requires "representations to track structures more elaborate than square and ring graphs". To that end I found this paper: https://www.ee.columbia.edu/~sfchang/course/svia-F03/papers/factorial-HMM-97.pdf

The coolest part of the paper is the modelling of Bach's chorales but I think factorial HMMs could be implicated in LLM pretraining in quite a few different ways, so it's specific enough to find structures but vague enough to be expanded upon for future work.

A factorial HMM has multiple independent Markov chains that interact only through shared observations. Chains that are independent in the prior become coupled in the posterior and this is called "explaining away".

Here's a simple model:
- **Chain 1 (Formality)**: {Formal, Informal}
- **Chain 2 (Topic)**: {Technical, Casual}
- **Joint states**: (F,T), (F,C), (I,T), (I,C)

In [1]:
# Colab setup
import sys
if 'google.colab' in sys.modules:
    !pip install -q plotly

import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

class FactorialHMM:
    def __init__(self, vocab_size='8-token'):
        self.A1 = np.array([[0.8, 0.2], [0.3, 0.7]])
        self.A2 = np.array([[0.7, 0.3], [0.4, 0.6]])
        self.A_joint = np.kron(self.A1, self.A2)
        self.pi = np.array([0.25, 0.25, 0.25, 0.25])
        self.state_names = ['(F,T)', '(F,C)', '(I,T)', '(I,C)']
        self.vocab_size = vocab_size
        if vocab_size == '8-token':
            self.token_names = ['shall', 'gonna', 'algorithm', 'stuff', 'the', 'optimize', 'hey', 'pursuant']
            self.B = np.array([
                [0.15, 0.02, 0.25, 0.02, 0.20, 0.20, 0.01, 0.15],
                [0.20, 0.05, 0.05, 0.10, 0.30, 0.05, 0.05, 0.20],
                [0.02, 0.15, 0.25, 0.10, 0.20, 0.18, 0.05, 0.05],
                [0.02, 0.25, 0.03, 0.20, 0.20, 0.02, 0.20, 0.08],
            ])
        else:
            self.token_names = ['formal_tech', 'formal_casual', 'informal_tech', 'informal_casual']
            self.B = np.array([
                [0.70, 0.15, 0.10, 0.05],
                [0.15, 0.70, 0.05, 0.10],
                [0.10, 0.05, 0.70, 0.15],
                [0.05, 0.10, 0.15, 0.70],
            ])

    def sample(self, n_steps, seed=None):
        if seed is not None:
            np.random.seed(seed)
        states = np.zeros(n_steps, dtype=int)
        observations = np.zeros(n_steps, dtype=int)
        states[0] = np.random.choice(4, p=self.pi)
        observations[0] = np.random.choice(len(self.B[0]), p=self.B[states[0]])
        for t in range(1, n_steps):
            states[t] = np.random.choice(4, p=self.A_joint[states[t-1]])
            observations[t] = np.random.choice(len(self.B[0]), p=self.B[states[t]])
        return states, observations

    def forward(self, observations):
        T = len(observations)
        beliefs = np.zeros((T, 4))
        alpha = self.pi * self.B[:, observations[0]]
        alpha = alpha / alpha.sum()
        beliefs[0] = alpha
        for t in range(1, T):
            alpha_pred = self.A_joint.T @ alpha
            alpha = alpha_pred * self.B[:, observations[t]]
            alpha = alpha / alpha.sum()
            beliefs[t] = alpha
        return beliefs

    def compute_coupling(self, joint_beliefs):
        T = len(joint_beliefs)
        coupling = np.zeros(T)
        for t in range(T):
            p = joint_beliefs[t]
            p_f, p_i = p[0]+p[1], p[2]+p[3]
            p_t, p_c = p[0]+p[2], p[1]+p[3]
            p_indep = np.array([p_f*p_t, p_f*p_c, p_i*p_t, p_i*p_c])
            coupling[t] = np.linalg.norm(p - p_indep)
        return coupling

hmm = FactorialHMM(vocab_size='8-token')

In [2]:
# transition matrices
print("Chain 1 (Formality):        Chain 2 (Topic):")
print("         F     I                   T     C")
print(f"    F [{hmm.A1[0,0]:.1f}   {hmm.A1[0,1]:.1f}]         T [{hmm.A2[0,0]:.1f}   {hmm.A2[0,1]:.1f}]")
print(f"    I [{hmm.A1[1,0]:.1f}   {hmm.A1[1,1]:.1f}]         C [{hmm.A2[1,0]:.1f}   {hmm.A2[1,1]:.1f}]")

Chain 1 (Formality):        Chain 2 (Topic):
         F     I                   T     C
    F [0.8   0.2]         T [0.7   0.3]
    I [0.3   0.7]         C [0.4   0.6]


## The Belief Simplex

Beliefs over 4 states live on a 3-simplex (tetrahedron). The **product manifold** is where P(S₁,S₂) = P(S₁)×P(S₂) — i.e., where chains are independent.

- Prior dynamics preserve products: if you start on the manifold, you stay on it
- Observations make the posterior move off the manifold

In [3]:
# Tetrahedron vertices for 4-state simplex
VERTICES = np.array([[1,1,1], [1,-1,-1], [-1,1,-1], [-1,-1,1]]) / np.sqrt(3)
STATE_NAMES = ['(F,T)', '(F,C)', '(I,T)', '(I,C)']

def belief_to_3d(belief):
    return belief @ VERTICES

def create_simplex_figure():
    """Create interactive 3D simplex with product manifold."""
    fig = go.Figure()

    # Tetrahedron edges
    edges = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
    for i, j in edges:
        fig.add_trace(go.Scatter3d(
            x=[VERTICES[i,0], VERTICES[j,0]],
            y=[VERTICES[i,1], VERTICES[j,1]],
            z=[VERTICES[i,2], VERTICES[j,2]],
            mode='lines', line=dict(color='gray', width=2),
            showlegend=False, hoverinfo='skip'
        ))

    # Vertex labels
    fig.add_trace(go.Scatter3d(
        x=VERTICES[:,0]*1.15, y=VERTICES[:,1]*1.15, z=VERTICES[:,2]*1.15,
        mode='text', text=STATE_NAMES, textfont=dict(size=12),
        showlegend=False, hoverinfo='skip'
    ))

    # Product manifold surface
    n = 20
    p1 = np.linspace(0.01, 0.99, n)
    p2 = np.linspace(0.01, 0.99, n)
    X, Y, Z = np.zeros((n,n)), np.zeros((n,n)), np.zeros((n,n))
    for i, pf in enumerate(p1):
        for j, pt in enumerate(p2):
            joint = np.array([pf*pt, pf*(1-pt), (1-pf)*pt, (1-pf)*(1-pt)])
            pt3d = belief_to_3d(joint)
            X[i,j], Y[i,j], Z[i,j] = pt3d

    fig.add_trace(go.Surface(
        x=X, y=Y, z=Z, opacity=0.3, colorscale=[[0,'green'],[1,'green']],
        showscale=False, name='Product Manifold',
        hovertemplate='Product Manifold<extra></extra>'
    ))

    fig.update_layout(
        scene=dict(xaxis_title='', yaxis_title='', zaxis_title='',
                   aspectmode='cube'),
        margin=dict(l=0, r=0, t=30, b=0), height=500
    )
    return fig

Prior vs Posterior Trajectories

**Prior** (green): Evolve beliefs using only transition dynamics.
**Posterior** (blue): Update beliefs with actual observations.

The prior trajectory stays exactly on the product manifold (coupling = 0).  
The posterior trajectory moves off it (coupling > 0) — this is explaining away.

In [4]:
n_steps = 40
states, observations = hmm.sample(n_steps, seed=42)

posterior = hmm.forward(observations)

prior = np.zeros((n_steps, 4))
belief = hmm.pi.copy()
for t in range(n_steps):
    prior[t] = belief
    belief = hmm.A_joint.T @ belief

prior_coupling = hmm.compute_coupling(prior)
posterior_coupling = hmm.compute_coupling(posterior)

In [5]:
fig = create_simplex_figure()

# Prior trajectory
prior_3d = np.array([belief_to_3d(b) for b in prior])
fig.add_trace(go.Scatter3d(
    x=prior_3d[:,0], y=prior_3d[:,1], z=prior_3d[:,2],
    mode='lines+markers', marker=dict(size=3, color='darkgreen'),
    line=dict(color='green', width=4), name='Prior (no obs)'
))

# Posterior trajectory
post_3d = np.array([belief_to_3d(b) for b in posterior])
fig.add_trace(go.Scatter3d(
    x=post_3d[:,0], y=post_3d[:,1], z=post_3d[:,2],
    mode='lines+markers',
    marker=dict(size=4, color=np.arange(n_steps), colorscale='Blues', showscale=True,
                colorbar=dict(title='Time', x=1.0, len=0.5)),
    line=dict(color='blue', width=2), name='Posterior (with obs)'
))

fig.update_layout(title='Belief Trajectories: Prior Stays on Manifold, Posterior Moves Off')
fig.show()

In [6]:
# Coupling over time
fig = go.Figure()
fig.add_trace(go.Scatter(y=prior_coupling, name='Prior', line=dict(color='green', width=2)))
fig.add_trace(go.Scatter(y=posterior_coupling, name='Posterior', line=dict(color='blue', width=2)))
fig.update_layout(
    title='Coupling: Distance from Independence',
    xaxis_title='Time', yaxis_title='||P(S₁,S₂) - P(S₁)P(S₂)||',
    height=300
)
fig.show()

print(f"Prior coupling:  mean={prior_coupling.mean():.6f} (should be ≈0)")
print(f"Posterior coupling: mean={posterior_coupling.mean():.4f}")

Prior coupling:  mean=0.000000 (should be ≈0)
Posterior coupling: mean=0.0233
