# Week 3: MFCCs and HMMs
This week we're moving on to how to apply sequence models to decode speech. To do this, we'll introduce a model that's closely related to the WFST, namely the *Hidden Markov Model*. We'll start with a simple analogy involving geography and pathfinding in order to understand the intuition behind the model. We'll use the `pomegranate` library for modeling HMMs, though HMMs can also be converted into WFSTs, which we'll demonstrate using `pynini`.

After that, we'll review how we can use MFCCs to represent speech in a numerical form that's useful for machine learning, and then use this to apply HMMs to speech. We'll use `librosa`, a Python library for handling audio, to calculate MFCCs.

In [None]:
import librosa
from praatio import textgrid
import pandas as pd
import numpy as np
import sklearn.manifold
import matplotlib.pyplot as plt
import seaborn as sns
import pynini
import graphviz
from librikws import *
import os

## Section 1: Encoding speech with MFCCs
MFCCs (Mel Frequency Cepstral Coefficients) are an algorithm for encoding speech into a sequence of vectors that is useful for machine learning. MFCCs rely on a series of signal processing steps to condense the waveform into a representation that is optimized for the kind of information that is useful for understanding speech.

While it would be great to go more into how they work rather than giving this handwavy description I've given here, all we need to know in order to use MFCCs for KWS is that they are a useful way of encoding audio. Here are [some more videos from Herman Kamper](https://www.youtube.com/playlist?list=PLmZlBIcArwhN8nFJ8VL1jLM2Qe7YCcmAb) if you want to learn more about various signal processing methods for working with speech data.

To give a brief demonstration of MFCCs, let's load some toy data I made for a previous class project. The data consists only of the words "lawn, lean, gnaw, knee, kneel" and is stored in the `data/` directory of this github project. I've included some helper functions (stored in `librikws.py`) to load the TextGrid data into a `pandas.DataFrame`.

In [None]:
df = textgrid_to_df('data/ailn.TextGrid')
df.head()

In [None]:
df[df['tier']=='word'].head()

Notice that every row in the dataframe corresponds to an interval from the Praat TextGrid, containing start and end timestamps, 'value' (either a phone label or a word label) and tier (phone or word).

Let's now load the wav file.

In [None]:
wav, samplerate = librosa.load('data/ailn.wav')
wav.shape, samplerate

Now that we've loaded the data, let's calculate MFCCs for the wav file. We can use `librosa.feature.mfcc` to do this.

In [None]:
mfcc_matrix = librosa.feature.mfcc(y=wav, sr=samplerate, n_mfcc=13)
mfcc_matrix.shape

Our MFCCs come in a `(13, 373)` shaped matrix. That means that our audio is divided into 373 windows, and for each of these windows is represented with a vector of length 13.

Next we want to be able to map the MFCCs to their respective phoneme. Let's do this by adding the columns `start_sample` and `end_sample` that indicate the start and end timestamps in terms of samples of the waveform instead of seconds.

In [None]:
df['start_sample'] = (df['start']*samplerate).astype(int)
df['end_sample'] = (df['end']*samplerate).astype(int)
df.head()

Let's then figure out how many samples are in a single window for the MFCC matrix.

In [None]:
mfcc_window_count = mfcc_matrix.shape[1]
num_samples = wav.shape[0]
samples_per_mfcc_window = num_samples//mfcc_window_count
samples_per_mfcc_window

It turns out there are 510 samples in a given MFCC window. We can use this information to map each row to the relevant MFCC vectors.

In [None]:
df['mfcc_index_start']=df['start_sample']//samples_per_mfcc_window
df['mfcc_index_end']=df['end_sample']//samples_per_mfcc_window-1

df.head()

Now we want to map MFCCs to a particular phone. First let's make a list with all unique phones in the dataset, and then create a vector mapping each MFCC to its phoneme using the index of the phone given in the `phones` list.

In [None]:
phones = list(df.loc[df['tier']=='phone', 'value'].unique())
phones

In [None]:
mfcc_phone_ids = np.full(mfcc_window_count, fill_value=-1, dtype=int)

def set_mfcc_phoneme_ids(row):
    mfcc_phone_ids[row['mfcc_index_start']:row['mfcc_index_end']]=phones.index(row['value'])

df[df['tier']=='phone'].apply(set_mfcc_phoneme_ids, axis=1)
mfcc_phone_ids

The vector `mfcc_phone_ids` gives the index of the phone corresponding to each MFCC window. The large amount of -1's in the vector reflect silent intervals between words. Where a value is greater than -1, e.g. `mfcc_phone_ids[85]=0`, it reflects that the 85th MFCC window corresponds to the phoneme indexed at position 0 on the `phones` list, that is the phone [l].

In [None]:
mfcc_phone_ids[85], phones[mfcc_phone_ids[85]]

Now we're almost ready to visualize the MFCCS for each phone. There's just one remaining issue: humans can't see in 13 dimensions at once! In order to visualize the 13-coefficient MFCC vectors we need to perform dimensionality reduction. Since we're just worried about demonstrating and exploring here, t-SNE works fine.

In [None]:
mfcc_tsne = sklearn.manifold.TSNE().fit_transform(mfcc_matrix.T)
mfcc_tsne.shape

Let's save the t-SNE embeddings to a dataframe, associate them with their phone labels, and make a mask to ignore all rows corresponding to silence.

In [None]:
tsne_df = pd.DataFrame(mfcc_tsne)
tsne_df['phone']=mfcc_phone_ids
tsne_df['phone']=tsne_df['phone'].apply(lambda i: phones[i] if i>=0 else '')
speech_mask = tsne_df['phone']!=''

tsne_df[speech_mask].head()

In [None]:
sns.scatterplot(tsne_df[speech_mask], x=0, y=1, hue='phone')
plt.show()

And voi-là! The t-SNE of the MFCCs shows pretty good clustering by phone. The few points that are far from their cluster, e.g. the [l] points that are close to [a] and [i], likely come from the boundaries between phones, where the transition from one phone to the next is a continuous gradient.

Since MFCCs are such a good way to represent phonetic information, we can use them to calculate *phone probabilities* that we'll then use to decode in order to get an ASR hypothesis. The process of computing a phone probability from an acoustic representation like an MFCC is the task of *acoustic modeling*. In short, we are trying to define an algorithm that models the joint distribution of phonemes $Y$ and audio $X$, i.e.:

$\mathrm{P_{ASR}}(X,Y)$

We'll dive deeper into one particular architecture of acoustic modeling later on in this lesson. For now, suffice to say that we can break up the probability above into two components using the [multiplication theory of probability](https://www.geeksforgeeks.org/maths/multiplication-theorem/):

$\mathrm{P_{ASR}}(X,Y)=\mathrm{P_{AM}}(X|Y)\mathrm{P_{LM}}(Y)$

Where $\mathrm{AM}$ stands for 'acoustic model' and $\mathrm{LM}$ 'language model.'
Let's work on calculating the first component, $\mathrm{P_{AM}}(X|Y)$. This intuitively means "how likely is audio X given phone Y". We can represent this naively using counts. Let's say $x_i$ is a particular acoustic frame, encoded using an MFCC vector. $X=x_i$, then, refers to the event that an audio corresponding to the MFCC vector $x_i$ is produced. Likewise, $Y=a$ refers to the event that the phone [a] is produced. We can represent the probability of observing $x_i$ whenever we hear the phone [a] using the following formula:

$\mathrm{P_{AM}}(X=x_i|Y=a)=\frac{\mathrm{count}(X=x_i,Y=a)}{\mathrm{count}(Y=a)}$

That is, the probability that we hear $x_i$ when a speaker utters [a] is equal to the number of times in our training corpus that $x_i$ is observed for the phone [a] over the total number of frames belonging to [a], regardless of what $X$ they are associated with.

How do we calculate these counts for MFCC vectors? Notice that there are no overlapping points in the plot above: this means that *every MFCC vector in our dataset is unique*. If every vector is unique, then the numerator $\mathrm{count}(X=x_i,Y=a)$ will always be equal to either 1 or 0. Furthermore, after the model is trained we'd have to observe an MFCC which is *exactly* the same as $x_i$ from our dataset in order for the probability calculated above to be useful. This is a fundamental challenge of working with *continuous data* like audio. The ideal solution would be to calculate the probability above using a *continuous distribution* rather than discrete counts like we've used. However, we can get away with using discrete counts by clustering MFCC vectors into discrete groups, and then calculating the counts based on these groups. This process of clustering vectors is known as *vector quantization*. For more mathematical discussion of vector quantization in acoustic modeling, see [Jurafsky and Martin, 2nd edition, Ch. 9, pp 305-308](https://github.com/rain1024/slp2-pdf/blob/0add5260eca38f541909cfa28d17e9ae96008d60/chapter-wise-pdf/%5B09%5D%20Automatic%20Speech%20Recognition.pdf).

To perform vector quantization, let's compute $k$-means over our dataset of MFCC vectors. We can use `Kmeans.fit_predict()` to compute $k$-means clustering and then return the cluster ID for each input vector.

In [None]:
n_clusters = 15
cluster_ids = sklearn.cluster.KMeans(n_clusters=n_clusters, random_state=42).fit_predict(mfcc_matrix.T)
cluster_ids.shape, cluster_ids[:10]

Now let's add the cluster IDs to our dataframe so that we can view the relation between clusters and phones.

In [None]:
tsne_df['cluster_id']=cluster_ids

sns.scatterplot(tsne_df[speech_mask], x=0, y=1, hue='phone', style='cluster_id', palette='tab10')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
plt.show()

Notice that we don't see all 15 clusters in the plot above since we've excluded non-speech audio frames from the visualization. It looks like clusters 1, 2, 4, 5, 8, 10 and 14 all correspond to non-speech.

Let $c_i$ be the $i$-th cluster. Judging by the plot, we can guess that e.g. $\mathrm{P_{AM}}(X=c_3|Y=a)$ will be pretty high, whereas $\mathrm{P_{AM}}(X=c_11|Y=a)$ will be much lower, and $\mathrm{P_{AM}}(X=c_0|Y=a)$ will be zero. Let's calculate this below.

In [None]:
cluster_3_mask = tsne_df['cluster_id']==3
cluster_11_mask = tsne_df['cluster_id']==11
cluster_0_mask = tsne_df['cluster_id']==0

a_mask = tsne_df['phone']=='a'

a_and_cluster_3 = np.sum(cluster_3_mask & a_mask).item()
a_and_cluster_11 = np.sum(cluster_11_mask & a_mask).item()
a_and_cluster_0 = np.sum(cluster_0_mask & a_mask).item()
a_count = np.sum(a_mask).item()

print(f"""
Count (X=c_3, Y=a): {a_and_cluster_3}
Count (X=c_11, Y=a): {a_and_cluster_11}
Count (X=c_0, Y=a): {a_and_cluster_0}
Total Count (Y=a):  {a_count}
""")

In [None]:
print(f"""
P(X=c_3|Y=a): {a_and_cluster_3/a_count:.2f}
P(X=c_11|Y=a): {a_and_cluster_11/a_count:.2f}
P(X=c_0|Y=a): {a_and_cluster_0/a_count:.2f}
""")

Our visual intuitions turned out to be correct!

**EXERCISE 1:** Write a general function `posterior_phone_prob(cluster_id, phone)` that calculates $\mathrm{P_{AM}}(X|Y)$ for the given cluster and phone ids. Feel free to model your code off the code above.

In [None]:
def posterior_phone_prob(cluster_id, phone):
    # YOUR CODE HERE
    ...

The following uses the function above to create a heatmap showing the posterior probability of every cluster id for every phone. Lighter values indicate higher probability.

In [None]:
posterior_phone_prob_matrix = np.zeros((n_clusters, len(phones)))
for i in range(n_clusters):
    for j, phone in enumerate(phones):
        posterior_phone_prob_matrix[i, j] = posterior_phone_prob(i, phone)

sns.heatmap(posterior_phone_prob_matrix, xticklabels=phones, yticklabels=[f'c_{i}' for i in range(n_clusters)])
plt.show()

Now we've got $\mathrm{P_{AM}}(X|Y)$, we're just missing $\mathrm{P_{AM}}(Y)$. We can represent this with an FST mapping phone labels to words. I've copied and adapted the relevant code from the previous assignment below.

In [None]:
lexicon = {
    'lawn': 'l a n',
    'lean': 'l i n',
    'gnaw': 'n a',
    'knee': 'n i',
    'kneel': 'n i l',
}

phones_in_lexicon = set()
for phonestr in lexicon.values():
    phones_in_lexicon.update(phonestr.split())
words = list(lexicon.keys())

symbols = pynini.SymbolTable()
symbols.add_symbol('<eps>')

for word in words:
    symbols.add_symbol(word)

for phone in phones_in_lexicon:
    symbols.add_symbol(phone)

for i in range(n_clusters):
    symbols.add_symbol(f'c_{i}')

def set_symbols(f: pynini.Fst) -> pynini.Fst:
    """
    Set input and output symbols for a FST `f` to the
    user-defined symbol table.
    """
    f=f.set_input_symbols(symbols)
    f=f.set_output_symbols(symbols)
    return f

def fsa(acceptor_str: Union[str, List[str]], weight: Optional[pynini.WeightLike]=None) -> pynini.Fst:
    """
    Create a Finite State Acceptor of the given string using
    the symbols table.
    """
    if type(acceptor_str) is list:
        acceptor_str = ' '.join(acceptor_str)
    f=pynini.accep(acceptor_str, weight=weight, token_type=symbols)
    f=set_symbols(f)
    f=f.optimize()
    return f

def fst_string(f):
    return f.string(token_type=symbols)


def cluster2phone_fst(phone_str):
    phone_fsa = fsa("")
    for phone in phone_str.split():
        phone_fsa += fsa(phone).plus
    phone_fsa.optimize()
    return phone_fsa

def print_fst(f):
    tmp_path = 'tmp.dot'
    f.draw(tmp_path, portrait=True)
    with open(tmp_path) as file:
        return graphviz.Source(file.read())
    
words, phones_in_lexicon

**EXERCISE 2:** Create an FST `phones2word` below. Hint: `pynini.cross()` and `pynini.union()` will be helpful.

In [None]:
phones2word = fsa("")
for word, phone_str in lexicon.items():
    # YOUR CODE HERE
    ...
    
phones2word.optimize()
print_fst(phones2word)

In order to apply our FST to our data, we need to get the sequence of MFCC clusters for a given word. Once that's done, we'll use our function from before to transform those cluster ID sequences into posterior phone probabilities, which we can then decode into words using the FST above.

In [None]:
mfcc_word_ids = np.full(mfcc_window_count, fill_value=-1, dtype=int)

def set_mfcc_word_ids(row):
    mfcc_word_ids[row['mfcc_index_start']:row['mfcc_index_end']]=words.index(row['value'])

df[df['tier']=='word'].apply(set_mfcc_word_ids, axis=1)
tsne_df['word']=mfcc_word_ids
tsne_df['word']=tsne_df['word'].apply(lambda i: words[i] if i>=0 else '')
tsne_df[speech_mask].head()

In [None]:
def get_cluster_id_for_words(word, return_type=str):
    word_mask = tsne_df['word']==word
    cluster_ids = tsne_df.loc[word_mask, 'cluster_id'].tolist()
    if return_type is int:
        return cluster_ids
    cluster_id_strs = [f'c_{i}' for i in cluster_ids]
    return cluster_id_strs

get_cluster_id_for_words('lawn')

We can then make a WFST transducing cluster ids to their posterior phone probabilities using the `posterior_phone_prob` function we defined above. An example is given below for transducing $c_6$ to [n] and [i]. Note we have to do `1-posterior_phone_prob` since the *cost* of a WFST is the inverse of the *probability*.

In [None]:
c6_to_n_prob = pynini.cross(fsa("c_6"), fsa("n", weight=1-posterior_phone_prob(6, 'n')))
c6_to_i_prob = pynini.cross(fsa("c_6"), fsa("i", weight=1-posterior_phone_prob(6, 'i')))
c6_to_n_or_i = c6_to_n_prob|c6_to_i_prob
c6_to_n_or_i.optimize()

print_fst(c6_to_n_or_i)

**EXERCISE 3**: Following the example above, create a WFST mapping all cluster ids to all phones.

In [None]:
cluster2phones = fsa("")

for phone in phones_in_lexicon:
    for i in range(n_clusters):
        cluster_str = f'c_{i}'
        # construct a WFST mapping cluster_str to phone
        # YOUR CODE HERE
        ...

# right now the cluster2phones WFST only accepts a single cluster id
# it should accept ANY NUMBER of cluster ids
# modify the WFST to do this, it should just take one line
# YOUR CODE HERE
cluster2phones = ...

cluster2phones.optimize()
print_fst(cluster2phones)

**EXERCISE 4**: Now we have a function that will get cluster ids for a given word, and an FST that will transduce cluster ids to phone probabilities, we have all we need to do ASR on our dataset! The only thing left is to compose the WFSTs so that we get a single graph mapping cluster id $\mapsto$ word. This will require two successive compositions. The first composition will make a graph mapping cluster ids to phones, and the second will complete the chain mapping cluster ids to words. Remember the `cluster2phones` and `phones2word` FSTs we defined above!

In [None]:
def decode_word(word):
    cluster_fsa = fsa(get_cluster_id_for_words(word))
    cluster2phone_fst = ... # YOUR CODE HERE
    cluster2word_fst = ... # YOUR CODE HERE
    cluster2word_fst.optimize()
    return cluster2word_fst

word_decoder_fst = decode_word('lawn')
print_fst(word_decoder_fst)

If you did it right, the output should be a monstrous WFST. Let's not worry about trying to interpret it. Instead, let's print out the cost for each of the 5 words on our lexicon.

In [None]:
word_decoder_fst = decode_word('lawn')

# don't worry about understanding this code
# it's just printing out the top 5 most likely outputs
# from the WFST we constructed above

def decode_byte_str(byte_str):
    return ' '.join([symbols.find(ord(c)) for c in byte_str])

word_decoder_output = pynini.project(word_decoder_fst, 'output')
for path in pynini.shortestpath(word_decoder_output, nshortest=5, unique=True).paths().items():
    output_str = decode_byte_str(path[1])
    weight = path[-1]
    print(f"Output: {output_str}, Cost: {weight}")

As expected, 'lawn' has the lowest cost, followed by 'gnaw', 'lean', 'knee' and then 'kneel.' Feel free to play around with running the cell above while passing different words to `decode_word` to see what scores you get.

We've done it! We now know how to use vector quantization to model $\mathrm{P_{AM}}(X|Y)$ and a phone-to-word FST to model $\mathrm{P_{LM}}(Y)$, so we can put the two together to model $\mathrm{P_{ASR}}(X,Y)$. We then calculate $\hat{Y}=\mathrm{argmax}_Y\mathrm{P_{ASR}}(X,Y)$ by finding the output from shortest path on the resulting composed WFST.

**SIDE NOTE:** I've said twice now that the phone-to-word FST models $P(Y)$ without really explaining what that means. Basically, think about how a typical language model describes $P(W)$, where $W$ is a random variable representing a sequence of words. Since `phonestoword` only outputs a word that is on our defined lexicon, it is essentially a toy language model over the language $\{\mathrm{lawn,lean,gnaw,knee,kneel}\}$. It may seem strange to call this FST a language model since there are no weight values specified, just mappings of phones to words. Think of it this way: any word not on the vocabulary has a probability of 0, which is why there are no paths outputting that word. All words on the vocabulary have equal probability (independent of phone posteriors, of course), which is why they all have the same weight, i.e. none.

Since `phonestoword` also maps phones to words, though, it's really doing two tasks: *pronunciation modeling* as well as language modeling. Think of it this way, if $Y$ is the random variable over phone sequences, let $f_\mathrm{PM}(Y)$ indicate the function that maps a phone sequence $\mathbf{y}\sim{Y}$ to the equivalent word sequence $\mathbf{w}\sim{W}$, e.g. $f_\mathrm{PM}(l l a a a n n)=\langle\mathrm{lawn}\rangle$. Our FST models the probability $P(W)$, but since our acoustic model outputs *phone* posterior probabilities $\mathrm{P_{AM}}(X|Y)$ rather than *word* probabilities $\mathrm{P}(X|W)$, we get $\mathrm{P_{LM}}(W)=\mathrm{P_{LM}}(f_\mathrm{PM}(Y))$. So the more accurate way of writing the equation above would be:

$\hat{Y}=\mathrm{argmax}_Y\mathrm{P_{AM}}(X|Y)\mathrm{P_{LM}}(f_\mathrm{PM}(Y))$

And additionally:

$\hat{W}=f_\mathrm{PM}(\hat{Y})$.

## Section 3: Navigating a map with HMMs
HMMs are a Markov sequence model like WFSTs. Like WFSTs, HMMs have various states that are linked by connecting arcs where each arc is associated with a weight. The key difference between WFSTs and HMMs lies in the separation of states from *observations* in the latter.

In FSTs, we traverse the graph by consuming a sequence of letters, i.e. the input word. Let $W$ be the input word and $w_i$ be the $i^{th}$ letter in the input word. At each point we can transition from the current state $n$ to state $m$ if there is an arc between $n$ and $m$ that consumes the letter $w_i$. Thus, the letters dictate which states we visit as we draw a path through the FST.

In HMMs, we traverse the graph by instead consuming *observations*. Each state in the HMM has a certain probability of 'emitting' an observation. Each state also has a certain probability of *transitioning* to another state. This can be best illustrated with an example.

Imagine we are navigating between three states, where here 'state' is a literal state of the United States of America. In this version of America, the states of Texas, California and Alaska are all geographically contiguous, so we can walk from one straight into the other. Also in this America, there are no signs or overt indicator of which state we are in. We don't see a "Welcome to California" sign when we cross from Texas into California. What we do see, though, are geographical features, in particular snow, beaches or oil rigs.

![HMM example](media/hmm.svg)

There are more oil rigs in Texas than in the other two, there beaches in California than the other two, and there is more snow in Alaska than the other two. Thus, if we observe a lot of oil rigs, but few beaches and no snow, then there's a good chance we're in Texas. As soon as we start seeing a lot of beaches, there's a good chance we've made it to California. However, if we observe a lot of oil rigs, a single beach, and then more oil rigs, it's likely we stayed in Texas the whole time. That's because, even if it's more probable that you'll see a beach in California, we still need to *transition* from Texas to California, and after seeing a single beach it's more likely that we happened on a beach in Texas than we happened to transition from Texas to California.

These probabilities are represented by the color and thickness of the arrows in the graph above. A thick red line indicates high probability and a thin blue line low probability. The observation probabilities of oil rig, snow and beach are not the only thing that differentiates the different states. They also have different transition probabilities. The probability of transitioning from California to Texas is very high: Californian's want lower rent and gas. The probability of transitioning from Texas to California is a bit lower. Alaska is harder to get to, even though, for the sake of this metaphor, it's possible to walk there. Because of that, transition probabilities into and out of Alaska are lower than they are between California and Texas.

In [None]:
from pomegranate.hmm.sparse_hmm import SparseHMM
from pomegranate.hmm.dense_hmm import DenseHMM
from pomegranate.distributions import Categorical
import torch

In [None]:
observations = ["snow", "oil", "beach"]
snow = observations.index('snow')
oilrig = observations.index('oil')
beach = observations.index('beach')

texas = Categorical([[0.1, 0.7, 0.2]])
cali = Categorical([[0.3, 0.1, 0.6]])
alaska = Categorical([[0.7, 0.2, 0.1]])
states = [texas, cali, alaska]
state_names = ['texas', 'cali', 'alska'] # abbreviated Alaska for better spacing

hmm = DenseHMM(distributions=states)

hmm.add_edge(hmm.start, texas, 0.4)
hmm.add_edge(hmm.start, cali, 0.4)
hmm.add_edge(hmm.start, alaska, 0.2)

hmm.add_edge(texas, texas, 0.9)
hmm.add_edge(texas, cali, 0.06)
hmm.add_edge(texas, alaska, 0.04)

hmm.add_edge(cali, cali, 0.5)
hmm.add_edge(cali, texas, 0.3)
hmm.add_edge(cali, alaska, 0.2)

hmm.add_edge(alaska, alaska, 0.85)
hmm.add_edge(alaska, texas, 0.05)
hmm.add_edge(alaska, cali, 0.1)

hmm.edges, hmm.distributions

We can use the `viterbi` algorithm to find the most likely state path for a given sequence of observations.

In [None]:
def shape_observations_for_hmm(observation_idcs):
    return torch.tensor(observation_idcs).view(1, -1, 1)

def get_state_sequence(observation_idcs):
    observation_vector = shape_observations_for_hmm(observation_idcs)
    predicted_states = hmm.viterbi(observation_vector)
    predicted_state_labels = [state_names[i] for i in predicted_states[0]]
    return predicted_state_labels

def print_states(observation_idcs):
    states = get_state_sequence(observation_idcs)
    observation_labels = [observations[i] for i in observation_idcs]
    print("->\t".join(states))
    print("->\t".join(observation_labels))
    print()

print_states([beach, beach, snow, beach, oilrig])

Notice that transitioning from beach to snow once does not switch us from California to Alaska, but it only takes one oil rig to transition from California to Texas because, as observed in the transition probabilities above, it is much easier to get to Texas from California than to get to Alaska from California.

Let's see what happens when we see snow twice.

In [None]:
print_states([beach, beach, snow])
print_states([beach, beach, snow, snow])

When we see snow twice, it is most likely that *both* instances of snow are in Alaska, even though if we see snow only once, we stay in California. Why does this happen? HMMs transitions only account for the *immediately previous* state, but this makes it look like state sequences are actually considering what state follows!

The secret lies in the Viterbi algorithm. To compute a path using Viterbi, we consider *all possible state sequences*, rather than greedily considering the most likely state going from left to right. Because of that, when we compute the most likely path with Viterbi, we can account for the fact that if we observe to sequences of snow, it is more likely we transitioned to Alaska and *stayed there* rather than observing one snow patch and *then* transitioning to Alaska.

**SIDE NOTE**: If you're familiar with the Viterbi algorithm you might object that it does in fact calculate the path going left to right. While this is true, it does not do so *greedily*. Instead, it uses dynamic programming to build the most likely path of all possible paths by considering the forwards and backwards probability at each point going left to right. See [Jurafsky and Martin, 2nd edition, Ch. 9, pp 316-326](https://github.com/rain1024/slp2-pdf/blob/0add5260eca38f541909cfa28d17e9ae96008d60/chapter-wise-pdf/%5B09%5D%20Automatic%20Speech%20Recognition.pdf).

As another example, if we observe a single oilrig in the middle of a bunch of beaches, then it's most likely we found an oilrig in California, rather than transitioning to Texas.

In [None]:
print_states([beach, beach, oilrig, beach, beach, beach])

**EXERCISE 5**: Create a sequence of observations that will result in us starting out in Alaska, going to California and then Texas, with the following two constraints.
1. There should be at least one beach in Texas
2. There should be at least one oilrig in Alaska

In [None]:
print_states([snow,]) # ADD STATES HERE

Now that we've played around with geography enough, let's get back to speech. How do the observation and transition probabilities relate to the WFST we created above? In short if we let the quantized vectors be the observations and the phones be the states, then the observation probabilities are directly related to the *posterior phone probabilities* and the transition probabilities to the arcs on the `phonetowords` FST.

Remember how above we calculated $P_\mathrm{AM}(X|Y)$ by counting the number of audio frames for a given cluster for each phone? This requires that we know ahead of time what frame each phone is aligned with, which requires human annotators to go through and align the data. That's a lot of work, especially when we're dealing with hours of audio. Surely there's a better way to do this!

One useful feature of HMMs is the *Baum-Welch* training algorithm [Jurafsky and Martin, 2nd edition, Ch. 9, pp 326-330](https://github.com/rain1024/slp2-pdf/blob/0add5260eca38f541909cfa28d17e9ae96008d60/chapter-wise-pdf/%5B09%5D%20Automatic%20Speech%20Recognition.pdf). While I won't go in to the math of Baum-Welch training here, suffice to say that it allows us to learn the parameters of the HMM (both the observation probabilities and the transition probabilities) from a speech dataset *without requiring phone alignments*. It does this through a process called Expectation Maximization (EM). At a very high level, EM iteratively improves the parameters of the model by maximizing the probability of the dataset. Think of it like a language model reducing its perplexity of the ground truth data as it is trained, except now we're reducing the perplexity of speech and text together. Feel free to read the Jurafsky and Martin textbook for a more satisfying explanation than the one I give here. Also check out [this youtube playlist by Herman Kamper](https://www.youtube.com/playlist?list=PLmZlBIcArwhMIRdgNwFUWGqY53h2TC6PH). All we need to know for this assignment, though, is that Baum-Welch training allows us to fit a HMM acoustic model without needing ground truth alignments.

To begin, let's make a function that will create a list of Categorical distributions for each phone in our inventory. To ensure we're not cheating, each distribution gets initialized with equal probability for each phone.

In [None]:
def create_phone_states(phones: List[str]=phones, n_clusters: int=n_clusters):
    """
    Create a list of Categorical distributions, one for each phone in `phones`.
    Each distribution should have one outcome for each cluster in `n_clusters`.
    Set pseudocount=0.1 to avoid zero probabilities.
    The initial probability for each cluster should be uniform (1/n_clusters).

    Returns:
        phone_states: List of pomegranate Categorical distributions.
    """
    phone_states = []
    initial_prob = 1/n_clusters
    for _ in phones:
        cluster_state = Categorical([[initial_prob]*n_clusters], pseudocount=0.1)
        phone_states.append(cluster_state)
    return phone_states

Now let's set the transition probabilities. We're actually going to do this by creating a separate HMM for each word, with the observation distributions shared across all of them. This means that each word gets a unique set of transition probabilities but the observation probabilities are the same regardless of word. This way the HMM for 'lawn' for example will exclude any transitions to or from the phone [i] which is not found in the word. It also allows us to ensure a monotonic sequence of phones: whatever observation sequence the HMM for 'lawn' sees, it will ensure that the output is some sequence of (to borrow RegEx notation) `l+a+n+`.

Is this cheating? Sort of. But it's the kind of cheating we'd be able to do anyways when training ASR. After all, we know the transcript for each sentence we're training on. Therefore, when training an HMM on an ASR dataset, we can help the training by constraining the possible state sequences for each sentence to match the sequence of phones on that sentence. We simulate that here by constraining the state sequence for each individual word instead.

In [None]:
self_transition_prob = 0.6
transition_prob = (1-self_transition_prob)

def create_word_hmm(word: str, phone_states: List[Categorical]) -> SparseHMM:
    """
    Create a Hidden Markov Model (HMM) for a word using its phoneme states.
    Each phoneme state should have a self-transition with probability `self_transition_prob`
    and a transition to the next phoneme state with probability `transition_prob`.

    Args:
        word: The word for which to create the HMM.
        phone_states: List of pomegranate Categorical distributions for each phoneme.
    Returns:
        word_model: A pomegranate DenseHMM representing the word.
    """
    word_phones = lexicon[word].split()
    states_for_word = [phone_states[phones.index(phone)] for phone in word_phones]
    word_model = DenseHMM(distributions=states_for_word, max_iter=1, inertia=0.8)
    for i, phone in enumerate(word_phones):
        phone_state = phone_states[phones.index(phone)]
        word_model.add_edge(phone_state, phone_state, self_transition_prob)

        if i==0:
            word_model.add_edge(word_model.start, phone_state, 1.0)
        else:
            prev_phone_state = phone_states[phones.index(prev_phone)]
            word_model.add_edge(prev_phone_state, phone_state, transition_prob)

        if i==len(word_phones)-1:
            word_model.add_edge(phone_state, word_model.end, transition_prob)
        prev_phone = phone
    return word_model

Let's make a helper function to get a tensor of observed cluster IDs for each word.

In [None]:
def get_observations_for_word(word: str):
    """
    Get the cluster IDs for all occurrences of `word` in the dataset.
    Returns a tensor of shape (1, num_occurrences, 1) suitable for input to pomegranate HMM.

    Args:
        word: The word for which to get observations.
    Returns:
        observations: A tensor of shape (1, num_occurrences, 1) containing cluster IDs.
    """
    cluster_ids = get_cluster_id_for_words(word, return_type=int)
    observations = torch.tensor(cluster_ids).view(1, -1, 1)
    return observations

Now let's train! We train the HMMs for 10 epochs in total, iterating through each word once per epoch. To show the improvement in training, we print out the probability predicted for each word before and after training. We use the `word_model.fit()` function to do Baum-Welch training.

In [None]:
epochs = 10

phone_states = create_phone_states()
word_hmms = [create_word_hmm(word, phone_states) for word in words]
get_word_probs = lambda hmm_list: [
    word_model.probability(get_observations_for_word(word)).item()
    for word, word_model in zip(words, hmm_list)
]

untrained_probs = get_word_probs(word_hmms)

for epoch in range(epochs):
    for word in words:
        word_model = word_hmms[words.index(word)]
        word_cluster_ids = get_cluster_id_for_words(word, return_type=int)
        observation_vector = shape_observations_for_hmm(word_cluster_ids)
        word_model.fit(observation_vector)

trained_probs = get_word_probs(word_hmms)
for word, untrained_prob, trained_prob in zip(words, untrained_probs, trained_probs):
    print(f"Word: {word}, Untrained Prob: {untrained_prob:.4e}, Trained Prob: {trained_prob:.4e}, Improvement: {trained_prob/untrained_prob:.2e}x")

That's quite the improvement! For every word the predicted probability increased by several orders of magnitude. This shows that the HMMs are much better fit to the word they model. To further illustrate this, let's compare the heatmap of the phone posteriors we computed earlier (using hand-annotated phone alignments) to our Baum-Welch computed emission probabilities.

In [None]:
phone_emission_matrix = np.zeros((n_clusters, len(phones)))
for i, phone in enumerate(phones):
    phone_state = phone_states[i]
    phone_emission_matrix[:, i] = list(phone_state.parameters())[1]

_, axes = plt.subplots(1, 2, figsize=(15, 5))
cluster_labels = [f'c_{i}' for i in range(n_clusters)]
fig1 = sns.heatmap(phone_emission_matrix, yticklabels=cluster_labels, xticklabels=phones, ax=axes[0])
fig1.set_title('HMM Emission Probabilities')

fig2 = sns.heatmap(posterior_phone_prob_matrix, yticklabels=cluster_labels, xticklabels=phones, ax=axes[1])
fig2.set_title('Hand-Computed Posterior Probabilities')
plt.show()

They're pretty close for most of the clusters! Clusters $c_6$ and $c_{11}$ and to a lesser extent $c_0$ are a bit off though.

If we were to get the most likely path for each word HMM we wouldn't see much interesting. Thanks to the transition probabilities, each word HMM *has* to output the phones for it's respective word! Let's instead look at what phone sequences we get when we initialize an HMM with naive transition probabilities and re-use our emission probabilities.

In [None]:
def get_predicted_phone_sequence(hmm: Union[DenseHMM, SparseHMM], word: str):
    observations = get_observations_for_word(word)
    state_idcs = hmm.viterbi(observations)
    phone_sequence = [phones[i] for i in state_idcs[0]]
    return ' '.join(phone_sequence)

acoustic_model = DenseHMM(distributions=phone_states)
for word in words:
    print(word, get_predicted_phone_sequence(acoustic_model, word))

Not perfect but much better than chance. 'lean' and 'kneel' are 100% correct, 'gnaw' has become 'gnlaw', 'lawn' has become 'lawln' and 'knee' has become 'kneen'. Let's try a more intelligent way of setting the transition probabilities by interpolating the transitions across all word HMMs.

In [None]:
interpolated_transition_probs = np.zeros((len(phone_states), len(phone_states)))

for i, state_i in enumerate(phone_states):
    for j, state_j in enumerate(phone_states):
        for hmm in word_hmms:
            distribution_list = list(hmm.distributions)
            if state_i in distribution_list and state_j in distribution_list:
                state_i_idx = distribution_list.index(state_i)
                state_j_idx = distribution_list.index(state_j)
                edge_weight = hmm.edges[state_i_idx, state_j_idx]
                edge_weight = torch.exp(edge_weight).item() # convert from log prob to prob
                interpolated_transition_probs[i, j] += edge_weight

interpolated_transition_probs /= len(word_hmms)
interpolated_transition_probs

Let's plug these transition probabilities into an HMM now.

In [None]:
interpolated_acoustic_model = DenseHMM(distributions=phone_states, edges=interpolated_transition_probs)
for word in words:
    print(word, get_predicted_phone_sequence(interpolated_acoustic_model, word))

Voi-là! We now have a single HMM that correctly predicts the phone sequence for each word in our lexicon!

**EXERCISE 6**: Let's try a different strategy. Instead of interpolating all of the HMM transition probabilities into a single model, let's instead decide which model we want to 

In [None]:
word_observations = get_observations_for_word('lawn')
for word, word_model in zip(words, word_hmms):
    word_prob = word_model.probability(word_observations).item()
    print(f"HMM for {word} predicts probability {word_prob:.2e} for word 'lawn'")

**Check your understanding**: How does the probability predicted by the HMM for 'lawn' compare to that predicted by other models? How can we use this to pick which model we want to use for prediction without having any more information about the word? Write your answer below.

Your answer here:

**Now implement**: Fill out the function below to implement this new strategy.

In [None]:
def get_best_word_model(word: str, word_hmms: List[DenseHMM]) -> DenseHMM:
    """
    Given a word and a list of word HMMs, return the HMM that predicts the highest
    probability for the given word.

    Args:
        word: The word for which to find the best HMM.
        word_hmms: List of pomegranate DenseHMMs for each word.
    Returns:
        best_model: The DenseHMM that predicts the highest probability for `word`.
    """
    best_model = None
    # YOUR CODE HERE
    return best_model

for word in words:
    best_model = get_best_word_model(word, word_hmms)
    best_model_prediction = get_predicted_phone_sequence(best_model, word)
    print(f"Best model predicts for '{word}': {best_model_prediction}")

## Recap
We've now implemented acoustic models by hand-coding WFSTs to represent observation probabilities of acoustic vectors and with HMMs by using the Baum-Welch algorithm. Along the way, we also learned how MFCCs can be used to encode audio and how to use these encodings to compute the posterior probability of audio given a phone label. This has given us the tools we need to work with modeling actual audio data.

**Check your understanding**: Answer the questions below.

1. What probability does the AM model? The LM? What probability distribution do they describe when we combine them together? Give your answer using mathematical notation, e.g. $\mathrm{P}(a)$ or $\mathrm{P}(a|b)$ and explain what the variables mean.

2. Similarly, what probability is modeled by the observation probabilities of an HMM? What do the transition probabilities model?


3. How do the transition probabilities of an HMM relate to the arcs of a WFST?


4. Believe it or not, the 'best model' selection method above is directly equivalent to the WFST graph we compiled in Exercise 4 minus the phone $\mapsto$ words step. In other words, it's the same as the `cluster2phone` WFST. In fact, we could directly create a WFST like `cluster2phone` using the transition and observation probabilities of our HMMs above and then compose it with `phones2word` and get a functionally equivalent graph. Explain how to convert the HMMs to a WFST equivalent to `cluster2phone`. Give your explanation in prose. No need to write code.

