In [35]:
from pomegranate.hmm import SparseHMM
from pomegranate.distributions import Categorical
from pomegranate.distributions._distribution import Distribution
from pomegranate._utils import _cast_as_tensor
import torch

import sys
import os
os.chdir(r'C:\projects\malachor5')
sys.path.append("scripts")
from kws import get_similarity_matrix

# KWS inference with HMMs
Need to implement the following functionality:
- Initialize HMM from list of keywords with one state for each word + silence + non-keyword speech
- Set transition weights based on key phrases with small transition prob to silence/NKWS
- Set observation probability to cosine similarity value from similarity matrix: easiest way is to either:
    1. Set `probability` function to index $i^{th}$ item from vector of similarity values
    2. Set `probability` function to return cosine similarity between state embedding and observed 
Let's stick with the first option, as I'll likely want to keep calculating the similarity matrix ahead of time to compute a soft-prealignment before running HMM inference.

In [None]:
class EmbeddingSimilarity(Distribution):
    def __init__(
            self,
            state_embed,
            inertia=0.0,
            frozen=False,
            check_data=True
        ):
        super().__init__(inertia=inertia, frozen=frozen, check_data=check_data)
        self.name = "Embedding similarity"
        self.state_embed = _cast_as_tensor(state_embed)

    def log_probability(self, X):
        X = _cast_as_tensor(X)
        return torch.nn.functional.cosine_similarity(self.state_embed, X).log()

state_embed = [1,   0, 1]
close_embed = [0.9, 0, 1]
orth_embed =  [0,   1, 0]
embedsim = EmbeddingSimilarity(state_embed=state_embed)
embedsim.log_probability([close_embed, close_embed, orth_embed])

tensor([-0.0014, -0.0014,    -inf])

In [54]:
class KeySimilarityMatrix(Distribution):
    def __init__(
            self,
            col_i: int,
            max_i: int,
            inertia=0.0,
            frozen=False,
            check_data=True
        ):
        super().__init__(inertia=inertia, frozen=frozen, check_data=check_data)
        self.col_i = col_i
        self.d = max_i
        self.name = "KeySimilarityMatrix"
    
    def log_probability(self, X):
        X = _cast_as_tensor(X)
        return X[:,self.col_i].log()
    
    def _reset_cache(self):
        return

keysim1 = KeySimilarityMatrix(0, 2)
keysim2 = KeySimilarityMatrix(1, 2)
state1_embed = [0.5, 1,   0]
state2_embed = [1,   0.5, 0]
close1_embed = [0.5, 0.9,   0]
close2_embed = [0.9, 0.5, 0]
orth_embed =   [0,   0,   1]

simmat = get_similarity_matrix(
    torch.tensor([close1_embed, close2_embed, orth_embed]),
    torch.tensor([state1_embed, state2_embed]),
)
simmat

tensor([[0.9991, 0.8253],
        [0.8253, 0.9991],
        [0.0000, 0.0000]])

In [51]:
keysim1.log_probability(simmat), keysim2.log_probability(simmat)

(tensor([-0.0009, -0.1920,    -inf]), tensor([-0.1920, -0.0009,    -inf]))

In [61]:
hmm = SparseHMM(
    distributions=[keysim1, keysim2],
    edges=[
        [keysim1,keysim1,0.5],
        [keysim1,keysim2,0.5],
        [keysim2,keysim2,0.5],
        [keysim2,keysim1,0.5],
    ],
    starts=[0.5, 0.5],
    ends=[0.5, 0.5],
)
X = get_similarity_matrix(
    torch.tensor([close1_embed, close2_embed, close2_embed, close1_embed]),
    torch.tensor([state1_embed, state2_embed]),
)
hmm.viterbi(X.unsqueeze(0))


tensor([[0, 1, 1, 0]])