In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as plt_colors
import numpy as np
import tensorflow as tf

from hybridhmm import HybridHMMTransitions, heuristic_priors_uniform
from models import SkelFeatureExtractor
from utils import plot_transition_matrix

np.random.seed(1234)

matplotlib.rcParams['figure.dpi'] = 100

# True (synthetic) distribution

We assume a Gaussian-HMM distribution. That is to say $p(\boldsymbol{x}_t \mid s_t = i) = \mathcal{N}(\boldsymbol{x}_t; \boldsymbol{\mu}_i, 1)$

In [None]:
input_size = 16
n_states = 6

mu = np.random.randn(n_states, input_size) * 4

true_transitions = [
    [.82, .18, .00, .00, .00, .00],
    [.05, .80, .15, .00, .00, .00],
    [.00, .07, .90, .03, .00, .00],
    [.00, .00, .00, .95, .05, .00],
    [.00, .00, .00, .01, .95, .04],
    [.00, .00, .00, .00, .15, .85]]
plot_transition_matrix(true_transitions)

true_init_state_priors = np.array([1.] + [0] * (n_states - 1))

# Data

Generate (noisy) data from the true model

In [None]:
dataset_y = []

for _ in range(100):
    state_seq = []
    state = np.random.choice(
        np.arange(n_states), 1, 
        p=true_init_state_priors)[0]
    state_seq.append(state)
    
    while True:
        state = np.random.choice(
            np.arange(n_states), 1, 
            p=true_transitions[state])[0]
        state_seq.append(state)
        if state == n_states - 1 and np.random.rand() > .87:
            break
    
    dataset_y.append(np.array(state_seq))

dataset_x = []

for y in dataset_y:
    dataset_x.append(mu[y] + np.random.randn(len(y), input_size) * 0.1)

# Model

Create the model and training routines

In [None]:
transitions = [
    [.84, .16, .00, .00, .00, .00],
    [.08, .84, .08, .00, .00, .00],
    [.00, .15, .84, .01, .00, .00],
    [.00, .00, .00, .84, .16, .00],
    [.00, .00, .00, .08, .84, .08],
    [.00, .00, .00, .00, .16, .84]]
init_state_priors = np.full([n_states], 1 / n_states, dtype='float32')
state_priors = np.full([n_states], 1 / n_states, dtype='float32')

plot_transition_matrix(transitions)

In [None]:
tf.reset_default_graph()

# Models and inference

training = tf.placeholder(tf.bool, shape=[], name='training')
inputs = tf.placeholder(dtype='float32', shape=[None, input_size], name='inputs')

net = SkelFeatureExtractor(n_states)

logits = net(inputs, training)
state_posteriors = tf.nn.log_softmax(logits)

hmm = HybridHMMTransitions(transitions, init_state_priors, state_priors)

pseudo_lkh = hmm.pseudo_lkh(state_posteriors)
ml_state_alignment = hmm.viterbi(pseudo_lkh)

# Training

state_tgt = tf.placeholder(dtype='int64', shape=[None], name='state_tgt')
loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(
    labels=state_tgt,
    logits=logits,
    name='cross_entropy'))

learning_rate = tf.placeholder(dtype='float32', shape=[], name='learning_rate')
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  # batch norm updates
with tf.control_dependencies(extra_update_ops):
    optimization_op = optimizer.minimize(loss, var_list=net.trainable_variables)

add_transition_stats_op = hmm.add_transition_stats(pseudo_lkh)
update_transitions_op = hmm.update_transitions()
add_prior_stats_op = hmm.add_priors_stats(pseudo_lkh)
update_state_priors = hmm.update_state_priors()

In [None]:
if 'sess' in globals():
    sess.close()
    
sess = tf.Session()

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

# Training procedure

## First training pass

Training begins with the posterior state model, that is to say the Neural Network.
In order to proceed, target values are needed.
In this case, the targets are the state values which are unfortunately unobserved.
To circumvent this problem, we initially select suboptimal values chosen arbitrarily according to an heuristic adapted to the use-case, for example the result of a GMM-HMM model or a k-means clustering.

Once all model parts have been trained at least once, the model is assumed to be good enough to provide sensible state values.

In [None]:
# arbitrary state posterior targets set by heuristic
# we assume a uniform succession of states over the sequence
state_alignments = [
    np.floor(np.linspace(0, n_states - 0.000001, len(x))).astype(np.int32)
    for x in dataset_x]

# state posterior model training iterations
net_losses = []
for _ in range(100):
    i = np.random.randint(len(dataset_x))
    loss_value, _ = sess.run(
        [loss, optimization_op],
        feed_dict={
            inputs: dataset_x[i],
            state_tgt: state_alignments[i],
            learning_rate: 0.001,
            training: True
        })
    net_losses.append(loss_value)

Using the posterior model, it is possible to train the transition model.
The accumulated statistics from all the observation sequences are used to update the initial state prior $\pi_i = p(s_1=i)$ and the transition probabilities $a_{ij} = p(s_{t+1}=j \mid s_t=i)$.

In [None]:
for x in dataset_x:
    sess.run(
        [add_transition_stats_op],
        feed_dict={
            inputs: x,
            training: False
        })

sess.run(update_transitions_op)

Let's observe how often each state is visited:

In [None]:
def count_states():
    state_alignments = np.concatenate([
            sess.run(ml_state_alignment, feed_dict={inputs: x, training: False})
            for x in dataset_x])
    counts = np.sum(state_alignments[:, None] == np.arange(hmm.n_states)[None, :], axis=0)
    
    return counts

plt.figure()
plt.bar(np.arange(hmm.n_states), count_states())
plt.title("state counts")
plt.show()

The training process has led the model to skip over some states in favor of more easily recognizable ones.
This is due to the sub-optimal state alignment initially used, the posterior model recognizes some states more easily than others, the latter therefore have a small likelihood which in turn leads the transition model to avoid them.

To prevent this degradation, we will force the model to visit some states more. The current state count is obtained under the asumption of uniform state priors, we will force the model to visit them more often by reducing their prior, and therefore increasing their likelihood.

In [None]:
heuristic_priors_adjustment(hmm, count_states, sess)

plt.figure()
plt.bar(np.arange(hmm.n_states), count_states())
plt.title("state counts")
plt.show()

Let's recaptulate the training process thus far

In [None]:
plt.figure(figsize=(8, 2), dpi=100)
plt.subplot(1, 3, 1)
plt.scatter(range(len(net_losses)), net_losses, c='red', s=10)
plt.xlabel("iterations")
plt.title("CE loss")
plt.subplot(1, 3, 2)
plot_transition_matrix(np.exp(sess.run(hmm.A)))
plt.subplot(1, 3, 3)
plt.bar(np.arange(hmm.n_states), np.exp(sess.run(hmm.state_priors)))
plt.title("state priors")
plt.ylim((0.01, 1))
plt.tight_layout()
plt.show()

Following are successive refinement iterations of the model parts.

In [None]:
for e in range(5):    
    # realign states
    state_alignments = [
        sess.run(ml_state_alignment, feed_dict={inputs: x, training: False})
        for x in dataset_x]

    # state posterior model (Neural Network)
    for _ in range(100):
        i = np.random.randint(len(dataset_x))
        loss_value, _ = sess.run(
            [loss, optimization_op],
            feed_dict={
                inputs: dataset_x[i],
                state_tgt: state_alignments[i],
                learning_rate: 0.0005,
                training: True
            })
        net_losses.append(loss_value)
    
    # transition model
    for x in dataset_x:
        sess.run(
            add_transition_stats_op,
            feed_dict={
                inputs: x,
                training: False
            })

    sess.run(update_transitions_op)
    
    # state priors
    for x in dataset_x:
        sess.run(
            add_prior_stats_op,
            feed_dict={
                inputs: x,
                training: False
            })
    
    sess.run(update_state_priors)
    
    plt.figure(figsize=(8, 2), dpi=100)
    plt.subplot(1, 3, 1)
    plt.scatter(np.arange(len(net_losses) - 100), net_losses[:-100], c='gray', s=10)
    plt.scatter(np.arange(len(net_losses) - 100, len(net_losses)), net_losses[-100:], c='red', s=10)
    plt.xlabel("iterations")
    plt.title("CE loss")
    plt.subplot(1, 3, 2)
    plot_transition_matrix(np.exp(sess.run(hmm.A)))
    plt.subplot(1, 3, 3)
    plt.bar(np.arange(hmm.n_states), np.exp(sess.run(hmm.state_priors)))
    plt.title("state priors")
    plt.ylim((0.01, 1))
    plt.tight_layout()
    plt.show()

In [None]:
class_mapping = np.array([0, 0, 0, 1, 1, 1])

for i in range(10):
    p, a = sess.run(
        [tf.argmax(logits, axis=1), ml_state_alignment],
        feed_dict={
            inputs: dataset_x[i],
            training: False
        })
    print("logits:   " + ''.join(map(str, p)))
    print("viterbi:  " + ''.join(map(str, a)))
    print("gndtruth: " + ''.join(map(str, dataset_y[i])))
    err_summary = ['✓' if a_ == y_ else '⨯'
                   for a_, y_ in zip(class_mapping[a], class_mapping[dataset_y[i]])]
    print("          " + ''.join(map(str, err_summary)))
    print()