In [271]:
import numpy as np

import torch
from torch import nn
from torch.distributions import Dirichlet, Normal
import torch.nn.functional as F

In [272]:
K = 5  # number of latent state
M = 10  # length of observations
N = 50  # number of observations

sd = .05  # standard deviation on gaussian observations
mu = lambda z: np.log(z + 1)  # the mean of our observation model, i.e. E[p(x|z)]

# Create dummy data

In [273]:
# Parameters
A = F.softmax(torch.randn(K, K), dim=1)  # randomly initialize a transition matrix
pi = torch.FloatTensor(K * [1 / K])  # uniform distribution over the initial latent state

# Data
X, Z = [], []
for n in range(N):
    obs = []
    for m in range(M):
        if m == 0:
            z = np.random.choice(range(K), p=pi.numpy())
        else:
            z = np.random.choice(range(K), p=A[z].numpy())
        x = mu(z) + sd * np.random.randn()
        obs.append((z, x))
    z_m, x_m = zip(*obs)
    Z.append(z_m)
    X.append(x_m)

Z = np.array(Z); X = np.array(X)
assert Z.shape == (N, M)
assert X.shape == (N, M)

# Create models

In [274]:
class GuassianObservationModel(nn.Module):
    
    def __init__(self, K, embed_dim, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(K, embed_dim)
        self.hidden = nn.Linear(embed_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, 1)
        
    def forward(self, z):
        embed = self.embed(z)
        return torch.relu(self.out(self.hidden(embed)))  # ReLU, b/c we know obs. means are positive
    
    
class HMM:
    
    EMBED_DIM = 20
    HIDDEN_DIM = 10
    
    def __init__(self, pi, A, X, seed=1):
        torch.manual_seed(1)
        self.pi = pi
        self.A = A
        self.X = torch.tensor(X, requires_grad=False, dtype=torch.float32)
        self.obs_model = GuassianObservationModel(K, self.EMBED_DIM, self.HIDDEN_DIM)
        
    def log_lik(self, z, m):
        """
        The log-likelihood of *all* of our observed data, in the m-th index of the sequential
        observation, conditional on latent code `z`, given a Gaussian observation model.
        
        This opposite of this quantity is also known as the "mean-squared error" between the 
        observed data and our "prediction" thereof, which can be computed via the L2-norm of 
        the differences between each datum, and the output of our model given `z`.
        """
        z = torch.LongTensor([z])
        return -(self.X[:, m] - self.obs_model(z)).norm()
        
    def factor(self, z_i, z_j, m):
        return self.A[z_i][z_j] * self.log_lik(z_j, m)

# Setup

In [275]:
hmm = HMM(pi, A, X)

# E-step

## Alpha step

The HMM factor is given by: $f(z_{n-1}, z_n) = p(z_n|z_n-1)p(X|z_n)$, where $n$ is
our current index on the chain.

The message from factor to variable, moving forward along the chain, is given by:

$$
\sum\limits_{z_{n-1}} f(z_{n-1}, z_n) * \mu_{f_{n-1} \rightarrow z_{n-1}}
$$

In an efficient implementation, for each value $z_j$, we would compute $p(X|z_n)$ upfront then multiply it by the summation. Below, we recompute it inside each term of the summation unnecessarily for demonstrative purposes.

In [297]:
alpha = []

for m in range(M):
    if m == 0:
        a_m = [hmm.log_lik(z, m) * hmm.pi[z] for z in range(K)]
    else:
        a_m = [sum(hmm.factor(z_i, z_j, m) * alpha[-1][z_i] for z_i in range(K)) for z_j in range(K)]
    alpha.append(a_m)
alpha = torch.FloatTensor(alpha)

## Beta step

In [298]:
beta = []

for m in reversed(range(1, M + 1)):
    if m == M:
        b_m = np.ones(K)
    else:
        b_m = [sum(hmm.factor(z_i, z_j, m) * beta[-1][z_j] for z_j in range(K)) for z_i in range(K)]
    beta.append(b_m)
beta = torch.FloatTensor(beta).flip(0)  # NB: we flip the axes back!

## Posterior marginals

In [299]:
gamma_ = alpha * beta
evidence = gamma_.sum(1)

# p(X) = \sum_{z_n} \alpha(z_n) * \beta(z_n), for any choice of n!
assert np.allclose(evidence[0], evidence)

gamma = gamma_ / evidence[0]
assert np.allclose(gamma.sum(1), 1.)

## Posterior transition matrices

In [300]:
zeta_ = []
for m in range(1, M):
    log_liks = np.array([hmm.log_lik(z, m).item() for z in range(K)])
    zeta_m = np.outer(alpha[m-1], beta[m]) * hmm.A.numpy() * log_liks
    zeta_.append(zeta_m)

zeta = torch.FloatTensor(zeta_) / evidence[0]
assert all([np.allclose(zta.sum(), 1.) for zta in zeta])

# M-step

# Max-sum (Viterbi algorithm)

### Pass max-messages forward

In [301]:
omega = []

for m in range(M):
    if m == 0:
        o_m = [hmm.pi[m].log() + hmm.log_lik(z, m) for z in range(K)]
        o_m_idx = [None] * len(o_m)
    else:
        o_m = []
        o_m_idx = []
        for z_j in range(K):
            ll = hmm.log_lik(z_j, m)
            mx, mx_idx = torch.tensor([ll + hmm.A[z_i][z_j] + omega[-1][z_i][0] for z_i in range(K)]).max(0)
            o_m.append(mx)
            o_m_idx.append(mx_idx)
    o = [(t.item(), i.item()) if i is not None else (t.item(), i) for t, i in zip(o_m, o_m_idx)]
    omega.append(o)

### Backtrack

In [302]:
m = M
configs = []
while omega:
    o_m = omega.pop()
    vals, idxs = zip(*o_m)
    if all([i is not None for i in idxs]):
        if m == M:
            max_val = max(vals)
            for i, v in enumerate(vals):
                if v == max_val:
                    configs.append([i, idxs[i]])
        else:
            for c in configs:
                phi = c[-1]
                c.append(idxs[phi])
        m -= 1
        
print(f'Max val: {max_val}')
print(f'Maximizing configs: {configs}')

Max val: -73.41764831542969
Maximizing configs: [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
