In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as plt_colors
import numpy as np
import tensorflow as tf
from sklearn.cluster import KMeans

from hybridhmm import HybridHMMTransitions, heuristic_priors_adjustment
from models import SkelFeatureExtractor
from utils import plot_transition_matrix

np.random.seed(1234)

matplotlib.rcParams['figure.dpi'] = 100

# True (synthetic) distribution

We assume a Laplace-HMM distribution with 6 states such than $p(\boldsymbol{x}_t \mid s_t = i) = \mathcal{L}(\boldsymbol{x}_t; \boldsymbol{\mu}_i, 0.2)$ where $\boldsymbol{\mu}_i = \left( cos \left( \frac{2 (i - 1) \pi}{6} \right), sin \left( \frac{2 (i - 1) \pi}{6} \right) \right)$

The duration of the sequences follows a poisson distribution of mean $\lambda = 25$.

Finally, the system always starts in state 1.

In [None]:
n_states = 5
input_size = 2

mu = np.array([
    [np.cos(2 * np.pi * i / n_states), np.sin(2 * np.pi * i / n_states)]
    for i in range(n_states)])

true_transitions = [
    [.96, .04, .00, .00, .00],
    [.00, .80, .20, .00, .00],
    [.00, .15, .70, .15, .00],
    [.00, .00, .00, .85, .15],
    [.15, .00, .00, .00, .85]]

true_init_state_priors = [1.] + [0] * (len(true_transitions) - 1)

plt.figure(figsize=(10, 3))
plt.subplot(1, 2, 1)
plot_transition_matrix(true_transitions)
plt.subplot(1, 2, 2)
plt.scatter(mu[:, 0], mu[:, 1])
plt.axis('equal')
plt.title("center points")
plt.show()

# Data

Generate (noisy) data from the true model.

In [None]:
dataset_y = []

for _ in range(100):
    state_seq = []
    state = np.random.choice(
        np.arange(n_states), 1, 
        p=true_init_state_priors)[0]
    state_seq.append(state)
    
    for _ in range(max(0, np.random.poisson(75) - 1)):
        state = np.random.choice(
            np.arange(n_states), 1, 
            p=true_transitions[state])[0]
        state_seq.append(state)
        if state == n_states - 1 and np.random.rand() > .87:
            break
    
    dataset_y.append(np.array(state_seq))

dataset_x = []

for y in dataset_y:
    dataset_x.append(mu[y] + np.random.laplace(scale=.2, size=(len(y), input_size)))

Observe that the noise on the datapoints will generate observations that are difficult to attribute to their corresponding state.
The transition model of the HMM will hopefully lift the ambiguity for most of these situations.

In [None]:
all_observations = np.concatenate(dataset_x)
all_labels = np.concatenate(dataset_y)
for i in range(n_states):
    plt.scatter(all_observations[all_labels == i, 0], 
                all_observations[all_labels == i, 1],
                s=5)
    
plt.axis('equal');
plt.title('Observation data points')
plt.legend([str('$s={}$'.format(i)) for i in range(n_states)],
           title="underlying state");

# Model

Below, we create the model and training routines.
The transition matrix encodes several assumptions:

- Our system describes certain events which involve 4 states.
- We have little information about the order of these states, except that the event always begins with the same state and ends in another one
- Inside the event, we do not know if the states appear in a certain order or not, except that we cannot loop back from the last to the first state.
- Before and after an event occurs, the system stays in a separate 'resting' state

In [None]:
transitions = [
    [.90, .04, .00, .00, .00],
    [.00, .70, .10, .10, .10],
    [.00, .10, .70, .10, .10],
    [.00, .10, .10, .70, .10],
    [.10, .00, .10, .10, .70]]
init_state_priors = [1.] + [0] * (len(true_transitions) - 1)
state_priors = np.full([n_states], 1 / n_states, dtype='float32')

plot_transition_matrix(transitions)

Below is the code to create the model parts, everything is implemented with Tensorflow (no wrapper of dependency used).

In [None]:
tf.reset_default_graph()

# Models and inference

training = tf.placeholder(tf.bool, shape=[], name='training')
inputs = tf.placeholder(dtype='float32', shape=[None, input_size], name='inputs')

net = SkelFeatureExtractor(n_states)

logits = net(inputs, training)
state_posteriors = tf.nn.log_softmax(logits)

hmm = HybridHMMTransitions(transitions, init_state_priors, state_priors)

pseudo_lkh = hmm.pseudo_lkh(state_posteriors)
ml_state_alignment = hmm.viterbi(pseudo_lkh)

Then comes the training routines

In [None]:
# Training

state_tgt = tf.placeholder(dtype='int64', shape=[None], name='state_tgt')
loss_weights = tf.placeholder(dtype='float32', shape=[n_states], name='loss_weight')
loss = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=state_tgt,
        logits=logits) * tf.gather(loss_weights, state_tgt))

learning_rate = tf.placeholder(dtype='float32', shape=[], name='learning_rate')
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  # batch norm updates
with tf.control_dependencies(extra_update_ops):
    optimization_op = optimizer.minimize(loss, var_list=net.trainable_variables)

add_transition_stats_op = hmm.add_transition_stats(pseudo_lkh)
update_transitions_op = hmm.update_transitions()
add_prior_stats_op = hmm.add_priors_stats(pseudo_lkh)
update_state_priors = hmm.update_state_priors()

(more Tensorflow boilerplate)

In [None]:
if 'sess' in globals():
    sess.close()
    
sess = tf.Session()

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

# Training procedure

## First training pass

Training begins with the posterior state model, that is to say the Neural Network.
In order to proceed, target values are needed.
In this case, the targets are the state values which are unfortunately unobserved.
To circumvent this problem, we initially select suboptimal values chosen arbitrarily according to an heuristic adapted to the use-case, for example the result of a GMM-HMM model or a k-means clustering.

Once all model parts have been trained at least once, the model is assumed to be good enough to provide sensible state values.

In [None]:
def count(data, n):
    return np.sum(np.asarray(data)[:, None] == np.arange(n)[None, :], axis=0)

# we use a k-Means to generate initial state values
kmeans = KMeans(n_clusters=n_states)
kmeans.fit(all_observations)
kmeans_state_predictions = [kmeans.predict(x) for x in dataset_x]

# to refine the initialization, we will use our asumption that states mostly appear in succession
centers2state = []

# we know the system always starts in state 0
centers2state.append(
    np.argmax(count([p[0] for p in kmeans_state_predictions], n_states)))

# iteratively find the next state
for _ in range(n_states - 1):
    successors = []
    for p in kmeans_state_predictions:
        is_successor = np.invert(np.isin(p[1:], centers2state)) & (p[:-1] == centers2state[-1])
        successors.extend(p[1:][is_successor])
    
    ml_successor = np.argmax(count(successors, n_states))
    centers2state.append(ml_successor)

# reorder clusters to (probably) match states
kmeans.cluster_centers_ = kmeans.cluster_centers_[centers2state]

On this simple example with synthetic data, kmeans provides a strong initialization.
When working with real data, it can become difficult to provide sensible initialization values.

In [None]:
for i in range(n_states):
    plt.scatter(
        [kmeans.cluster_centers_[i, 0], mu[i, 0]],
        [kmeans.cluster_centers_[i, 1], mu[i, 1]])

plt.axis('equal')
plt.legend([str(i) for i in range(n_states)])
plt.title("k-Mean centers vs true state distributions means")

In [None]:
# use the kmeans heuristic to set state posterior targets

state_alignments = [
    kmeans.predict(x)
    for x in dataset_x]

Unfortunately, a bad initialization of state targets might easily send the training procedure on a degenerative and non-recoverable path.
Supervision of the training procedure is critical and expertise is required to make sure the state alignment does not ignore states or otherwise attribute too many observations to a single states.

This can be monitored by checking how often each state is visited.

Some imbalance is not abnormal, some states may naturally appear more often, but imbalance will quickly detoriorate the quality of the posterior state model (the neural network): it will overly focus on the most frequent states while ignoring the others.
To counter this effect, we could resample the dataset but we choose to reweight the loss depending on the label, which has the same effect on average but is simpler to implement.

Check the usage of the `loss_weights` variable in the expression of the loss above.

In [None]:
freqs = count(np.concatenate(state_alignments), n_states)
freqs = freqs / np.sum(freqs)

plt.figure()
plt.bar(np.arange(hmm.n_states), count(np.concatenate(state_alignments), n_states))
plt.title("state counts")
plt.show()

In [None]:
# state posterior model training iterations
net_losses = []
for _ in range(100):
    i = np.random.randint(len(dataset_x))
    loss_value, _ = sess.run(
        [loss, optimization_op],
        feed_dict={
            inputs: dataset_x[i],
            state_tgt: state_alignments[i],
            loss_weights: n_states / freqs,
            learning_rate: 0.001,
            training: True
        })
    net_losses.append(loss_value)

Using the posterior model, it is possible to train the transition model.
The accumulated statistics from all the observation sequences are used to update the initial state prior $\pi_i = p(s_1=i)$ and the transition probabilities $a_{ij} = p(s_{t+1}=j \mid s_t=i)$.

In [None]:
for x in dataset_x:
    sess.run(
        [add_transition_stats_op],
        feed_dict={
            inputs: x,
            training: False
        })

sess.run(update_transitions_op)

Because our neural network is unbiased with respect to state frequencies (reweighting mechanism), the state priors remain uniform.
The output of the Neural Network is assumed to be proportional to the likelihood.

In [None]:
# estimate state priors from the state frequencies observed on the training dataset with the current model

# for x in dataset_x:
#     sess.run(
#         add_prior_stats_op,
#         feed_dict={
#             inputs: x,
#             training: False
#         })
# 
# sess.run(update_state_priors)

All model parts have been trained once now. Let's recaptulate the training process thus far:

In [None]:
plt.figure(figsize=(9, 2), dpi=100)
plt.subplot(1, 3, 1)
plt.scatter(range(len(net_losses)), net_losses, c='red', s=5, alpha=.3)
plt.xlabel("iterations")
plt.title("CE loss")
plt.subplot(1, 3, 2)
plot_transition_matrix(np.exp(sess.run(hmm.A)))
plt.subplot(1, 3, 3)
plt.bar(np.arange(hmm.n_states), np.exp(sess.run(hmm.state_priors)))
plt.title("state priors")
plt.ylim((0.01, 1))
plt.tight_layout()
plt.show()

Following are successive refinement iterations:

In [None]:
for e in range(5):    
    # realign states
    state_alignments = [
        sess.run(ml_state_alignment, feed_dict={inputs: x, training: False})
        for x in dataset_x]
    
    freqs = count(np.concatenate(state_alignments), n_states)
    freqs = freqs / np.sum(freqs)

    # state posterior model (Neural Network)
    for _ in range(100):
        i = np.random.randint(len(dataset_x))
        loss_value, _ = sess.run(
            [loss, optimization_op],
            feed_dict={
                inputs: dataset_x[i],
                state_tgt: state_alignments[i],
                loss_weights: n_states / freqs,
                learning_rate: 0.001,
                training: True
            })
        net_losses.append(loss_value)
    
    # transition model
    for x in dataset_x:
        sess.run(
            add_transition_stats_op,
            feed_dict={
                inputs: x,
                training: False
            })

    sess.run(update_transitions_op)
    
    # state priors
    #for x in dataset_x:
    #    sess.run(
    #        add_prior_stats_op,
    #        feed_dict={
    #            inputs: x,
    #            training: False
    #        })
    #
    #sess.run(update_state_priors)
    
    plt.figure(figsize=(9, 2), dpi=100)
    plt.subplot(1, 3, 1)
    plt.scatter(np.arange(len(net_losses) - 100), net_losses[:-100], c='gray', s=5, alpha=.3)
    plt.scatter(np.arange(len(net_losses) - 100, len(net_losses)), net_losses[-100:], c='red', s=5, alpha=.3)
    plt.xlabel("iterations")
    plt.title("CE loss")
    plt.subplot(1, 3, 2)
    plot_transition_matrix(np.exp(sess.run(hmm.A)))
    plt.subplot(1, 3, 3)
    plt.bar(np.arange(hmm.n_states), np.exp(sess.run(hmm.state_priors)))
    plt.title("state priors")
    plt.ylim((0.01, 1))
    plt.tight_layout()
    plt.show()

# Evaluation

Below are some observations of the model in action.
The output of the Neural Network (logits) is sometime mistaken on what state is active at a given time, but the HMM transition model (viterbi) brings temporal coherence and eliminates these errors.

In [None]:
for i in range(10):
    p, a = sess.run(
        [tf.argmax(logits, axis=1),  # most likely state according to the posterior state model
         ml_state_alignment],  # most likely states according to the HMM (viterbi)
        feed_dict={
            inputs: dataset_x[i],
            training: False
        })
    print("logits:   " + ''.join(map(str, p)))
    print("viterbi:  " + ''.join(map(str, a)))
    print("gndtruth: " + ''.join(map(str, dataset_y[i])))
    print()