In [21]:
import numpy as np
from collections import defaultdict
from itertools import groupby
import pickle
import os
import glob

from bokeh.plotting import figure, output_notebook, show

output_notebook()

In [42]:
def read_timit_labels(path, frate=100, srate=16000):
    '''Read TIMIT label files.
    
    Args:
        path (str): Path to the TIMIT label file.
        samp_period (int): Features sampling rate in 
            100ns (default is 100 Hz).
        srate (int): Audion sampling rate (default: 16000 Hz)
    
    Returns:
        list of tuple
        
    '''
    factor = frate / srate
    print(factor)
    segmentation = []
    with open(path, 'r') as file_obj:
        for line in file_obj:
            line = line.strip()
            tokens = line.split()
            if len(tokens) == 0:
                continue
            if len(tokens) != 3:
                raise ValueError('File is badly formatted.')
            start = int(int(tokens[0])*factor)
            end = int(int(tokens[1])*factor)
            segmentation += [tokens[2]] * (end - start)
    return segmentation

read_timit_labels('/mnt/matylda2/data/TIMIT/timit/train/dr1/fcjf0/sa1.phn')

0.00625


['h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'h#',
 'sh',
 'sh',
 'sh',
 'sh',
 'sh',
 'sh',
 'sh',
 'sh',
 'sh',
 'ix',
 'ix',
 'ix',
 'ix',
 'ix',
 'ix',
 'ix',
 'hv',
 'hv',
 'hv',
 'hv',
 'hv',
 'hv',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'eh',
 'dcl',
 'dcl',
 'dcl',
 'jh',
 'jh',
 'jh',
 'jh',
 'jh',
 'jh',
 'jh',
 'ih',
 'ih',
 'ih',
 'ih',
 'ih',
 'ih',
 'ih',
 'dcl',
 'dcl',
 'dcl',
 'dcl',
 'dcl',
 'dcl',
 'dcl',
 'd',
 'ah',
 'ah',
 'ah',
 'ah',
 'ah',
 'ah',
 'ah',
 'ah',
 'ah',
 'ah',
 'ah',
 'ah',
 'kcl',
 'kcl',
 'kcl',
 'kcl',
 'kcl',
 'kcl',
 'kcl',
 'kcl',
 'k',
 'k',
 'k',
 's',
 's',
 's',
 's',
 's',
 's',
 's',
 's',
 's',
 's',
 's',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'ux',
 'q',
 'q',
 'q',
 'q',
 'q',
 'en',
 'en',
 'en',
 'en',
 'en',
 'en',
 'en',
 'en',
 'en',
 'gcl',
 'gcl',

In [36]:
def load_transcript(path):
    with open(path, 'r') as f:
        trans = {}
        for line in f:
            tokens = line.strip().split()
            trans[tokens[0]] = tokens[1:]
    return trans

def bimap_trans(trans, bimap):
    new_trans = {}
    for utt, utt_trans in trans.items():
        new_utt_trans = []
        prev_sym = '<s>'
        for token in utt_trans:
            try:
                new_utt_trans.append(bimap[prev_sym][token])
            except KeyError as err:
                print(prev_sym, token)
                raise err
            prev_sym = token
        new_trans[utt] = new_utt_trans
    return new_trans

def map_trans(trans, mapfile):
    new_trans = {}
    for utt, utt_trans in trans.items():
        tmp = []
        for token in utt_trans:
            try:
                new_token = mapfile[token]
                tmp.append(new_token)
            except KeyError:
                pass
        new_trans[utt] = tmp
    return new_trans

def wer(r, h):
    """
    Calculation of WER with Levenshtein distance.

    Works only for iterables up to 254 elements (uint8).
    O(nm) time ans space complexity.

    Parameters
    ----------
    r : list
    h : list

    Returns
    -------
    int

    Examples
    --------
    >>> wer("who is there".split(), "is there".split())
    1
    >>> wer("who is there".split(), "".split())
    3
    >>> wer("".split(), "who is there".split())
    3
    """
    # initialisation
    d = np.zeros((len(r)+1)*(len(h)+1), dtype=int)
    d = d.reshape((len(r)+1, len(h)+1))
    for i in range(len(r)+1):
        for j in range(len(h)+1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    # computation
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):
            if r[i-1] == h[j-1]:
                d[i][j] = d[i-1][j-1]
            else:
                substitution = d[i-1][j-1] + 1
                insertion    = d[i][j-1] + 1
                deletion     = d[i-1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)

    return d[len(r)][len(h)]

In [25]:
ref = load_transcript('/mnt/matylda5/iondel/workspace/2019/gsm/beer/recipes/aud/data/timitfull/all/trans_61')
hyp = load_transcript('exp/timit/aud_mfcc_4g_dirichlet/decode/train/trans')
#ref_align = load_transcript('exp/timitfull/monophone_mbn_babel/align_ac1.0/all/trans')
hyp_align = load_transcript('exp/timit/aud_mfcc_4g_dirichlet/decode_perframe/train/trans')

mapfile = '/mnt/matylda5/iondel/workspace/2019/gsm/beer/recipes/aud/data/timit/lang/phones_61_to_39.txt'
phonemap = {}
with open(mapfile, 'r') as f:
    for line in f:
        try:
            p1, p2 = line.strip().split()
            phonemap[p1] = p2
        except ValueError:
            pass

In [26]:
ref_align = {}
for path in glob.glob('/mnt/matylda2/data/TIMIT/timit/*/dr*/*/*.phn'):
    fname = os.path.basename(path)
    root, ext = os.path.splitext(fname)
    spkname = os.path.basename(os.path.dirname(path))
    uttid = spkname + '_' + root
    trans = []
    for entry in read_timit_labels(path):
        if entry[0] == '#':
            print(uttid)
        trans += [entry[0]] * (entry[2] - entry[1])
    try:
        utt_len = len(hyp_align[uttid])
    except KeyError:
        continue
    if utt_len > len(trans):
        pad = abs(utt_len - len(trans))
        trans += [trans[-1]] * pad
    elif utt_len < len(trans):
        trans = trans[:utt_len]
    ref_align[uttid] = trans

## Acoustic unit to phone mapping

In [43]:
counts = defaultdict(lambda: defaultdict(int))
for utt in ref_align:
    for ref_unit, hyp_unit in zip(ref_align[utt], hyp_align[utt]):
        if ref_unit == '#':
            print(utt, ' '.join(ref_align[utt]))
        counts[hyp_unit][ref_unit] += 1

au_map = {au: max(label_counts, key=label_counts.get) for au, label_counts in counts.items()}
len(au_map)

59

## Equivalent phone error rate

In [28]:
ref_trans = map_trans(ref, phonemap)
hyp_trans = map_trans(hyp, au_map)
hyp_trans = map_trans(hyp_trans, phonemap)
hyp_trans = {utt: [x[0] for x in groupby(trans)] for utt, trans in hyp_trans.items()}

# remove sil
#ref_trans = {utt: list(filter(lambda a: a != 'sil', trans)) for utt, trans in ref_trans.items()}
#hyp_trans = {utt: list(filter(lambda a: a != 'sil', trans)) for utt, trans in hyp_trans.items()}

In [29]:
acc_wer = 0
nwords = 0
for utt in ref_trans:
    try:
        ref_t, hyp_t = ref_trans[utt], hyp_trans[utt]
        acc_wer += wer(ref_t, hyp_t)
        nwords += len(ref_t)
    except KeyError:
        pass
print(f'Phone Error Rate: {100 * acc_wer / nwords:.2f}')

Phone Error Rate: 65.79


In [30]:
import random
for utt in random.choices(list(ref.keys()), k=1):
    print('(ref)', utt, ' '.join(ref_trans[utt]))
    print('(hyp)', utt, ' '.join(hyp[utt]))
    print('(hyp)', utt, ' '.join(hyp_trans[utt]))

(ref) mdks0_si1696 sil n aa sil t w ih n sh iy sil w ey dx ih sil s ow l aa ng aa l r eh sil d iy sil
(hyp) mdks0_si1696 sil au45 au18 au36 au65 au27 au58 au74 au76 au65 au20 au72 au12 au27 au74 au47 au12 au72 au60 au78 au40 au58 au40 au21 au81 au40 au24 au13 au6 au60 au50 au30 au1 sil
(hyp) mdks0_si1696 sil aa n l ay l sil n sh aa iy l ey iy aa s ow ay ow l m ow aa r s iy s f sil


## Normalized Mutual Information

In [141]:
def align(T_ref, T_new, labels, clusters):
    counts_labels = np.zeros(len(labels))
    counts_clusters = np.zeros(len(clusters))
    counts = np.zeros((len(clusters), len(labels))) + 1
    for utt in T_ref.keys():
        data_ref = T_ref[utt]
        data_new = T_new[utt]

        ref_labels = []
        mu = []
        for t in data_ref:
            label, start, stop, _, _ = t
            mu.append(start + 0.5*(stop-start))
            ref_labels.append(label)
            idx = labels.index(label)
            counts_labels[idx] += 1
        mu = np.asarray(mu)

        if len(mu) == 0:
            print(utt)

        for t in data_new:
            cluster, start, stop, _, _ = t
            idx = clusters.index(cluster)
            counts_clusters[idx] += 1
            x = start + 0.5 * (stop-start)
            closest_label = ((x-mu)**2).argmin()
            i = clusters.index(cluster)
            j = labels.index(ref_labels[closest_label])
            counts[i,j] += 1

    return counts, counts_labels, counts_clusters

In [151]:
def timing(align):
    timing_trans = {}
    for utt, trans in align.items():
        label = trans[0]
        start = 0
        timings = []
        for i, next_label in enumerate(trans[1:], 1):
            if label != next_label:
                timings.append((label, start, i, None, None))
                label = next_label
                start = i
        timing_trans[utt] = timings
        timings.append((label, start, len(trans) -1, None, None))
    return timing_trans

ref_t_trans = timing(map_trans(ref_align, phonemap))
hyp_t_trans = timing(hyp_align)

In [152]:
counts = defaultdict(lambda: defaultdict(int))
ref_align39 = map_trans(ref_align, phonemap)
for utt in ref_align:
    for ref_unit, hyp_unit in zip(ref_align39[utt], hyp_align[utt]):
        counts[hyp_unit][ref_unit] += 1
aus = list(counts.keys())
phones = set()
for phonecount in counts.values():
    for phone in phonecount:
        phones.add(phone)
phones = list(phones)

M, counts_labels, counts_clusters = align(ref_t_trans, hyp_t_trans, phones, aus)
p_X_given_Y = probability_matrix(M, alpha=alpha)

In [153]:
def probability_matrix(counts, alpha=1):                                           
    c = np.array(counts + alpha, dtype=float)                                      
    return (c.T/c.sum(axis=1)).T

alpha = 1

In [154]:
#M = np.zeros((len(aus), len(phones)))
#for i in range(len(aus)):
#    for j in range(len(phones)):
#        M[i, j] += counts[aus[i]][phones[j]]
#p_X_given_Y = probability_matrix(M, alpha=alpha)

In [155]:
# Estimate the marginal distribution of the reference cluster labels.       
#-----------------------------------------------------------------------    
p_X = np.asarray(M.sum(axis=0), dtype=float) + alpha            
p_X /= p_X.sum()                                                            
p_Y = np.asarray(M.sum(axis=1), dtype=float) + alpha           
p_Y /= p_Y.sum()                                                            

# Evaluate the conditional and marginal entropy.                            
#-----------------------------------------------------------------------    
H_X_given_Y = -(p_Y.dot((p_X_given_Y*np.log2(p_X_given_Y)).sum(axis=1)))    
H_X = -p_X.dot(np.log2(p_X))                                                
H_Y = -p_Y.dot(np.log2(p_Y))                                                

# Evaluate the mutual information between reference labels and clusters.    
#-----------------------------------------------------------------------    
I_XY = H_X - H_X_given_Y                                                    


#print('H(X):', H_X)                                                        
#print('H(Y):', H_Y)                                                        
#print('2*I(X;Y)/(H(X) + H(Y)):', 100*2*I_XY/(H_X + H_Y), '%')              
print('Normalized Mutual Information')                                      
print('-----------------------------')                                      
print('# refs units:', len(phones))                                       
print('# proposed units:', len(aus))                                 
print('I(X;Y)/ H(x) =', 100 * I_XY/(H_X), '%')                             
print('I(X;Y) =', I_XY)                                                   
print('H(Y) =', H_Y)                                                      
print('H(X) =', H_X)                                                      
print('counts =', M.sum())

Normalized Mutual Information
-----------------------------
# refs units: 39
# proposed units: 100
I(X;Y)/ H(x) = 26.05202641409705 %
I(X;Y) = 1.211704920736424
H(Y) = 6.541155353562759
H(X) = 4.651096622874436
counts = 234343.0


In [129]:
def get_durations(trans, max_duration=100):
    current = trans[0]
    durations = np.zeros(max_duration)
    duration = 1
    for token in trans[1:]:
        if token == current:
            duration += 1
        else:
            current = token
            durations[min(duration, max_duration - 1)] += 1
            duration = 1
    return durations

In [236]:
ref = load_transcript('exp/timit/monophone_mbn_babel/align_ac1.0/train/trans')
durations = np.zeros(100)
for utt, trans in ref.items():
    trans = list(filter(lambda a: a != 'sil', trans))
    durations += get_durations(trans, len(durations))
durations /= durations.sum()

hyp = load_transcript('exp/timit/subspace_aud_mbn_babel_ldim100/decode_perframe_ac1.0/train/trans')
hyp_durations = np.zeros(len(durations))
for utt, trans in hyp.items():
    trans = list(filter(lambda a: a != 'sil', trans))
    hyp_durations += get_durations(trans, len(hyp_durations))
hyp_durations /= hyp_durations.sum()

hyp = load_transcript('exp/timit/aud_8g/decode_perframe_ac1.0/train/trans')
#hyp = load_transcript('exp/timit/aud_4g/decode_perframe_ac1.0/train/trans')
hyp_durations2 = np.zeros(len(durations))
for utt, trans in hyp.items():
    trans = list(filter(lambda a: a != 'sil', trans))
    hyp_durations2 += get_durations(trans, len(hyp_durations2))
hyp_durations2 /= hyp_durations2.sum()

fig = figure(x_range=(0, 40))
fig.vbar(x=range(len(durations)), top=durations, width=0.9, alpha=0.5)
fig.vbar(x=range(len(hyp_durations)), top=hyp_durations, width=0.9, alpha=0.5, color='red')
#fig.vbar(x=range(len(hyp_durations2)), top=hyp_durations2, width=0.9, alpha=0.5, color='green')
show(fig)