In [135]:
import os
import sys
import subprocess
from collections import defaultdict, OrderedDict

import numpy as np
import pandas as pd
from dotenv import load_dotenv

from kaldi_io import *

# Helpers

In [136]:
def softmax(x):
    x = x.T
    x = x - np.max(x, axis=0)
    x = np.exp(x)
    x /= np.sum(x, axis=0)
    return x.T

# GOP computation

## Set environment

In [137]:
# Set environment
load_dotenv('./.env', verbose=True, override=True)
#print(os.environ.get('KALDI_ROOT'))
#print(os.environ.get('QA_ROOT'))
#print(os.environ.get('DATA_ROOT'))
#print(os.environ.get('PATH'))

In [138]:
MODEL_DIR = os.path.join(os.environ.get('QA_ROOT'), "s5_mfcc/exp/mono_mfcc_4")

In [139]:
model =  os.path.join(MODEL_DIR, 'final.mdl')
data = os.path.join(MODEL_DIR, '../../data/train/split8/1')
phones_file = os.path.join(MODEL_DIR, 'phones.txt')
ali_file = os.path.join(MODEL_DIR, 'ali.1.gz')

## Get mappings

In [None]:
# Get mappings between phoneme int code and symbol
symb2int = phone_symb2int(phones_file)
int2symb = {y: x for x, y in symb2int.items()}

# Get mapping for pdf id -> phoneme symbol
pdf2symb = pdf2phone(MODEL_DIR)
pdf2int = {k:symb2int[v] for k,v in pdf2symb.items()}

## Get alignments

In [142]:
alis = read_ali_from_stdout(model, f'ark:"gunzip -c {ali_file}|"')

## Compute GMM probs for pdf_id(s)

In [140]:
# Compute gmm probs for each pdf id
feats_rspec = f'"ark,s,cs:apply-cmvn --utt2spk=ark:{data}/utt2spk scp:{data}/cmvn.scp scp:{data}/feats.scp ark:- | add-deltas ark:- ark:- |"'
gmm_likes_command = f'gmm-compute-likes {model} {feats_rspec} ark,t:-'

gmm_likes = read_feats_from_stdout(gmm_likes_command)
probs = {k:softmax(v) for k,v in gmm_likes.items()}

## Switch from pdf_id to phoneme: sum probs by phoneme

In [141]:
probs_by_phoneme = {}
for reciter, prob in probs.items():
    df = pd.DataFrame(data=prob, columns=pdf2int.values())
    probs_by_phoneme[reciter] = (
        df.groupby(by=df.columns, axis=1)
          .sum()
          # If you want to return a numpy array just add
          #.values
    )

## Get probs summary: predicted phoneme (i.e. predicted) VS. alignment phoneme (i.e. real)

In [153]:
# TODO: This seems to be slow.
# TODO: Check reciters intersection
probs_summary = {}
for reciter, prob in probs_by_phoneme.items():
    probs_summary[reciter] = []
    try:
        ali = alis[reciter]
    except KeyError:
        print(f'Could not find alignment for {reciter}')
        continue
    for index, row in prob.iterrows():
        probs_summary[reciter].append(
            [row.max(), row.idxmax(), row[ali[index]], ali[index]]
        )