In [1]:
import os
import numpy as np
from itertools import groupby
from collections import defaultdict
import pickle
import yaml

from bokeh.plotting import output_notebook, figure, show, gridplot
from bokeh.models import ColumnDataSource, Range1d, LabelSet, Label
from bokeh.palettes import Category10 as palette
output_notebook()

In [2]:
mapping = {}
with open('data/timit/lang/phones_61_to_39.txt') as f:
    for line in f:
        tokens = line.strip().split()
        if len(tokens) == 2:
            mapping[tokens[0]] = tokens[1]
            
def map_trans(utt_trans, mapfile):                                                     
    tmp = []                                                                   
    for token in utt_trans:                                                    
        try:                                                                   
            new_token = mapfile[token]                                         
            tmp.append(new_token)                                              
        except KeyError:                                                       
            pass                                                               
    return tmp

In [3]:
counts = defaultdict(int)
with open('data/timit/train/trans') as f:
    for line in f:
        tokens = line.strip().split()[1:]
        tokens =  [x[0] for x in groupby(tokens)] 
        tokens = map_trans(tokens, mapping)
        for token in tokens:
            counts[token] += 1
            
true_counts = np.array([value for _, value in counts.items()])
weights = np.zeros(101)
weights[:len(true_counts)] = np.sort(true_counts)[::-1] 
true_weights = weights / weights.sum()

fig2 = figure(width=400, height=400, title='Mixing weights')
fig2.vbar(range(len(true_weights)), width=.5, top=true_weights)
fig2.xgrid.visible = False
show(gridplot([[fig2]]))

In [4]:
with open('exp/timit/aud_mfcc_4g_dirichlet/30.mdl', 'rb') as f:
    model_d = pickle.load(f)
d_weights = model_d.categorical.mean.sort(descending=True)[0].numpy()
    
with open('exp/timit/aud_mfcc_4g_dirichlet_process/20.mdl', 'rb') as f:
    model_dp = pickle.load(f)
dp_weights = model_dp.categorical.mean.sort(descending=True)[0].numpy()
    
with open('exp/timit/aud_mfcc_4g_gamma_dirichlet_process/20.mdl', 'rb') as f:
    model_gdp = pickle.load(f)
gdp_weights = model_gdp.categorical.mean.sort(descending=True)[0].numpy()

In [5]:
fig = figure(title='Mixing weights')
fig.line(range(len(true_weights)), true_weights, color='black', line_width=2, line_dash='dashed', legend='True')
fig.line(range(len(d_weights)), d_weights, line_width=2, color='blue', legend='Dir')
fig.line(range(len(dp_weights)), dp_weights, line_width=2, color='red', legend='DirProc')
fig.line(range(len(dp_weights)), gdp_weights, line_width=2, color='green', legend='GamDirProc')
fig.xgrid.visible = False
show(fig)

In [218]:
with open('exp/timit/aud_mfcc_4g_dirichlet_process/decode_perframe/train/score/au_phone', 'r') as f:
    for line in f:
        print(line.strip())

sil h#
au49 q
au74 kcl
au100 iy
au68 ih
au8 m
au3 n
au11 n
au71 ae
au50 n
au41 h#
au86 h#
au75 ix
au80 ih
au12 sh
au24 tcl
au9 ix
au47 h#
au21 ay
au23 s
au13 dh
au30 f
au6 aa
au95 aw
au31 tcl
au5 ae
au88 ix
au18 k
au40 kcl
au91 s
au98 h#
au92 w
au58 bcl
au38 ah
au15 s
au14 s
au59 m
au37 n
au22 s
au33 ay
au54 el
au96 n
au46 w
au69 k
au73 pcl
au76 q
au63 h#
au89 ae
au90 ow
au97 ow
au1 n
au70 ey
au34 kcl
au56 t
au61 axr
au67 m
au28 k
au42 er
au66 tcl
au17 s
au48 s
au16 m
au20 r
au57 dcl
au64 p
au93 ih
au39 r
au43 er
au72 n
au36 w
au94 er
au53 s
au35 iy
au7 iy
au52 f
au81 s
au27 ah
au19 tcl
au2 r
au26 h#
au84 iy
au44 ay
au99 s
au62 ey
au4 iy
au82 l
au83 axr
au60 r
au79 l
au45 h#
au87 ae
au25 r
au10 n
au65 t
au78 s
au29 aa
au85 sh
au32 kcl
au51 z
au55 ao
au77 ix


In [209]:
with open('exp/timit/aud_mfcc_4g_dirichlet_process/decode_perframe/train/score/au_phone_counts.yml', 'rb') as f:
    counts = yaml.load(f)
counts
aus = list(counts.keys())
phone_counts = defaultdict(lambda: defaultdict(int))
for au, au_counts in counts.items():
    for phone, pcounts in au_counts.items():
        if phone != 'q':
            phone_counts[mapping[phone]][au] += pcounts

In [210]:
sphones = '''
iy
ih
ey
eh
y
ae
ay
aw
aa
ah
uh
uw
ow
oy
w
l
er
r
m
n
ng
z
s
sh
ch
jh
hh
v
f
dh
th
d
b
dx
g
t
p
k
sil
'''.split()

len(sphones)

39

In [211]:
scounts = {}
#for key, val in counts.items():
#    print(key)

In [212]:
mat = np.zeros((39, 101))
for i, phone in enumerate(sphones):
    for j, au in enumerate(aus):
        mat[i, j] += phone_counts[phone][au]

#mat = mat / (1e-6 + mat.sum(axis=0, keepdims=True))
#mat = np.log(1e-6 + mat)
mat.sum(axis=0).shape

(101,)

In [215]:
sidx = np.argmax(mat, axis=0)
sidx
nmat = np.zeros_like(mat)
nmat[sidx, range(101)] = 1
#mat = nmat

In [216]:
dh, dw = nmat.T.shape
fig = figure(y_range=(0, dh), x_range=(0, dw))
fig.image(image=[nmat.T], x=0, y=0, dh=dh, dw=dw, palette='Viridis256')

ticks = list(range(len(sphones)))
labels = {i: phone for i, phone in enumerate(sphones)}
fig.xaxis.ticker = ticks
fig.xaxis.major_label_overrides = labels

show(fig)

In [86]:
with open('exp/timit/aud_mfcc_4g_gamma_dirichlet_process/decode_perframe/train/score/au_phone_counts.yml', 'rb') as f:
    model = pickle.load(f)
model.categorical

UnpicklingError: unpickling stack underflow

In [84]:
model.categorical.concentration.prior.expected_value()

AttributeError: 'SBCategorical' object has no attribute 'concentration'

In [176]:
def gamma_lh(x, pdf):
    from scipy.special import gammaln
    shape = float(pdf.params.shape)
    rate = float(pdf.params.rate)
    return np.exp(-rate * x + (shape - 1) * np.log(x) - gammaln(shape) + shape * np.log(rate))

weights = model.categorical.mean.sort(descending=True)[0].numpy()
print(len(weights))

fig2 = figure(width=400, height=400, title='Mixing weights')
fig2.line(range(len(true_weights)), true_weights, color='black', line_width=2, line_dash='dashed')
fig2.line(range(len(weights)), weights, line_width=2)
#fig2.xaxis.ticker = list(range(len(weights)))
fig2.xgrid.visible = False

fig3 = figure(width=400, height=400, title='Prior/posterior over the concentration')
c = np.linspace(3, 9, 1000)
prior_lh = gamma_lh(c, model.categorical.concentration.prior)
posterior_lh = gamma_lh(c, model.categorical.concentration.posterior)
fig3.line(c, prior_lh, color='blue')
fig3.line(c, posterior_lh, color='green')

show(gridplot([[fig2, fig3]]))

101


In [124]:
weights

array([8.82951021e-02, 7.06105530e-02, 6.13297299e-02, 5.51275015e-02,
       5.06085865e-02, 4.68704626e-02, 4.33884598e-02, 4.00728099e-02,
       3.71540450e-02, 3.41713205e-02, 3.20717618e-02, 3.01194396e-02,
       2.87494641e-02, 2.73091495e-02, 2.63039600e-02, 2.52284035e-02,
       2.44471952e-02, 2.36020032e-02, 2.29295548e-02, 2.21483167e-02,
       2.12774556e-02, 2.04897653e-02, 1.96956657e-02, 1.86967980e-02,
       1.74995959e-02, 1.61616523e-02, 1.45806530e-02, 1.28334099e-02,
       1.12589560e-02, 9.82530881e-03, 8.50046892e-03, 7.28445407e-03,
       6.23477437e-03, 5.30028064e-03, 4.38513281e-03, 3.54687171e-03,
       2.98337196e-03, 2.40728469e-03, 1.90797821e-03, 1.46631896e-03,
       9.99810174e-04, 6.80246041e-04, 5.00649330e-04, 3.15561716e-04,
       2.00520313e-04, 1.11726273e-04, 7.87333411e-05, 3.55761367e-05,
       1.72952550e-05, 1.06409998e-05, 1.90699848e-06, 1.53165001e-06,
       1.23018151e-06, 9.88048782e-07, 7.93575111e-07, 6.37378264e-07,
      

In [18]:
def wer(r, h):
    """
    Calculation of WER with Levenshtein distance.

    Works only for iterables up to 254 elements (uint8).
    O(nm) time ans space complexity.

    Parameters
    ----------
    r : list
    h : list

    Returns
    -------
    int

    Examples
    --------
    >>> wer("who is there".split(), "is there".split())
    1
    >>> wer("who is there".split(), "".split())
    3
    >>> wer("".split(), "who is there".split())
    3
    """
    # initialisation
    d = np.zeros((len(r)+1)*(len(h)+1), dtype=int)
    d = d.reshape((len(r)+1, len(h)+1))
    for i in range(len(r)+1):
        for j in range(len(h)+1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    # computation
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):
            if r[i-1] == h[j-1]:
                d[i][j] = d[i-1][j-1]
            else:
                substitution = d[i-1][j-1] + 1
                insertion    = d[i][j-1] + 1
                deletion     = d[i-1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)

    return d[len(r)][len(h)]

In [19]:
mapfile = 'data/timit/lang/phones_48_to_39.txt'
phonemap = {}
with open(mapfile, 'r') as f:
    for line in f:
        p1, p2 = line.strip().split()
        phonemap[p1] = p2
phonemap

{'sil': 'sil',
 'aa': 'aa',
 'ae': 'ae',
 'ah': 'ah',
 'ao': 'aa',
 'aw': 'aw',
 'ax': 'ah',
 'ay': 'ay',
 'b': 'b',
 'ch': 'ch',
 'cl': 'sil',
 'd': 'd',
 'dh': 'dh',
 'dx': 'dx',
 'eh': 'eh',
 'el': 'l',
 'en': 'n',
 'epi': 'sil',
 'er': 'er',
 'ey': 'ey',
 'f': 'f',
 'g': 'g',
 'hh': 'hh',
 'ih': 'ih',
 'ix': 'ih',
 'iy': 'iy',
 'jh': 'jh',
 'k': 'k',
 'l': 'l',
 'm': 'm',
 'n': 'n',
 'ng': 'ng',
 'ow': 'ow',
 'oy': 'oy',
 'p': 'p',
 'r': 'r',
 's': 's',
 'sh': 'sh',
 't': 't',
 'th': 'th',
 'uh': 'uh',
 'uw': 'uw',
 'v': 'v',
 'vcl': 'sil',
 'w': 'w',
 'y': 'y',
 'z': 'z',
 'zh': 'sh'}

In [27]:
ref = 'data/timit/test/trans'
hyp = f'exp/timit/monophone_mbn_babel/decode_ac1.0/test/trans'

def load_trans(transfile, mapping=None):
    with open(transfile, 'r') as f:
        utt_trans = {}
        for line in f:
            tokens = line.strip().split()
            uttid, trans = tokens[0], tokens[1:]
            if mapping is not None:
                trans = [mapping[token] for token in trans]
            utt_trans[uttid] = trans
            #utt_trans[uttid] = list(filter(lambda a: a != 'sil', trans))
    return utt_trans

ref_trans = load_trans(ref, phonemap)
ref_trans = {utt: [x[0] for x in groupby(trans)] for utt, trans in ref_trans.items()}
hyp_trans = load_trans(hyp, phonemap)
hyp_trans = {utt: [x[0] for x in groupby(trans)] for utt, trans in hyp_trans.items()}
acc_wer = 0
nwords = 0
for utt in ref_trans:
    ref_t, hyp_t = ref_trans[utt], hyp_trans[utt]
    acc_wer += wer(ref_t, hyp_t)
    nwords += len(ref_t)
print(f'WER = {100 * acc_wer/nwords:.2f} %')

WER = 36.82 %


In [28]:
print(' '.join(ref_trans[utt]))
print(' '.join(hyp_trans[utt]))

sil b ae s sil k ih sil b aa l sil k ih n sil b iy ih n ih n er sil t ey n ih ng s sil p aa r sil
sil b ae s k ih sil b aa k n b iy n ih n ih t hh ng uw ng k s f w r sil


In [88]:
ref = 'data/timit/test/trans'

ldims = [2, 5, 10, 15, 20, 25, 30, 35, 40, 45]
colors = ['blue', 'orange', 'green', 'red', 'brown', 'purple', 'salmon', 'black', 'darkblue', 'grey']
fig = figure(y_range=(0, 100))
dim_wers = {}
for ldim, color in zip(ldims, colors):
    wers = []
    for i in range(0, 31):
        hyp = f'exp/timit/dsubspace_monophone_mbn_babel_ldim{ldim}/decode_e{i}_ac1.0/test/trans'
        if not os.path.isfile(hyp):
            continue
        def load_trans(transfile, mapping=None):
            with open(transfile, 'r') as f:
                utt_trans = {}
                for line in f:
                    tokens = line.strip().split()
                    uttid, trans = tokens[0], tokens[1:]
                    if mapping is not None:
                        trans = [mapping[token] for token in trans]
                    utt_trans[uttid] = list(filter(lambda a: a != 'sil', trans))
            return utt_trans

        ref_trans = load_trans(ref, phonemap)
        hyp_trans = load_trans(hyp, phonemap)

        acc_wer = 0
        nwords = 0
        for utt in ref_trans:
            ref_t, hyp_t = ref_trans[utt], hyp_trans[utt]
            acc_wer += wer(ref_t, hyp_t)
            nwords += len(ref_t)
        wers.append(100 * acc_wer/nwords )
    if len(wers) > 0:
        print(wers[-1])
    dim_wers[ldim] = wers
    fig.line(range(len(wers)), wers, color=color, line_dash='dashed')
    fig.circle(range(len(wers)), wers, color=color)

show(fig)

80.54610341107146
66.21296373143
47.755000414972194
41.75450244833596
37.72097269482945
36.998921072288155
36.13577890281351
35.911693916507595
35.3224333969624
34.69167565773093


In [85]:
ref = 'data/timit/test/trans'

ldims = [2, 5, 10, 15, 20, 25, 30, 35, 40, 45]
colors = ['blue', 'orange', 'green', 'red', 'brown', 'purple', 'salmon', 'black', 'darkblue', 'grey']
fig = figure(y_range=(0, 100))
dim_wers = {}
for ldim, color in zip(ldims, colors):
    wers = []
    for i in range(0, 31):
        hyp = f'exp/timit/subspace_monophone_mbn_babel_ldim{ldim}/decode_e{i}_ac1.0/test/trans'
        if not os.path.isfile(hyp):
            continue
        def load_trans(transfile, mapping=None):
            with open(transfile, 'r') as f:
                utt_trans = {}
                for line in f:
                    tokens = line.strip().split()
                    uttid, trans = tokens[0], tokens[1:]
                    if mapping is not None:
                        trans = [mapping[token] for token in trans]
                    utt_trans[uttid] = list(filter(lambda a: a != 'sil', trans))
            return utt_trans

        ref_trans = load_trans(ref, phonemap)
        hyp_trans = load_trans(hyp, phonemap)

        acc_wer = 0
        nwords = 0
        for utt in ref_trans:
            ref_t, hyp_t = ref_trans[utt], hyp_trans[utt]
            acc_wer += wer(ref_t, hyp_t)
            nwords += len(ref_t)
        wers.append(100 * acc_wer/nwords )
    print(wers[-1])
    dim_wers[ldim] = wers
    fig.line(range(len(wers)), wers, color=color)
    fig.circle(range(len(wers)), wers, color=color)

show(fig)

84.30575151464852
70.08880405012864
51.92132127147481
44.37712673250892
39.52195202921404
37.306000497966636
36.06938335131546
35.46352394389576
35.47182338783301
35.14814507428002


In [50]:
dims = [dim for dim in dim_wers]
ws = [w[-1] for w in dim_wers.values()]
fig = figure(y_range=(0, 100), x_range=(1, 45))
fig.line(dims, ws)
fig.circle(dims, ws)
fig.line(range(46), np.zeros(46) + 36.44, line_dash='dashed', color='black')
show(fig)

In [35]:
print('ref:', ' '.join(ref_trans['fadg0_si1909']))
print('hyp:', ' '.join(hyp_trans['fadg0_si1909']))

ref: f ae sh ow d ih n l uw s r ow l z b ih n iy th ih sh er
hyp: f ae sh ow d ih n l uw s r l f dh ih dx iy th ah sh er


In [20]:
import pickle
with open('exp/timit/monophone_mbn_babel/final.mdl', 'rb') as f:
    model = pickle.load(f)
model.weights.value()

tensor([0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208])

In [167]:
keys = [key for key, val in model.start_pdf.items()]

In [168]:
model.start_pdf.values()

dict_values([0, 5, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 127, 130, 133, 136, 139, 142, 145])

In [169]:
from collections import defaultdict

counts = defaultdict(int)
with open('data/timit/train/trans', 'r') as f:
    for line in f:
        tokens = line.strip().split()[1:]
        for token in tokens:
            counts[token] += 1

In [172]:
import torch
stats = torch.FloatTensor([counts[key] for key in keys])
model.weights.posterior.params.concentrations = model.weights.posterior.params.concentrations + stats 
model._on_weights_update()

In [201]:
sorted(model.weights.expected_value())

[tensor(0.0011),
 tensor(0.0022),
 tensor(0.0036),
 tensor(0.0045),
 tensor(0.0052),
 tensor(0.0053),
 tensor(0.0058),
 tensor(0.0065),
 tensor(0.0068),
 tensor(0.0071),
 tensor(0.0072),
 tensor(0.0085),
 tensor(0.0087),
 tensor(0.0094),
 tensor(0.0118),
 tensor(0.0118),
 tensor(0.0133),
 tensor(0.0133),
 tensor(0.0138),
 tensor(0.0139),
 tensor(0.0142),
 tensor(0.0156),
 tensor(0.0158),
 tensor(0.0158),
 tensor(0.0161),
 tensor(0.0162),
 tensor(0.0162),
 tensor(0.0163),
 tensor(0.0169),
 tensor(0.0173),
 tensor(0.0185),
 tensor(0.0234),
 tensor(0.0254),
 tensor(0.0263),
 tensor(0.0271),
 tensor(0.0278),
 tensor(0.0282),
 tensor(0.0295),
 tensor(0.0303),
 tensor(0.0316),
 tensor(0.0330),
 tensor(0.0334),
 tensor(0.0440),
 tensor(0.0492),
 tensor(0.0515),
 tensor(0.0526),
 tensor(0.0591),
 tensor(0.0893)]

In [174]:
with open('exp/timit/monophone_mbn_babel//final_unigram.mdl', 'wb') as f:
    pickle.dump(model, f)