In [74]:
from glob import glob

import os
os.chdir(r'C:\projects\malachor5')
import sys
import json
import torch
from tqdm import tqdm
sys.path.append('scripts')
from hmm_utils import KeySimilarityMatrix, init_keyword_hmm
from kws import textgrid_to_df, timestamp_hits
from copy import deepcopy
from jiwer import cer, wer
kws_eval_files = glob(r"C:\projects\malachor5\data\keyword_search\tira_eval_kws\*.json")

In [2]:
file2matrices = {}
stems = set(os.path.basename(filepath).split('-')[0] for filepath in kws_eval_files)

get_window_len = lambda filepath: float(
    os.path.basename(filepath)
        .split('-')[1]
        .removesuffix('.json')
        .removesuffix('sec')
        .replace('_', '.')
)

for stem in tqdm(stems):
    stem_files = [filepath for filepath in kws_eval_files if stem in filepath]
    window_lens = set(get_window_len(filepath) for filepath in stem_files)
    file2matrices[stem]={}
    for window_len in tqdm(window_lens):
        file_with_len = [filepath for filepath in stem_files if get_window_len(filepath)==window_len][0]
        with open(file_with_len, encoding='utf8') as f:
            json_obj = json.load(f)
        sim_mat = torch.tensor(json_obj['similarity_matrix'])
        oov_probs = torch.tensor(json_obj['oov_probs'])
        file2matrices[stem][window_len]={}
        file2matrices[stem][window_len]['sim_mat']=sim_mat
        file2matrices[stem][window_len]['oov_probs']=oov_probs
        file2matrices[stem][window_len]['timestamps']=json_obj['timestamps']
        if 'keywords' not in file2matrices[stem]:
            file2matrices[stem]['keywords']=json_obj['keywords']

100%|██████████| 5/5 [00:03<00:00,  1.34it/s]
100%|██████████| 5/5 [00:02<00:00,  1.79it/s]
100%|██████████| 5/5 [00:06<00:00,  1.25s/it]
100%|██████████| 3/3 [00:12<00:00,  4.26s/it]


In [3]:
tgs = glob(r"C:\projects\malachor5\data\keyword_search\tira_eval_mfa_aligned\*.TextGrid")
for stem in stems:
    stem_tg = [tg for tg in tgs if stem in tg][0]
    file2matrices[stem]['textgrid']=stem_tg

In [4]:
kw_files = glob(r"C:\projects\malachor5\data\keyword_search\keyword_lists\*.txt")
for stem in stems:
    stem_kwfile = [kwfile for kwfile in kw_files if stem in kwfile][0]
    with open(stem_kwfile, encoding='utf8') as f:
        keyphrases = [x.strip() for x in f.readlines()]
    file2matrices[stem]['keyphrases']=keyphrases

In [7]:
keyphrases=file2matrices['HH20210913']['keyphrases']
keywords=file2matrices['HH20210913']['keywords']
timestamps=file2matrices['HH20210913'][0.5]['timestamps']
sim_mat=file2matrices['HH20210913'][0.5]['sim_mat']
oov_probs=file2matrices['HH20210913'][0.5]['oov_probs']
tg_path=file2matrices['HH20210913']['textgrid']
tg_df=textgrid_to_df(tg_path)

In [8]:
ground_truth = torch.zeros_like(sim_mat)
for i, keyword in tqdm(enumerate(keywords), total=len(keywords)):
    ground_truth[:,i] = torch.tensor(timestamp_hits(tg_df, keyword, timestamps))

100%|██████████| 23/23 [01:47<00:00,  4.67s/it]


In [9]:
kw_idcs=ground_truth.argmax(dim=1)
kw_idcs[kw_idcs>0]

tensor([13, 13, 13, 13, 13, 13, 21, 21, 21, 21, 21, 21, 13, 13, 12, 12, 12, 12,
        13, 13, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 12, 12,
        12, 12, 15, 15, 15, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13,
        13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13,
        13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 18, 18, 18, 18, 18,
        18,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 17, 17, 17,
        17, 17, 17, 17, 17, 17, 17,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2, 17,
        17, 17, 17, 17, 17, 17, 17, 17, 17, 17,  7,  7, 17, 17, 17, 17, 17, 17,
        17, 17, 17, 17, 17, 17, 17, 17, 17, 17,  7,  7, 17, 17,  7,  7, 17, 17,
         7,  7, 17, 17,  7,  7, 17, 17,  7,  7, 17, 17,  7,  7, 17, 17,  7,  7,
        17, 17,  7,  7, 19, 19, 19, 19, 19, 19,  7,  7, 19, 19,  7,  7, 19, 19,
        19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19, 19,  7,  7,
        19, 19,  7,  7, 19, 19,  7,  7, 

In [10]:
hmm_prefit, states = init_keyword_hmm(keyphrases, keywords)
hmm_bw = deepcopy(hmm_prefit)

In [11]:
emission_mat=torch.concat([sim_mat, oov_probs[:,None]], dim=1).unsqueeze(0)

In [16]:
hmm_prefit.edges

[(KeySimilarityMatrix(), KeySimilarityMatrix(), 0.004347826086956522),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.004347826086956522),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.245),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.245),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.004347826086956522),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.004347826086956522),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.245),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.245),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.004347826086956522),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.004347826086956522),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.49),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.245),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.245),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.004347826086956522),
 (KeySimilarityMatrix(), KeySimilarityMatrix(), 0.004347826086956522),
 (KeySimilarityMatrix(), KeySimilarityMat

In [58]:
bw=hmm_prefit.backward(emission_mat)
bw

tensor([[[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [-3.7621, -3.8908, -3.9489,  ..., -3.7290, -3.4059, -3.4059],
         [-3.2581, -3.2581, -3.2581,  ..., -3.2581, -3.2581, -3.2581]]])

In [30]:
fw=hmm_bw.forward(emission_mat)
fw

tensor([[[-4.5236, -5.3604, -5.4243,  ..., -4.3505, -3.3035, -3.3035],
         [-6.4340, -7.9089, -8.0804,  ..., -5.5795, -2.5784, -2.5784],
         [-7.7539, -9.0649, -9.3165,  ..., -6.9986, -2.6426, -2.6426],
         ...,
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
         [    nan,     nan,     nan,  ...,     nan,     nan,     nan]]])

In [56]:
fw_nan=fw.isnan().sum(dim=-1)
fw_nan[:,:20]

tensor([[ 0,  0,  0,  0, 12, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26,
         26, 26]])

In [None]:
viterbi=hmm_prefit.viterbi(emission_mat)
viterbi.shape

torch.Size([1, 15175])

In [75]:
kw_str = ' '.join(str(x) for x in kw_idcs.tolist() if x!=0)
viterbi_str = ' '.join(str(x) for x in viterbi.squeeze().tolist() if x!=0)
wer(kw_str, viterbi_str)

33.496688741721854