In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from collections import deque
from copy import deepcopy

import numpy as np

from pomegranate import *
import torch
from torch import nn
from torch.distributions import Dirichlet, Normal
import torch.nn.functional as F
from torch.optim import Adam

In [3]:
K = 5  # number of latent state
M = 10  # length of observations
N = 25  # number of observations

mu = [np.log(z + 1) for z in range(K)]  # the mean of our observation model, i.e. E[p(x|z)]
std = [1.] * K  # standard deviation on gaussian observations

# Generate data

In [4]:
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

# Parameters
A = F.softmax(torch.randn(K, K), dim=-1)  # randomly initialize a transition matrix
pi = F.softmax(torch.randn(K), dim=-1)  # randomly initialize a distribution over the initial latent state

# Data
X, Z = [], []
for n in range(N):
    obs = []
    for m in range(M):
        if m == 0:
            z = np.random.choice(range(K), p=pi.numpy())
        else:
            z = np.random.choice(range(K), p=A[z].numpy())
        x = mu[z] + std[z] * np.random.randn()
        obs.append((z, x))
    z_m, x_m = zip(*obs)
    Z.append(z_m)
    X.append(x_m)

Z = np.array(Z); X = np.array(X)
assert Z.shape == (N, M)
assert X.shape == (N, M)

## Print first observation

In [5]:
n = 0
x, z = X[n], Z[n]

for m, (z_m, x_m) in enumerate(zip(z, x)):
    if m == 0:
        print(f'm: {m} | p(z_m):       {z_m} | p(x_m|{z_m}): {x_m:1.3}')
    else:
        print(f'm: {m} | p(z_m|z_m-1): {z_m} | p(x_m|{z_m}): {x_m:1.3}')
        
del x, z, m

m: 0 | p(z_m):       2 | p(x_m|2): 0.296
m: 1 | p(z_m|z_m-1): 0 | p(x_m|0): -0.449
m: 2 | p(z_m|z_m-1): 1 | p(x_m|1): -0.413
m: 3 | p(z_m|z_m-1): 4 | p(x_m|4): -0.0451
m: 4 | p(z_m|z_m-1): 3 | p(x_m|3): 0.369
m: 5 | p(z_m|z_m-1): 4 | p(x_m|4): 2.25
m: 6 | p(z_m|z_m-1): 0 | p(x_m|0): -0.86
m: 7 | p(z_m|z_m-1): 3 | p(x_m|3): 3.16
m: 8 | p(z_m|z_m-1): 1 | p(x_m|1): 1.26
m: 9 | p(z_m|z_m-1): 4 | p(x_m|4): 1.04


# Create models

In [50]:
class HMM:
    
    EMBED_DIM = 10
    HIDDEN_DIM = 5
    
    def __init__(self, pi=None, A=None, mu=None, std=None, seed=1):
        torch.manual_seed(seed)
        self.A = A if A is not None else F.softmax(torch.randn(K, K), dim=-1)
        self.pi = pi if pi is not None else F.softmax(torch.randn(K), dim=-1)
        self.mu = mu if mu is not None else torch.randn(K).abs()
        self.std = std if std is not None else torch.randn(K).abs()
        
    def loglik(self, z, x):
        """
        The likelihood of the given observation `x`, conditional on latent code `z`, 
        given a Gaussian observation model.
        """
        x = torch.FloatTensor([x])
        return Normal(loc=self.mu[z], scale=self.std[z]).log_prob(x)

    def factor(self, z_i, z_j, x):
        return self.A[z_i][z_j] * self.loglik(z_j, x).exp()
    
    
class GuassianObservationModel(nn.Module):
    
    def __init__(self, K, embed_dim, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(K, embed_dim)
        self.hidden = nn.Linear(embed_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, 1)
        self.mu = nn.Parameter(torch.ones(K))
        
    def forward(self, z):
        embed = self.embed(torch.LongTensor([z]))
        return torch.relu(self.out(self.hidden(embed)))  # ReLU, b/c we know obs. means are positive
    
    
class NeuralHMM(nn.Module):
    
    EMBED_DIM = 10
    HIDDEN_DIM = 5
    
    def __init__(self, pi=None, A=None, mu=None, std=None, seed=1):
        super().__init__()
        torch.manual_seed(seed)
        self._A = A if A is not None else nn.Parameter(torch.randn(K, K))
        self._pi = pi if pi is not None else nn.Parameter(torch.randn(K))
        self.mu = mu if mu is not None else nn.Parameter(torch.randn(K).abs())
        self.std = std if std is not None else nn.Parameter(torch.randn(K).abs())
        self.obs_model = GuassianObservationModel(K, self.EMBED_DIM, self.HIDDEN_DIM)
        
    @property
    def pi(self):
        return F.softmax(self._pi, dim=-1)
    
    @property
    def A(self):
        return F.softmax(self._A, dim=-1)
        
    def loglik(self, z, x):
        """
        The likelihood of the given observation `x`, conditional on latent code `z`, 
        given a Gaussian observation model.
        """
        x = torch.FloatTensor([x])
        return Normal(loc=self.obs_model(z), scale=self.std[z]).log_prob(x)
    
    def factor(self, z_i, z_j, x):
        return self.A[z_i][z_j] * self.loglik(z_j, x).exp()

# E-step

## Alpha step

The HMM factor is given by: $f(z_{n-1}, z_n) = p(z_n|z_n-1)p(X|z_n)$, where $n$ is
our current index on the chain.

The message from factor to variable, moving forward along the chain, is given by:

$$
\sum\limits_{z_{n-1}} f(z_{n-1}, z_n) * \mu_{f_{n-1} \rightarrow z_{n-1}}
$$

In an efficient implementation, for each value $z_j$, we would compute $p(X|z_n)$ upfront then multiply it by the summation. Below, we recompute it inside each term of the summation unnecessarily for demonstrative purposes.

In [7]:
def alpha_step(model, x):
    alpha = []
    for m in range(M):
        if m == 0:
            a_m = [model.pi[z] * model.loglik(z, x[m]).exp() for z in range(K)]

        else:
            a_m = [sum(model.factor(z_i, z_j, x[m]) * alpha[-1][z_i] for z_i in range(K)) for z_j in range(K)]
        alpha.append(torch.FloatTensor(a_m))
    return torch.stack(alpha)

## Beta step

In [8]:
def beta_step(model, x):
    beta = []
    for m in reversed(range(M)):
        if m == M - 1:
            b_m = [1. for _ in range(K)]
        else:
            b_m = [sum(model.factor(z_i, z_j, x[m+1]) * beta[-1][z_j] for z_j in range(K)) for z_i in range(K)]
        beta.append(torch.FloatTensor(b_m))
    return torch.stack(beta).flip(0)  # NB: we flip the axes back!

## Posterior marginals

In [9]:
def compute_gamma(alpha, beta):
    gamma_ = alpha * beta
    evd = gamma_.sum(1)
    # p(X) = \sum_{z_n} \alpha(z_n) * \beta(z_n), for any choice of n!
    assert np.allclose(evd[0], evd)
    gamma = gamma_ / evd[0]
    assert np.allclose(gamma.sum(1), 1.)
    return gamma, evd[0]

## Posterior transition matrices

In [10]:
def compute_zeta(model, alpha, beta, evd, x):
    zeta_ = []
    for m in range(M-1):
        liks = np.array([model.loglik(z, x[m+1]).exp().item() for z in range(K)])
        zeta_m = np.outer(alpha[m], beta[m+1]) * model.A.detach().numpy() * liks
        zeta_.append(zeta_m)

    zeta = torch.FloatTensor(zeta_).clone() / evd
    assert all([np.allclose(zta.sum(), 1.) for zta in zeta])
    
    return zeta

## Altogether

In [11]:
def e_step(model, x):
    alpha = alpha_step(model, x)
    beta = beta_step(model, x)
    gamma, evd = compute_gamma(alpha, beta)
    zeta = compute_zeta(model, alpha, beta, evd, x)
    return gamma, evd, zeta

# M-step

In [12]:
def train_via_em(data, model, n_epochs, verbose=False, evd_tolerance=1e-4):
    
    X = data.copy()
    model = deepcopy(model)
    prev_evd = float('-inf')
    
    for n in range(n_epochs):
        # E-step (compute posteriors)
        gamma, evd, zeta = zip(*[e_step(model, x) for x in X])
        evd = np.sum(evd)

        # M-step (update parameters)
        ## Pi
        model.pi = sum(gamma)[0] / sum(gamma)[0].sum()
        ## A
        zeta = sum([zta.sum(0) for zta in zeta])
        model.A = zeta / zeta.sum(1)[:, None]
        ## mu
        for z in range(K):
            model.mu[z] = sum((g[:, z].numpy() * x).sum() for g, x in zip(gamma, X)) / sum(g[:, z].sum() for g in gamma)

        assert np.allclose(model.pi.sum(), 1.)
        assert np.allclose(model.A.sum(1), 1.)
        assert evd >= (prev_evd - evd_tolerance)

        if (n % (n_epochs / 10) == 0 and n != 0) or verbose:
            print(f'Epoch {n} | P(X): {evd:1.5}')

        prev_evd = evd
        
    gamma, evd, zeta = zip(*[e_step(model, x) for x in X])
    evd = np.sum(evd)
        
    return model, evd, gamma, zeta

## Max-sum (Viterbi algorithm)

In [13]:
def compute_max_sum_messages(model, x):
    omega = []

    for m in range(M):
        if m == 0:
            o_m = [model.pi[z].log() + model.loglik(z, x[m]) for z in range(K)]
            o_m_idx = [None] * len(o_m)
        else:
            o_m = []
            o_m_idx = []
            for z_j in range(K):
                ll = model.loglik(z_j, x[m])
                mx, mx_idx = torch.tensor([ll + model.A[z_i][z_j].log() + omega[-1][z_i][0] for z_i in range(K)]).max(0)
                o_m.append(mx)
                o_m_idx.append(mx_idx)
        o = [(t.item(), i.item()) if i is not None else (t.item(), i) for t, i in zip(o_m, o_m_idx)]
        omega.append(o)
    return omega


def backtrack(omega):
    m = M
    configs = []
    while omega:
        o_m = omega.pop()
        vals, idxs = zip(*o_m)
        if M == 1:
            max_val = max(vals)
            configs.append([vals.index(max_val)])
        elif all([i is not None for i in idxs]):
            if m == M:
                max_val = max(vals)
                for i, v in enumerate(vals):
                    if v == max_val:
                        c = deque([idxs[i], i])
                        configs.append(c)
            else:
                for c in configs:
                    phi = c[0]
                    c.appendleft(idxs[phi])
            m -= 1
    configs = [tuple(c) for c in configs]
    return max_val, configs


def viterbi(model, x):
    omega = compute_max_sum_messages(model, x)
    max_val, configs = backtrack(omega)
    return max_val, configs

## Test

In [14]:
x = X[0]

# Theirs
dists = [NormalDistribution(m, sd) for m, sd in zip(mu, std)]
trans_mat = A.numpy()
starts = pi.numpy()
their_model = HiddenMarkovModel.from_matrix(trans_mat, dists, starts)

# Ours
our_model = HMM(pi=pi, A=A, mu=mu, std=std)

# Test

## E-step
gamma, evd, zeta = e_step(our_model, x)

### Log-prob
assert np.allclose(np.exp(their_model.log_probability(x)), evd.item())

### Transitions, emissions
transitions, emissions = their_model.forward_backward(x)
assert np.allclose(np.exp(emissions), gamma)
assert np.allclose(transitions[:K, :K], zeta.sum(0))

## Viterbi
their_max_val, their_states = their_model.viterbi(x)
their_states, _ = zip(*their_states[1:])
our_max_val, our_states = viterbi(our_model, x)
assert their_states in our_states
assert np.allclose(their_max_val, our_max_val)

## M-step
assert np.allclose(our_model.pi.numpy(), their_model.dense_transition_matrix()[their_model.start_index, :][:K])
assert np.allclose(our_model.A.numpy(), their_model.dense_transition_matrix()[:K, :K])
_ = their_model.fit(X, min_iterations=1, max_iterations=1, algorithm='baum-welch', stop_threshold=1e-15)
mdl, evd, gamma, zeta = train_via_em(data=X, model=our_model, n_epochs=1)
assert np.allclose(mdl.pi, their_model.dense_transition_matrix()[their_model.start_index, :][:K])
assert np.allclose(mdl.A, their_model.dense_transition_matrix()[:K, :K])
their_mu, their_std = zip(*[s.distribution.parameters for s in their_model.get_params()['states'][:K]])
assert np.allclose(their_mu, [m.item() for m in mdl.mu])

# Vanilla HMM

## Train

In [15]:
N_EPOCHS = 20

model = HMM(pi=pi, A=A, mu=mu)

model = train_via_em(X, model, N_EPOCHS, verbose=True)

Epoch 0 | P(X): 7.6105e-05
Epoch 1 | P(X): 9.2247e-05
Epoch 2 | P(X): 9.7944e-05
Epoch 3 | P(X): 0.00012207
Epoch 4 | P(X): 0.00014187
Epoch 5 | P(X): 0.00014949
Epoch 6 | P(X): 0.0001483
Epoch 7 | P(X): 0.00014353
Epoch 8 | P(X): 0.00013786
Epoch 9 | P(X): 0.00013223
Epoch 10 | P(X): 0.00012708
Epoch 11 | P(X): 0.00012262
Epoch 12 | P(X): 0.00011897
Epoch 13 | P(X): 0.00011609
Epoch 14 | P(X): 0.00011391
Epoch 15 | P(X): 0.00011232
Epoch 16 | P(X): 0.00011125
Epoch 17 | P(X): 0.0001106
Epoch 18 | P(X): 0.00011033
Epoch 19 | P(X): 0.00011038


# M-step

In [54]:
def m_step(X, model, gamma, zeta):
    tgt = 0
    
    for x in X:
        tgt += (gamma[0] * model.pi.log()).sum()
        for zta in zeta:
            tgt += (zta * model.A.log()).sum()
        for gma in gamma:
            for z in range(K):
                tgt += (gma[z] * model.loglik(z, x)).sum()

    return tgt

# Neural HMM

## WIP: Train

In [58]:
N_EPOCHS = 10
LR = .01
VERBOSE = True

model = NeuralHMM()
optim = Adam(model.parameters(), lr=LR)

for n in range(N_EPOCHS):
    gamma, evd, zeta = zip(*[e_step(model, x) for x in X])
    gamma, zeta, evd = sum(gamma), sum(zeta), sum(evd)
    tgt = m_step(X, model, gamma, zeta)
    (-tgt).backward()
    optim.step()
    
    assert np.allclose(model.pi.sum().detach().numpy(), 1.)
    assert np.allclose(model.A.sum(1).detach().numpy(), 1.)
    
    if n % (N_EPOCHS / 10) == 0 and n != 0 or VERBOSE:
        print(f'Epoch {n} | P(X): {evd:1.5}')

Epoch 0 | P(X): 3.1525e-08
Epoch 1 | P(X): 3.6446e-08
Epoch 2 | P(X): 4.1838e-08
Epoch 3 | P(X): 4.7827e-08
Epoch 4 | P(X): 5.4471e-08
Epoch 5 | P(X): 6.1793e-08
Epoch 6 | P(X): 6.9801e-08
Epoch 7 | P(X): 7.8485e-08
Epoch 8 | P(X): 8.7826e-08
Epoch 9 | P(X): 9.7797e-08
