In [89]:
import numpy as np

from torch import nn
from torch.distributions import Dirichlet, Normal
import torch.nn.functional as F

In [78]:
K = 5  # number of latent state
M = 10  # length of observations
N = 50  # number of observations

sd = .05  # standard deviation on gaussian observations
mu = lambda z: np.log(z + 1)  # the mean of our observation model, i.e. E[p(x|z)]

# Create dummy data

In [87]:
# Parameters
A = F.softmax(torch.randn(K, K), dim=1).numpy()  # randomly initialize a transition matrix
pi = np.array(K * [1 / K])  # uniform distribution over the initial latent state

# Data
X, Z = [], []
for n in range(N):
    obs = []
    for m in range(M):
        if m == 0:
            z = np.random.choice(range(K), p=pi)
        else:
            z = np.random.choice(range(K), p=A[z])
        x = mu(z) + sd * np.random.randn()
        obs.append((z, x))
    z_m, x_m = zip(*obs)
    Z.append(z_m)
    X.append(x_m)

Z = np.array(Z); X = np.array(X)
assert Z.shape == (N, M)
assert X.shape == (N, M)

# Create models

In [158]:
class GuassianObservationModel(nn.Module):
    
    def __init__(self, K, embed_dim, hidden_dim):
        super().__init__()
        self.embed = nn.Embedding(K, embed_dim)
        self.hidden = nn.Linear(embed_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, 1)
        
    def forward(self, z):
        embed = self.embed(z)
        return torch.relu(self.out(self.hidden(embed)))  # ReLU, b/c we know obs. means are positive
    
    
class HMM:
    
    EMBED_DIM = 20
    HIDDEN_DIM = 10
    
    def __init__(self, pi, A, X):
        self.pi = pi
        self.A = A
        self.X = torch.tensor(X, requires_grad=False, dtype=torch.float32)
        self.obs_model = GuassianObservationModel(K, self.EMBED_DIM, self.HIDDEN_DIM)
        
    def log_lik(self, z):
        """
        The log-likelihood of *all* of our observed data conditional on latent code `z`,
        given a Gaussian observation model.
        
        This quantity is also known as the "mean-squared error" between the observed data and
        our "prediction" thereof, which can be computed via the L2-norm of the differences between
        each datum, and the output of our model given `z`.
        """
        z = torch.LongTensor([z])
        return (self.X[:, z] - self.obs_model(z)).norm().item()
    
    def evidence(self):
        return sum([self.obs_model(z).item() for z in range(K)])
        
    def factor(self, i, j, m):
        """
        The HMM factor is given by: `f(z_m-1, z_m) = p(z_m|z_m-1)p(x_m|z_m)`, where:
        
            - `m` gives our current index in the chain
            - `i` gives the value of `z_m-1`
            - `j` gives the value of `z_m`
        """
        return self.A[i][j] * self.log_lik(j)

# Setup

In [159]:
hmm = HMM(pi, A, X)

# E-step

## Alpha step

In [174]:
alpha = []

for m in range(M):
    if m == 0:
        a_m = [hmm.log_lik(z) * hmm.pi[z] for z in range(K)]
    else:
        a_m = []
        for z_j in range(K):
            ll = hmm.log_lik(z_j)
            a_m.append(sum([alpha[-1][z_i] * hmm.A[z_i][z_j] for z_i in range(K)]))
    alpha.append(a_m)
alpha = np.array(alpha)