In [29]:
import os
import numpy as np
from itertools import groupby
import pickle

from bokeh.plotting import output_notebook, figure, show, gridplot
from bokeh.models import ColumnDataSource, Range1d, LabelSet, Label
from bokeh.palettes import Category10 as palette
output_notebook()

In [31]:
with open('exp/timit/subspace_monophone_mbn_babel_ldim15/gsm_0.mdl', 'rb') as f:
    model = pickle.load(f)
model.affine_transform.out_dim

3861

In [18]:
def wer(r, h):
    """
    Calculation of WER with Levenshtein distance.

    Works only for iterables up to 254 elements (uint8).
    O(nm) time ans space complexity.

    Parameters
    ----------
    r : list
    h : list

    Returns
    -------
    int

    Examples
    --------
    >>> wer("who is there".split(), "is there".split())
    1
    >>> wer("who is there".split(), "".split())
    3
    >>> wer("".split(), "who is there".split())
    3
    """
    # initialisation
    d = np.zeros((len(r)+1)*(len(h)+1), dtype=int)
    d = d.reshape((len(r)+1, len(h)+1))
    for i in range(len(r)+1):
        for j in range(len(h)+1):
            if i == 0:
                d[0][j] = j
            elif j == 0:
                d[i][0] = i

    # computation
    for i in range(1, len(r)+1):
        for j in range(1, len(h)+1):
            if r[i-1] == h[j-1]:
                d[i][j] = d[i-1][j-1]
            else:
                substitution = d[i-1][j-1] + 1
                insertion    = d[i][j-1] + 1
                deletion     = d[i-1][j] + 1
                d[i][j] = min(substitution, insertion, deletion)

    return d[len(r)][len(h)]

In [19]:
mapfile = 'data/timit/lang/phones_48_to_39.txt'
phonemap = {}
with open(mapfile, 'r') as f:
    for line in f:
        p1, p2 = line.strip().split()
        phonemap[p1] = p2
phonemap

{'sil': 'sil',
 'aa': 'aa',
 'ae': 'ae',
 'ah': 'ah',
 'ao': 'aa',
 'aw': 'aw',
 'ax': 'ah',
 'ay': 'ay',
 'b': 'b',
 'ch': 'ch',
 'cl': 'sil',
 'd': 'd',
 'dh': 'dh',
 'dx': 'dx',
 'eh': 'eh',
 'el': 'l',
 'en': 'n',
 'epi': 'sil',
 'er': 'er',
 'ey': 'ey',
 'f': 'f',
 'g': 'g',
 'hh': 'hh',
 'ih': 'ih',
 'ix': 'ih',
 'iy': 'iy',
 'jh': 'jh',
 'k': 'k',
 'l': 'l',
 'm': 'm',
 'n': 'n',
 'ng': 'ng',
 'ow': 'ow',
 'oy': 'oy',
 'p': 'p',
 'r': 'r',
 's': 's',
 'sh': 'sh',
 't': 't',
 'th': 'th',
 'uh': 'uh',
 'uw': 'uw',
 'v': 'v',
 'vcl': 'sil',
 'w': 'w',
 'y': 'y',
 'z': 'z',
 'zh': 'sh'}

In [27]:
ref = 'data/timit/test/trans'
hyp = f'exp/timit/monophone_mbn_babel/decode_ac1.0/test/trans'

def load_trans(transfile, mapping=None):
    with open(transfile, 'r') as f:
        utt_trans = {}
        for line in f:
            tokens = line.strip().split()
            uttid, trans = tokens[0], tokens[1:]
            if mapping is not None:
                trans = [mapping[token] for token in trans]
            utt_trans[uttid] = trans
            #utt_trans[uttid] = list(filter(lambda a: a != 'sil', trans))
    return utt_trans

ref_trans = load_trans(ref, phonemap)
ref_trans = {utt: [x[0] for x in groupby(trans)] for utt, trans in ref_trans.items()}
hyp_trans = load_trans(hyp, phonemap)
hyp_trans = {utt: [x[0] for x in groupby(trans)] for utt, trans in hyp_trans.items()}
acc_wer = 0
nwords = 0
for utt in ref_trans:
    ref_t, hyp_t = ref_trans[utt], hyp_trans[utt]
    acc_wer += wer(ref_t, hyp_t)
    nwords += len(ref_t)
print(f'WER = {100 * acc_wer/nwords:.2f} %')

WER = 36.82 %


In [28]:
print(' '.join(ref_trans[utt]))
print(' '.join(hyp_trans[utt]))

sil b ae s sil k ih sil b aa l sil k ih n sil b iy ih n ih n er sil t ey n ih ng s sil p aa r sil
sil b ae s k ih sil b aa k n b iy n ih n ih t hh ng uw ng k s f w r sil


In [88]:
ref = 'data/timit/test/trans'

ldims = [2, 5, 10, 15, 20, 25, 30, 35, 40, 45]
colors = ['blue', 'orange', 'green', 'red', 'brown', 'purple', 'salmon', 'black', 'darkblue', 'grey']
fig = figure(y_range=(0, 100))
dim_wers = {}
for ldim, color in zip(ldims, colors):
    wers = []
    for i in range(0, 31):
        hyp = f'exp/timit/dsubspace_monophone_mbn_babel_ldim{ldim}/decode_e{i}_ac1.0/test/trans'
        if not os.path.isfile(hyp):
            continue
        def load_trans(transfile, mapping=None):
            with open(transfile, 'r') as f:
                utt_trans = {}
                for line in f:
                    tokens = line.strip().split()
                    uttid, trans = tokens[0], tokens[1:]
                    if mapping is not None:
                        trans = [mapping[token] for token in trans]
                    utt_trans[uttid] = list(filter(lambda a: a != 'sil', trans))
            return utt_trans

        ref_trans = load_trans(ref, phonemap)
        hyp_trans = load_trans(hyp, phonemap)

        acc_wer = 0
        nwords = 0
        for utt in ref_trans:
            ref_t, hyp_t = ref_trans[utt], hyp_trans[utt]
            acc_wer += wer(ref_t, hyp_t)
            nwords += len(ref_t)
        wers.append(100 * acc_wer/nwords )
    if len(wers) > 0:
        print(wers[-1])
    dim_wers[ldim] = wers
    fig.line(range(len(wers)), wers, color=color, line_dash='dashed')
    fig.circle(range(len(wers)), wers, color=color)

show(fig)

80.54610341107146
66.21296373143
47.755000414972194
41.75450244833596
37.72097269482945
36.998921072288155
36.13577890281351
35.911693916507595
35.3224333969624
34.69167565773093


In [85]:
ref = 'data/timit/test/trans'

ldims = [2, 5, 10, 15, 20, 25, 30, 35, 40, 45]
colors = ['blue', 'orange', 'green', 'red', 'brown', 'purple', 'salmon', 'black', 'darkblue', 'grey']
fig = figure(y_range=(0, 100))
dim_wers = {}
for ldim, color in zip(ldims, colors):
    wers = []
    for i in range(0, 31):
        hyp = f'exp/timit/subspace_monophone_mbn_babel_ldim{ldim}/decode_e{i}_ac1.0/test/trans'
        if not os.path.isfile(hyp):
            continue
        def load_trans(transfile, mapping=None):
            with open(transfile, 'r') as f:
                utt_trans = {}
                for line in f:
                    tokens = line.strip().split()
                    uttid, trans = tokens[0], tokens[1:]
                    if mapping is not None:
                        trans = [mapping[token] for token in trans]
                    utt_trans[uttid] = list(filter(lambda a: a != 'sil', trans))
            return utt_trans

        ref_trans = load_trans(ref, phonemap)
        hyp_trans = load_trans(hyp, phonemap)

        acc_wer = 0
        nwords = 0
        for utt in ref_trans:
            ref_t, hyp_t = ref_trans[utt], hyp_trans[utt]
            acc_wer += wer(ref_t, hyp_t)
            nwords += len(ref_t)
        wers.append(100 * acc_wer/nwords )
    print(wers[-1])
    dim_wers[ldim] = wers
    fig.line(range(len(wers)), wers, color=color)
    fig.circle(range(len(wers)), wers, color=color)

show(fig)

84.30575151464852
70.08880405012864
51.92132127147481
44.37712673250892
39.52195202921404
37.306000497966636
36.06938335131546
35.46352394389576
35.47182338783301
35.14814507428002


In [50]:
dims = [dim for dim in dim_wers]
ws = [w[-1] for w in dim_wers.values()]
fig = figure(y_range=(0, 100), x_range=(1, 45))
fig.line(dims, ws)
fig.circle(dims, ws)
fig.line(range(46), np.zeros(46) + 36.44, line_dash='dashed', color='black')
show(fig)

In [35]:
print('ref:', ' '.join(ref_trans['fadg0_si1909']))
print('hyp:', ' '.join(hyp_trans['fadg0_si1909']))

ref: f ae sh ow d ih n l uw s r ow l z b ih n iy th ih sh er
hyp: f ae sh ow d ih n l uw s r l f dh ih dx iy th ah sh er


In [20]:
import pickle
with open('exp/timit/monophone_mbn_babel/final.mdl', 'rb') as f:
    model = pickle.load(f)
model.weights.value()

tensor([0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208, 0.0208,
        0.0208, 0.0208, 0.0208])

In [167]:
keys = [key for key, val in model.start_pdf.items()]

In [168]:
model.start_pdf.values()

dict_values([0, 5, 10, 13, 16, 19, 22, 25, 28, 31, 34, 37, 40, 43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73, 76, 79, 82, 85, 88, 91, 94, 97, 100, 103, 106, 109, 112, 115, 118, 121, 124, 127, 130, 133, 136, 139, 142, 145])

In [169]:
from collections import defaultdict

counts = defaultdict(int)
with open('data/timit/train/trans', 'r') as f:
    for line in f:
        tokens = line.strip().split()[1:]
        for token in tokens:
            counts[token] += 1

In [172]:
import torch
stats = torch.FloatTensor([counts[key] for key in keys])
model.weights.posterior.params.concentrations = model.weights.posterior.params.concentrations + stats 
model._on_weights_update()

In [201]:
sorted(model.weights.expected_value())

[tensor(0.0011),
 tensor(0.0022),
 tensor(0.0036),
 tensor(0.0045),
 tensor(0.0052),
 tensor(0.0053),
 tensor(0.0058),
 tensor(0.0065),
 tensor(0.0068),
 tensor(0.0071),
 tensor(0.0072),
 tensor(0.0085),
 tensor(0.0087),
 tensor(0.0094),
 tensor(0.0118),
 tensor(0.0118),
 tensor(0.0133),
 tensor(0.0133),
 tensor(0.0138),
 tensor(0.0139),
 tensor(0.0142),
 tensor(0.0156),
 tensor(0.0158),
 tensor(0.0158),
 tensor(0.0161),
 tensor(0.0162),
 tensor(0.0162),
 tensor(0.0163),
 tensor(0.0169),
 tensor(0.0173),
 tensor(0.0185),
 tensor(0.0234),
 tensor(0.0254),
 tensor(0.0263),
 tensor(0.0271),
 tensor(0.0278),
 tensor(0.0282),
 tensor(0.0295),
 tensor(0.0303),
 tensor(0.0316),
 tensor(0.0330),
 tensor(0.0334),
 tensor(0.0440),
 tensor(0.0492),
 tensor(0.0515),
 tensor(0.0526),
 tensor(0.0591),
 tensor(0.0893)]

In [174]:
with open('exp/timit/monophone_mbn_babel//final_unigram.mdl', 'wb') as f:
    pickle.dump(model, f)