In [94]:
import numpy as np
import torch
from torch import nn
from torch.distributions import Dirichlet, Normal
import torch.nn.functional as F
from torch.optim import Adam

In [223]:
K = 5  # number of latent state
M = 10  # length of observations
N = 2  # number of observations

sd = 1  # standard deviation on gaussian observations
mu = lambda z: np.log(z + 1)  # the mean of our observation model, i.e. E[p(x|z)]

# Generate data

In [224]:
SEED = 1
np.random.seed(SEED)
torch.manual_seed(SEED)

# Parameters
A = F.softmax(torch.randn(K, K), dim=-1)  # randomly initialize a transition matrix
pi = F.softmax(torch.randn(K), dim=-1)  # randomly initialize a distribution over the initial latent state

# Data
X, Z = [], []
for n in range(N):
    obs = []
    for m in range(M):
        if m == 0:
            z = np.random.choice(range(K), p=pi.numpy())
        else:
            z = np.random.choice(range(K), p=A[z].numpy())
        x = mu(z) + sd * np.random.randn()
        obs.append((z, x))
    z_m, x_m = zip(*obs)
    Z.append(z_m)
    X.append(x_m)

Z = np.array(Z); X = np.array(X)
assert Z.shape == (N, M)
assert X.shape == (N, M)

## Print first observation

In [225]:
m = 0
x, z = X[m], Z[m]

for m, (z_m, x_m) in enumerate(zip(z, x)):
    if m == 0:
        print(f'm: {m} | p(z_m):       {z_m} | p(x_m|{z_m}): {x_m:1.3}')
    else:
        print(f'm: {m} | p(z_m|z_m-1): {z_m} | p(x_m|{z_m}): {x_m:1.3}')

m: 0 | p(z_m):       2 | p(x_m|2): 0.296
m: 1 | p(z_m|z_m-1): 0 | p(x_m|0): -0.449
m: 2 | p(z_m|z_m-1): 1 | p(x_m|1): -0.413
m: 3 | p(z_m|z_m-1): 4 | p(x_m|4): -0.0451
m: 4 | p(z_m|z_m-1): 3 | p(x_m|3): 0.369
m: 5 | p(z_m|z_m-1): 4 | p(x_m|4): 2.25
m: 6 | p(z_m|z_m-1): 0 | p(x_m|0): -0.86
m: 7 | p(z_m|z_m-1): 3 | p(x_m|3): 3.16
m: 8 | p(z_m|z_m-1): 1 | p(x_m|1): 1.26
m: 9 | p(z_m|z_m-1): 4 | p(x_m|4): 1.04


# Create models

In [226]:
class GuassianObservationModel(nn.Module):
    
    def __init__(self, K, embed_dim, hidden_dim):
        super().__init__()
#         self.embed = nn.Embedding(K, embed_dim)
#         self.hidden = nn.Linear(embed_dim, hidden_dim)
#         self.out = nn.Linear(hidden_dim, 1)
        self.mu = nn.Parameter(torch.ones(K))
        
    def forward(self, z):
#         embed = self.embed(z)
#         return torch.relu(self.out(self.hidden(embed)))  # ReLU, b/c we know obs. means are positive
        return self.mu[z]


class HMM:
    
    EMBED_DIM = 10
    HIDDEN_DIM = 5
    OBS_MODEL_STDDEV = 1
    
    def __init__(self, pi, A, X, seed=1):
        self.pi = torch.ones(K) / K
        self.A = torch.ones(K, K) / K
        self.mu = np.zeros(K)
        self.X = torch.FloatTensor(X)
        
    def lik(self, z, x):
        """
        The likelihood of *all* of our observed data, in the m-th index of the sequential
        observation, conditional on latent code `z`, given a Gaussian observation model.
        """
        x = torch.FloatTensor([x])
        return Normal(loc=self.mu[z], scale=self.OBS_MODEL_STDDEV).log_prob(x).exp()

    def factor(self, z_i, z_j, z):
        return self.A[z_i][z_j] * self.lik(z_j, z)
    
    
class NeuralHMM(nn.Module):
    
    EMBED_DIM = 10
    HIDDEN_DIM = 5
    OBS_MODEL_STDDEV = 1
    
    def __init__(self, pi, A, X, seed=1):
        super().__init__()
        self._pi = nn.Parameter(torch.ones(K) / K)
        self._A = nn.Parameter(torch.ones(K, K) / K)
        self.X = torch.FloatTensor(X)
        self.obs_model = GuassianObservationModel(K, self.EMBED_DIM, self.HIDDEN_DIM)
        
    @property
    def pi(self):
        return F.softmax(self._pi, dim=-1)
    
    @property
    def A(self):
        return F.softmax(self._A, dim=-1)
        
    def lik(self, z, m):
        """
        The likelihood of *all* of our observed data, in the m-th index of the sequential
        observation, conditional on latent code `z`, given a Gaussian observation model.
        """
        z = torch.LongTensor([z])
        X_m = self.X[:, m]
        return Normal(loc=self.obs_model(z), scale=self.OBS_MODEL_STDDEV).log_prob(X_m).sum().exp()

    def factor(self, z_i, z_j, x):
        return self.A[z_i][z_j] * self.lik(z_j, m)

# E-step

## Alpha step

The HMM factor is given by: $f(z_{n-1}, z_n) = p(z_n|z_n-1)p(X|z_n)$, where $n$ is
our current index on the chain.

The message from factor to variable, moving forward along the chain, is given by:

$$
\sum\limits_{z_{n-1}} f(z_{n-1}, z_n) * \mu_{f_{n-1} \rightarrow z_{n-1}}
$$

In an efficient implementation, for each value $z_j$, we would compute $p(X|z_n)$ upfront then multiply it by the summation. Below, we recompute it inside each term of the summation unnecessarily for demonstrative purposes.

In [227]:
# K = 2
# M = 5


# class HMM:
    
#     EMBED_DIM = 10
#     HIDDEN_DIM = 5
#     OBS_MODEL_STDDEV = 1
    
#     def __init__(self, pi, A, X, B, seed=1):
#         self.pi = pi
#         self.A = A
#         self.X = X
#         self.B = B
        
#     def lik(self, z, m):
#         """
#         The likelihood of *all* of our observed data, in the m-th index of the sequential
#         observation, conditional on latent code `z`, given a Gaussian observation model.
#         """
#         return self.B[z][self.X[m]]

#     def factor(self, z_i, z_j, m):
#         return self.A[z_i][z_j] * self.lik(z_j, m)

# A = np.array([[.7, .3], [.3, .7]])
# B = np.array([[.9, .1], [.2, .8]])
# X = [0, 0, 1, 0, 0]
# pi = np.array([.5, .5])

# hmm = HMM(pi, A, X, B)

In [228]:
def alpha_step(hmm, x):
    alpha = [hmm.pi]
    for m in range(M):
        a_m = [sum(hmm.factor(z_i, z_j, x[m]) * alpha[-1][z_i] for z_i in range(K)) for z_j in range(K)]
        alpha.append(torch.FloatTensor(a_m))
    return torch.stack(alpha).clone()

## Beta step

In [229]:
def beta_step(hmm, x):
    beta = [torch.ones(K)]
    for m in reversed(range(M)):
        b_m = [sum(hmm.factor(z_i, z_j, x[m]) * beta[-1][z_j] for z_j in range(K)) for z_i in range(K)]
        beta.append(torch.FloatTensor(b_m))
    return torch.stack(beta).flip(0).clone()  # NB: we flip the axes back!

## Posterior marginals

In [230]:
def compute_gamma(alpha, beta):
    gamma_ = alpha * beta
    evidence = gamma_.sum(1)

    # p(X) = \sum_{z_n} \alpha(z_n) * \beta(z_n), for any choice of n!
    assert np.allclose(evidence[0], evidence)

    gamma = gamma_ / evidence[0]
    assert np.allclose(gamma.sum(1), 1.)
    
    return gamma, evidence[0]

## Posterior transition matrices

In [231]:
def compute_zeta(alpha, beta, evidence, x):
    zeta_ = []
    for m in range(M):
        liks = np.array([hmm.lik(z, x[m]).item() for z in range(K)])
        zeta_m = np.outer(alpha[m-1], beta[m]) * hmm.A.detach().numpy() * liks
        zeta_.append(zeta_m)

    zeta = torch.FloatTensor(zeta_).clone() / evidence
#     assert all([np.allclose(zta.sum(), 1.) for zta in zeta])
    
    return zeta

## Altogether

In [232]:
def e_step(hmm, x):
    alpha = alpha_step(hmm, x)
    beta = beta_step(hmm, x)
    gamma, evidence = compute_gamma(alpha, beta)
    zeta = compute_zeta(alpha, beta, evidence, x)
    return gamma, evidence, zeta

# Vanilla HMM

## Train

In [239]:
N_EPOCHS = 20
PRINT_EVERY = True
LR = .001

hmm = HMM(pi, A, X)


for n in range(N_EPOCHS):
    gamma, evidence, zeta = 0, 0, 0
    for x in X:
        g, e, z = e_step(hmm, x)
        gamma += g
        evidence += e
        zeta += z
    gamma /= len(X)
    evidence /= len(X)
    zeta /= len(X)
    
    gamma = gamma[1:, :]
    
    # M-step (update parameters)
    ## Pi
    hmm.pi = gamma[0] / gamma[0].sum()
    ## A
    for z_i in range(K):
        for z_j in range(K):
            hmm.A[z_i][z_j] = sum([zta[z_i, z_j] for zta in zeta]) / sum([zta[z_i].sum() for zta in zeta])
    ## mu
    hmm.mu = [(hmm.X * gamma[:, k]).mean().item() / gamma[:, k].sum().item() for k in range(K)]
    
    assert np.allclose(hmm.pi.sum(), 1.)
    assert np.allclose(hmm.A.sum(1), 1.)
    
    if (n % (N_EPOCHS / 10) == 0 and n != 0) or PRINT_EVERY:
        print(f'Epoch {n} | P(X): {evidence:1.5}')

Epoch 0 | P(X): 3.8057e-09
Epoch 1 | P(X): 7.1645e-09
Epoch 2 | P(X): 7.1645e-09
Epoch 3 | P(X): 7.1645e-09
Epoch 4 | P(X): 7.1645e-09
Epoch 5 | P(X): 7.1645e-09
Epoch 6 | P(X): 7.1645e-09
Epoch 7 | P(X): 7.1645e-09
Epoch 8 | P(X): 7.1645e-09
Epoch 9 | P(X): 7.1645e-09
Epoch 10 | P(X): 7.1645e-09
Epoch 11 | P(X): 7.1645e-09
Epoch 12 | P(X): 7.1645e-09
Epoch 13 | P(X): 7.1645e-09
Epoch 14 | P(X): 7.1645e-09
Epoch 15 | P(X): 7.1645e-09
Epoch 16 | P(X): 7.1645e-09
Epoch 17 | P(X): 7.1645e-09
Epoch 18 | P(X): 7.1645e-09
Epoch 19 | P(X): 7.1645e-09


# M-step

In [None]:
def m_step(hmm, gamma, zeta):
    tgt = 0

    tgt += (gamma[0] * hmm.pi.log()).sum()
    for zta in zeta:
        tgt += (zta * hmm.A.log()).sum()
    for m, gma in enumerate(gamma):
        for z in range(K):
            tgt += gma[z] * hmm.lik(z, m).log()
    
    return tgt

# Neural HMM

## Train

In [None]:
import numpy as np
from hmmlearn import hmm
np.random.seed(42)

model = hmm.GaussianHMM(n_components=3, covariance_type="full")
model.startprob_ = np.array([0.6, 0.3, 0.1])
model.transmat_ = np.array([[0.7, 0.2, 0.1], [0.3, 0.5, 0.2], [0.3, 0.3, 0.4]])
model.means_ = np.array([[0.0, 0.0], [3.0, -3.0], [5.0, 10.0]])
model.covars_ = np.tile(np.identity(2), (3, 1, 1))

In [None]:
X = np.concatenate([X1, X2])
lengths = [len(X1), len(X2)]
model = hmm.GaussianHMM(n_components=3)
model.fit(X, lengths)  

In [None]:
X = np.random.randn(100, 3)

lengths = [3] * 100

In [None]:
model.fit(X, lengths)

In [None]:
N_EPOCHS = 100
LR = .01

hmm = NeuralHMM(pi, A, X)
optim = Adam(hmm.parameters(), lr=LR)

for n in range(N_EPOCHS):
    gamma, evidence, zeta = e_step(hmm)
    tgt = m_step(hmm, gamma, zeta)
    (-tgt).backward()
    optim.step()
    
    assert np.allclose(hmm.pi.sum().detach().numpy(), 1.)
    assert np.allclose(hmm.A.sum(1).detach().numpy(), 1.)
    
    if n % (N_EPOCHS / 10) == 0 and n != 0:
        print(f'Epoch {n} | P(X): {evidence:1.5}')

# Max-sum (Viterbi algorithm)

### Pass max-messages forward

In [None]:
omega = []

for m in range(M):
    if m == 0:
        o_m = [hmm.pi[m].log() + hmm.lik(z, m) for z in range(K)]
        o_m_idx = [None] * len(o_m)
    else:
        o_m = []
        o_m_idx = []
        for z_j in range(K):
            ll = hmm.lik(z_j, m)
            mx, mx_idx = torch.tensor([ll + hmm.A[z_i][z_j] + omega[-1][z_i][0] for z_i in range(K)]).max(0)
            o_m.append(mx)
            o_m_idx.append(mx_idx)
    o = [(t.item(), i.item()) if i is not None else (t.item(), i) for t, i in zip(o_m, o_m_idx)]
    omega.append(o)

### Backtrack

In [None]:
m = M
configs = []
while omega:
    o_m = omega.pop()
    vals, idxs = zip(*o_m)
    if all([i is not None for i in idxs]):
        if m == M:
            max_val = max(vals)
            for i, v in enumerate(vals):
                if v == max_val:
                    configs.append([i, idxs[i]])
        else:
            for c in configs:
                phi = c[-1]
                c.append(idxs[phi])
        m -= 1
        
print(f'Max val: {max_val}')
print(f'Maximizing configs: {configs}')