In [5]:
from argparse import ArgumentParser, Namespace
from collections import Counter
import itertools
from pathlib import Path
import pickle
import sys
from typing import List, Tuple

from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm, trange

In [6]:
sns.set_context("talk")

In [7]:
from IPython.display import HTML

In [8]:
%load_ext autoreload
%autoreload 2

In [9]:
sys.path.append(str(Path(".").resolve().parent))
import berp.models.reindexing_regression as rr

In [28]:
natural_language_stimulus_path = "../workflow/heilbron2022/data/stimulus/EleutherAI/gpt-neo-2.7B/n10000/old-man-and-the-sea/run1.pkl"
confusion_path = "../workflow/heilbron2022/data/confusion.npz"
lambda_ = 2.9636077645074517
threshold = 0.8091832825245094

In [29]:
with Path(natural_language_stimulus_path).open("rb") as f:
    stim = pickle.load(f)
confusion = np.load(confusion_path)

In [30]:
assert confusion["phonemes"].tolist() == stim.phonemes

In [31]:
confusion_matrix = torch.tensor(confusion["confusion"])
torch.testing.assert_close(
    confusion_matrix.sum(0),
    torch.ones(confusion_matrix.shape[1]).to(confusion_matrix))

In [32]:
p_candidates = rr.predictive_model(
    stim.p_candidates, stim.candidate_phonemes,
    confusion=confusion_matrix,
    lambda_=torch.tensor(lambda_),
    return_gt_only=False)

In [33]:
# Nested array. Outer list of words and inner list of phonemes per word.
plot_data: List[List[Tuple[str, float]]] = []

# Keep top-k candidates from each posterior predictive.
k = 10
for p_candidates_i, candidate_ids, gt_length in zip(tqdm(p_candidates), stim.candidate_ids.numpy(), stim.word_lengths):
    candidate_strs = [stim.candidate_vocabulary[idx] for idx in candidate_ids]
    plot_data_i = []
    
    # First p_candidates_i column corresponds to prior; rest correspond to incremental posteriors
    assert p_candidates_i.shape[1] == stim.max_n_phonemes + 1
    # Take just the phonemes for which we have ground truth phonemes
    p_candidates_i = p_candidates_i[:, :gt_length + 1].T.numpy()
    
    topk = (-p_candidates_i).argsort(1)[:, :k * 2]
    p_candidates_i = np.take_along_axis(p_candidates_i, topk, 1)
    gt_str = candidate_strs[0]
    
    # print(gt_str, p_candidates_i.shape, topk.shape)

    for t, (p_candidates_t, candidates_t) in enumerate(zip(p_candidates_i, topk)):
        # Merge candidates with same form.
        # HACK
        weights = Counter()
        # print([candidate_strs[idx] for idx in candidates_t])
        for p, idx in zip(p_candidates_t, candidates_t):
            weights[candidate_strs[idx]] += p
        
        plot_data_i.append((gt_str[:t], list(weights.most_common(k))))
        
    plot_data.append(plot_data_i)
    # break

  0%|          | 0/574 [00:00<?, ?it/s]

In [16]:
def to_gt_words(word_ids):
    return ["".join(stim.get_candidate_strs(idx, top_k=1)[0])
            for idx in word_ids]

In [17]:
class PredictiveAnimation:
    
    def __init__(self, plot_data, k=3, threshold=0.7, save_all=None):
        self.plot_data = plot_data
        # Mapping between frame index and plot_data nested index.
        self.plot_idxs_flat = [(i, j) for i in range(len(plot_data)) for j in range(len(plot_data[i]))]

        self.k = k
        self.threshold = threshold
        self.save_all = save_all
        
    def _prepare_figure(self):
        self.fig, self.ax = plt.subplots(dpi=200)

        self.ax.set_xlim((0, 1))
        self.ax.set_xlabel("Evidence for word")
        self.ax.set_ylabel("Word")

        self.bar = self.ax.barh([str(i) * 7 for i in range(self.k)], [0] * self.k)
        self.threshold_line = self.ax.axvline(self.threshold, color="gray", linestyle="--")
        self.prev_label = self.ax.annotate("", (0.49, 1.05), xycoords="axes fraction", ha="right")
        self.incremental_label = self.ax.annotate("abc", (0.5, 1.05), xycoords="axes fraction")
        self.prev_incremental_text = [None]
        self.acc_text = []
        
        plt.tight_layout()
        
    def get_first_frame_for_word_idx(self, word_idx):
        return next(i for i, (word_idx_i, _) in enumerate(self.plot_idxs_flat) if word_idx_i == word_idx)
        
    def animate(self, i):
        token_idx, incremental_idx = self.plot_idxs_flat[i]
        incremental_text, incremental_dist = self.plot_data[token_idx][incremental_idx]
        incremental_text = "".join(incremental_text)
        print(incremental_text)
        print(incremental_dist)

        xticklabels = self.ax.get_xticklabels()
        artists = list(self.bar) + xticklabels

        for j, ((option_j, p_j), bar_j, label_j) in enumerate(zip(incremental_dist, self.bar, xticklabels)):
            bar_j.set_width(p_j)

            # Doesn't work .. do it differently below
            # label_j.set_text(option_j)

        self.ax.set_yticklabels(["".join(option_j) for option_j, _ in incremental_dist])
        self.ax.tick_params("y", width=10)

        if self.prev_incremental_text[0] is not None and not incremental_text.startswith(self.prev_incremental_text[0]):
            print("----", self.prev_incremental_text)
            self.acc_text.append(self.prev_incremental_text[0])
            self.prev_label.set_text(" ".join(self.acc_text))
            artists.append(self.prev_label)

        self.incremental_label.set_text(incremental_text)
        artists.append(self.incremental_label)
        self.prev_incremental_text[0] = incremental_text
        
        if self.save_all is not None:
            plt.savefig(f"{self.save_all}.{i}.pdf")

        return artists
    
    def plot(self, start_frame=None, n_frames=100, **kwargs):
        plt.ioff()
        if start_frame is None: start_frame = 0
        
        default_kwargs = dict(blit=True, repeat=True, interval=750)
        default_kwargs.update(kwargs)
        
        self._prepare_figure()
        anim = FuncAnimation(self.fig, self.animate, frames=list(range(start_frame, start_frame + n_frames)),
                             **default_kwargs)
        return anim
    
    def plot_for_word(self, word_idx=None, num_prior_frames=0, **kwargs):
        start_frame = self.get_first_frame_for_word_idx(word_idx)
        start_frame = max(0, start_frame - num_prior_frames)
        return self.plot(start_frame, n_frames=stim.word_lengths[word_idx].item() + num_prior_frames + 1,
                         **kwargs)

In [18]:
# anim = PredictiveAnimation(plot_data).plot(n_frames=1, start_frame=1)
# # HTML(anim.to_html5_video())
# anim.save("test.mp4", dpi=200)

## Selective plots

In [20]:
!mkdir animations

In [34]:
word_id = 396
name = f"animations/1.{word_id}"
anim = PredictiveAnimation(plot_data, k=5, threshold=threshold, save_all=name).plot_for_word(word_id, num_prior_frames=0, interval=1000)
anim.save(f"{name}.gif", dpi=200)
# plt.close()


[(('w', 'i', 'l'), 0.42574870586395264), (('ʌ',), 0.324208527803421), (('g', 'oʊ'), 0.1473640650510788), (('aɪ',), 0.0741746723651886), (('j', 'u'), 0.014107891358435154), (('t', 'ɛɪ', 'k'), 0.004924137610942125), (('s', 'ʌ', 'm'), 0.0022657027002424), (('k', 'ʌ', 'm'), 0.0018161664484068751), (('ʌ', 'n', 'ʌ', 'ð', 'ɚ'), 0.0010319072753190994), (('m', 'ɛɪ', 'b', 'i'), 0.0008873871411196887)]

[(('w', 'i', 'l'), 0.42574870586395264), (('ʌ',), 0.324208527803421), (('g', 'oʊ'), 0.1473640650510788), (('aɪ',), 0.0741746723651886), (('j', 'u'), 0.014107891358435154), (('t', 'ɛɪ', 'k'), 0.004924137610942125), (('s', 'ʌ', 'm'), 0.0022657027002424), (('k', 'ʌ', 'm'), 0.0018161664484068751), (('ʌ', 'n', 'ʌ', 'ð', 'ɚ'), 0.0010319072753190994), (('m', 'ɛɪ', 'b', 'i'), 0.0008873871411196887)]

[(('w', 'i', 'l'), 0.42574870586395264), (('ʌ',), 0.324208527803421), (('g', 'oʊ'), 0.1473640650510788), (('aɪ',), 0.0741746723651886), (('j', 'u'), 0.014107891358435154), (('t', 'ɛɪ', 'k'), 0.00492413761094

  self.ax.set_yticklabels(["".join(option_j) for option_j, _ in incremental_dist])
  self.ax.set_yticklabels(["".join(option_j) for option_j, _ in incremental_dist])
  self.ax.set_yticklabels(["".join(option_j) for option_j, _ in incremental_dist])


w
[(('w', 'i', 'l'), 0.8524951338768005), (('ʌ',), 0.07411392778158188), (('g', 'oʊ'), 0.04256400465965271), (('aɪ',), 0.016956297680735588), (('j', 'u'), 0.009151800535619259), (('t', 'ɛɪ', 'k'), 0.001125655835494399), (('w', 'ɔ', 'k'), 0.0009761854889802635), (('s', 'ʌ', 'm'), 0.0005179386353120208), (('m', 'ɛɪ', 'b', 'i'), 0.0004942223895341158), (('k', 'ʌ', 'm'), 0.00041517484351061285)]
wi
[(('w', 'i', 'l'), 0.9726270437240601), (('ʌ',), 0.01707884669303894), (('aɪ',), 0.003907416947185993), (('g', 'oʊ'), 0.0038485664408653975), (('j', 'u'), 0.0019138467032462358), (('t', 'ɛɪ', 'k'), 0.00014745366934221238), (('w', 'ɔ', 'k'), 8.826501289149746e-05), (('m', 'ɛɪ', 'b', 'i'), 6.473996472777799e-05), (('s', 'ʌ', 'm'), 4.683117367676459e-05), (('h', 'ɛ', 'l', 'p'), 4.3093630665680394e-05)]


  self.ax.set_yticklabels(["".join(option_j) for option_j, _ in incremental_dist])
  self.ax.set_yticklabels(["".join(option_j) for option_j, _ in incremental_dist])


wil
[(('w', 'i', 'l'), 0.993059515953064), (('ʌ',), 0.004318417515605688), (('aɪ',), 0.0009879975114017725), (('g', 'oʊ'), 0.0009731180034577847), (('j', 'u'), 0.0004839205648750067), (('h', 'ɛ', 'l', 'p'), 4.3998967157676816e-05), (('t', 'ɛɪ', 'k'), 2.2038781025912613e-05), (('s', 'ʌ', 'm'), 1.5598801837768406e-05), (('m', 'ɛɪ', 'b', 'i'), 1.4402146007341798e-05), (('w', 'ɔ', 'k'), 1.3192302503739484e-05)]


  self.ax.set_yticklabels(["".join(option_j) for option_j, _ in incremental_dist])


### MAP due to high-prior

In [17]:
high_prior_words = [(i, word_dat[-1][0], word_dat)
                    for i, word_dat in enumerate(plot_data)
                    if word_dat[0][1][0][0] == word_dat[-1][0]
                    # exclude words which are high prior for uninteresting reasons (because the prior space is not big enough .. need top-p sampling)
                    and len(word_dat[-1][0]) < 5]

In [21]:
sorted(high_prior_words, key=lambda x: -len(x[1]))[2]

(3,
 ('m', 'æ', 'n'),
 [((),
   [(('m', 'æ', 'n'), 0.4217202663421631),
    (('f', 'ɹ', 'ɛ', 'n', 'd'), 0.06591551005840302),
    (('l', 'ɛɪ', 'd', 'i'), 0.051039230078458786),
    (('g', 'aɪ'), 0.043623536825180054),
    (('w', 'ʊ', 'm', 'ʌ', 'n'), 0.033722903579473495),
    (('s', 'k', 'u', 'l'), 0.029956839978694916),
    (('b', 'ɔɪ'), 0.020579513162374496),
    (('dʒ', 'ɛ', 'n', 't', 'ʌ', 'l', 'm', 'ʌ', 'n'), 0.013390779495239258),
    (('ʌ', 'n', 'd'), 0.007232083007693291),
    (('ʌ', 'k', 'w', 'ɛɪ', 'n', 't', 'ʌ', 'n', 's'), 0.006683813873678446)]),
  (('m',),
   [(('m', 'æ', 'n'), 0.9481782205402851),
    (('m', 'ʌ', 'ð', 'ɚ'), 0.005373214837163687),
    (('m', 'ɛ', 'm', 'b', 'ɚ'), 0.003540197154507041),
    (('l', 'ɛɪ', 'd', 'i'), 0.0030101430602371693),
    (('m', 'ɑ', 'm'), 0.0015233983285725117),
    (('m', 'æ', 'n', 'ʌ', 'dʒ', 'ɚ'), 0.0014964542351663113),
    (('m', 'aɪ', 'n', 'ɚ'), 0.001482674852013588),
    (('m', 'ɑ', 'd', 'ʌ', 'l'), 0.0014688374940305948),
    (('m', 

In [22]:
# Plot up through the next word
anim = PredictiveAnimation(plot_data, k=5, threshold=threshold).plot_for_word(3, num_prior_frames=13, interval=1000)
anim.save("high_prior.gif", dpi=200)
plt.close()


[(('ð', 'ʌ'), 0.0852500144392252), (('ʌ',), 0.03402941208332777), (('aɪ',), 0.027632102370262146), (('ɪ', 't'), 0.01590776862576604), (('ð', 'ɪ', 's'), 0.015373277477920055), (('ɛ', 's'), 0.00927803386002779), (('ɪ', 'n'), 0.008502335287630558), (('w', 'i'), 0.008139224722981453), (('dʒ', 'i'), 0.007861110381782055), (('d', 'i'), 0.00779521930962801)]

[(('ð', 'ʌ'), 0.0852500144392252), (('ʌ',), 0.03402941208332777), (('aɪ',), 0.027632102370262146), (('ɪ', 't'), 0.01590776862576604), (('ð', 'ɪ', 's'), 0.015373277477920055), (('ɛ', 's'), 0.00927803386002779), (('ɪ', 'n'), 0.008502335287630558), (('w', 'i'), 0.008139224722981453), (('dʒ', 'i'), 0.007861110381782055), (('d', 'i'), 0.00779521930962801)]

[(('ð', 'ʌ'), 0.0852500144392252), (('ʌ',), 0.03402941208332777), (('aɪ',), 0.027632102370262146), (('ɪ', 't'), 0.01590776862576604), (('ð', 'ɪ', 's'), 0.015373277477920055), (('ɛ', 's'), 0.00927803386002779), (('ɪ', 'n'), 0.008502335287630558), (('w', 'i'), 0.008139224722981453), (('dʒ',

  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])
  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])


wʌz
[(('w', 'ʌ', 'z'), 0.9609295725822449), (('w', 'ʌ', 'n'), 0.022430835524573922), (('w', 'ʌ', 't'), 0.006781782256439328), (('w', 'ʌ', 'n', 's'), 0.004500413779169321), (('w', 'ɛ', 'n'), 0.0006285113922785968), (('ð', 'ʌ'), 0.0005701320769730955), (('w', 'ɛ', 'ð', 'ɚ'), 0.00043271741014905274), (('w', 'ɛ', 's', 't'), 0.000376115640392527), (('w', 'ɔ', 'ɹ'), 0.00031992627191357315), (('d', 'ʌ', 'z'), 0.00027577526634559035)]

[(('ʌ',), 0.07272014021873474), (('ɪ', 'n'), 0.04130559414625168), (('ð', 'ʌ'), 0.04056883603334427), (('dʒ', 'ʌ', 's', 't'), 0.027373705059289932), (('v', 'ɛ', 'ɹ', 'i'), 0.020236529409885406), (('ɑ', 'n'), 0.020017100498080254), (('n', 'ɑ', 't'), 0.018870579078793526), (('b', 'ɔ', 'ɹ', 'n'), 0.01646537147462368), (('w', 'ʌ', 'n'), 0.014444197528064251), (('g', 'oʊ', 'ɪ', 'ŋ'), 0.014190247282385826)]
---- ['wʌz']
ʌ
[(('ʌ',), 0.5425959825515747), (('ʌ', 'n'), 0.09217236936092377), (('ʌ', 'b', 'aʊ', 't'), 0.04265241324901581), (('ɑ', 'n'), 0.024026352912187576), 

### Low-prior but high posterior

In [23]:
not_in_top_k_for_n = 3
low_prior_high_post_words = [
    (i, word_dat[-1][0], word_dat)
     for i, word_dat in enumerate(tqdm(plot_data))
     if len(word_dat) >= 3
     # not in top K for first two phonemes
     and not any(opt_name == word_dat[-1][0] for opt_name, _ in itertools.chain.from_iterable(word_dat[j][1] for j in range(not_in_top_k_for_n)))
     and any(opt_name == word_dat[-1][0] for opt_name, _ in word_dat[not_in_top_k_for_n][1])
     # exclude words which are high prior for uninteresting reasons (because the prior space is not big enough .. need top-p sampling)
     # and len(word_dat[-1][0]) < 5
]

  0%|          | 0/574 [00:00<?, ?it/s]

In [24]:
len(low_prior_high_post_words)

16

In [30]:
low_prior_high_post_words[2]

(134,
 ('p', 'æ', 'tʃ', 't'),
 [((),
   [(('ʌ',), 0.05641583725810051),
    (('s', 'oʊ'), 0.05489831790328026),
    (('n', 'ɑ', 't'), 0.046705830842256546),
    (('ð', 'ʌ'), 0.030566466972231865),
    (('ɔ', 'l', 'w', 'ɛɪ', 'z'), 0.02890135534107685),
    (('t', 'u'), 0.02587797213345766),
    (('m', 'ɛɪ', 'd'), 0.01811293326318264),
    (('v', 'ɛ', 'ɹ', 'i'), 0.016279121860861778),
    (('ɪ', 'n'), 0.01607707142829895),
    (('f', 'ʊ', 'l'), 0.015482651069760323)]),
  (('p',),
   [(('p', 'ʊ', 'l', 'd'), 0.06827444583177567),
    (('t', 'u'), 0.038045672699809074),
    (('p', 'ʊ', 't'), 0.030255716294050217),
    (('k', 'æ', 'ɹ', 'i', 'd'), 0.029049573466181755),
    (('k', 'ɔ', 'l', 'd'), 0.02794480137526989),
    (('p', 'ɪ', 'k', 't'), 0.02537950873374939),
    (('f', 'ʊ', 'l'), 0.023123836144804955),
    (('p', 'ɹ', 'ɑ', 'b', 'ʌ', 'b', 'l', 'i'), 0.022180015221238136),
    (('h', 'ɪ', 'z'), 0.020394008606672287),
    (('p', 'ɹ', 'ɪ', 't', 'i'), 0.019156871363520622)]),
  (('p', 'æ')

In [33]:
idx = 134

# Plot up through the next word
anim = PredictiveAnimation(plot_data, k=5, threshold=threshold).plot_for_word(idx, num_prior_frames=21, interval=1000)
anim.save("low_prior.gif", dpi=200)
plt.close()

anim = PredictiveAnimation(plot_data, k=5, threshold=threshold).plot_for_word(idx, num_prior_frames=0, interval=2000)
anim.save("low_prior.no_context.gif", dpi=200)
plt.close()

ɚaʊn
[(('ɚ', 'aʊ', 'n', 'd'), 0.9985718131065369), (('d', 'aʊ', 'n'), 0.0005760067142546177), (('ɹ', 'aʊ', 'n', 'd'), 0.0004918471095152199), (('ɚ',), 0.0003080114838667214), (('w', 'ɛ', 'n'), 1.725154106679838e-05), (('d', 'aʊ', 'n', 'w', 'ɚ', 'd'), 9.466271876590326e-06), (('d', 'aʊ', 'n', 'w', 'ɚ', 'd', 'z'), 8.817076377454214e-06), (('s', 'aʊ', 'n', 'd'), 2.941679440482403e-06), (('ʌ', 'n', 'd'), 2.191731937273289e-06), (('ð', 'ɛ', 'n'), 1.462873683522048e-06)]
ɚaʊn
[(('ɚ', 'aʊ', 'n', 'd'), 0.9985718131065369), (('d', 'aʊ', 'n'), 0.0005760067142546177), (('ɹ', 'aʊ', 'n', 'd'), 0.0004918471095152199), (('ɚ',), 0.0003080114838667214), (('w', 'ɛ', 'n'), 1.725154106679838e-05), (('d', 'aʊ', 'n', 'w', 'ɚ', 'd'), 9.466271876590326e-06), (('d', 'aʊ', 'n', 'w', 'ɚ', 'd', 'z'), 8.817076377454214e-06), (('s', 'aʊ', 'n', 'd'), 2.941679440482403e-06), (('ʌ', 'n', 'd'), 2.191731937273289e-06), (('ð', 'ɛ', 'n'), 1.462873683522048e-06)]
ɚaʊn
[(('ɚ', 'aʊ', 'n', 'd'), 0.9985718131065369), (('d', 'a

  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])
  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])



[(('h', 'ɪ', 'm'), 0.35676923394203186), (('ð', 'ʌ'), 0.2853933274745941), (('h', 'ɪ', 'z'), 0.16253718733787537), (('ɪ', 't'), 0.024948526173830032), (('ʌ', 'n', 'd'), 0.022330213338136673), (('ʌ',), 0.01559661515057087), (('ɪ', 'n'), 0.011347698979079723), (('ð', 'ɛ', 'm'), 0.01005197037011385), (('w', 'ɪ', 'θ'), 0.008790658786892891), (('t', 'u'), 0.006895551458001137)]
---- ['ɚaʊnd']
ð
[(('ð', 'ʌ'), 0.9068081378936768), (('ð', 'ɛ', 'm'), 0.031939100474119186), (('ð', 'æ', 't'), 0.014699632301926613), (('h', 'ɪ', 'm'), 0.012525944970548153), (('ð', 'ɪ', 's'), 0.007348621264100075), (('ð', 'ɛ', 'ɹ'), 0.006314359838142991), (('h', 'ɪ', 'z'), 0.0057065789587795734), (('b', 'aɪ'), 0.0028576357290148735), (('l', 'aɪ', 'k'), 0.0015653609298169613), (('ð', 'i', 'z'), 0.0014108422910794616)]
ðʌ
[(('ð', 'ʌ'), 0.9970113378967653), (('ð', 'æ', 't'), 0.0011997428955510259), (('ð', 'ɛ', 'm'), 0.0007874651346355677), (('b', 'ʌ', 't'), 0.0003107173542957753), (('ð', 'ɛ', 'ɹ'), 0.00015568183152936

  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])
  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])


pætʃt
[(('p', 'æ', 'tʃ', 't'), 0.8327721357345581), (('p', 'æ', 'k', 't'), 0.10421884059906006), (('p', 'ɪ', 'tʃ', 't'), 0.03480057790875435), (('p', 'æ', 's', 't'), 0.01654942985624075), (('f', 'æ', 's', 't'), 0.0052461703307926655), (('k', 'æ', 's', 't'), 0.001497190329246223), (('p', 'ɪ', 'k', 't'), 0.0009914073161780834), (('k', 'ɛ', 'p', 't'), 0.0005902679404243827), (('k', 'æ', 'tʃ', 'ɪ', 'ŋ'), 0.0005129128694534302), (('p', 'ɛɪ', 'n', 't', 'ɪ', 'd'), 0.0004817460721824318)]


## Index candidates by feature

### Find low-prior ground truth words

In [267]:
k = 5
low_prior_word_ids = stim.p_candidates[:, 0].argsort()[:k]
low_prior_words = to_gt_words(low_prior_word_ids)
low_prior_words

['sælɑu', 'bʌnɛvʌlʌnt', 'sæntiɑgoʊ', 'fɪʃlʌs', 'mɑɹlʌn']

### Find words which are still low-posterior at phoneme $k$

In [268]:
k = 3
n = 20

# Only include words which are still running at phoneme $k$
mask = stim.word_lengths > k
p_candidates_masked = p_candidates.clone()
p_candidates_masked[~mask] = torch.inf

low_posterior_word_ids = p_candidates_masked[:, 0, k].argsort()[:n]
low_posterior_words = to_gt_words(low_posterior_word_ids)
list(zip(low_posterior_word_ids, low_posterior_words))

[(tensor(165), 'bʌnɛvʌlʌnt'),
 (tensor(244), 'ʌndɪfitɪd'),
 (tensor(205), 'kɔɹdz'),
 (tensor(486), 'stægɚɪŋ'),
 (tensor(537), 'tækʌl'),
 (tensor(114), 'laɪnz'),
 (tensor(57), 'sælɑu'),
 (tensor(28), 'fɔɹti'),
 (tensor(491), 'plæŋk'),
 (tensor(138), 'lʊkt'),
 (tensor(55), 'dɛfʌnʌtli'),
 (tensor(166), 'skɪn'),
 (tensor(319), 'ɹɪmɛmbɚ'),
 (tensor(396), 'fɪʃɚmɪn'),
 (tensor(395), 'bɪtwin'),
 (tensor(193), 'hænz'),
 (tensor(56), 'faɪnʌli'),
 (tensor(556), 'sɔltɪŋ'),
 (tensor(442), 'dɛpθs'),
 (tensor(472), 'mɑɹlʌn')]

In [269]:
p_candidates[low_posterior_word_ids, 0, k]

tensor([0.0002, 0.0006, 0.0030, 0.0034, 0.0052, 0.0056, 0.0066, 0.0097, 0.0107,
        0.0128, 0.0159, 0.0235, 0.0285, 0.0289, 0.0338, 0.0364, 0.0424, 0.0493,
        0.0569, 0.0582])

In [270]:
anim_word_idx(1338)

NameError: name 'anim_word_idx' is not defined

### Find words which are not the MAP at phoneme $k$

In [61]:
p_candidates.shape

torch.Size([2187, 1000, 16])

In [69]:
k = 5

# Only include words which are still running at phoneme $k$
mask = stim.word_lengths > k
p_candidates_masked = p_candidates.clone()
p_candidates_masked[~mask] = -torch.inf
p_candidates_argmax = p_candidates_masked[:, k].argmax(dim=1)

not_argmax_idxs = torch.where(p_candidates_argmax != 0)[0]
list(zip(not_argmax_idxs, to_gt_words(not_argmax_idxs)))

[(tensor(15), 'mojstə'),
 (tensor(27), '#dipər'),
 (tensor(37), 'heləbul'),
 (tensor(63), 'zemənsə#x'),
 (tensor(89), 'bladrə'),
 (tensor(114), '#ɑləmal#x#'),
 (tensor(175), '#idərə'),
 (tensor(228), '#ustərs'),
 (tensor(233), 'ɑndərə'),
 (tensor(238), 'zɛs#ɪs'),
 (tensor(239), 'vɛrdər'),
 (tensor(240), 'vərdində'),
 (tensor(264), 'ɑləmal#r'),
 (tensor(282), 'merə#x'),
 (tensor(456), '#pʏrpərən'),
 (tensor(506), '#maktə'),
 (tensor(529), 'xəkert#xə'),
 (tensor(592), 'xəkomə'),
 (tensor(628), 'bəwox#xə'),
 (tensor(655), 'lœkərst'),
 (tensor(865), 'vərlɑŋənt'),
 (tensor(869), 'dəxenə'),
 (tensor(898), 'kɛikə#h#'),
 (tensor(950), 'zwɔm#h'),
 (tensor(967), 'zemermɪnətjə'),
 (tensor(974), '#œyjtstrɛktə'),
 (tensor(980), 'vɛiftin'),
 (tensor(990), 'hɔndərt'),
 (tensor(996), 'herləkstə'),
 (tensor(1007), '#kɑlmə'),
 (tensor(1010), 'lɪxə#xə'),
 (tensor(1043), 'rɛijtœyjxən'),
 (tensor(1082), 'darnax'),
 (tensor(1357), 'hɔndərt'),
 (tensor(1359), 'lɑntarnsh'),
 (tensor(1423), 'zɛstin'),
 (tensor

In [71]:
anim_word_idx(37)

h
[(0.2843737857954816, 'hel'), (0.11503773708401324, 'helə'), (0.04161625407498444, 'ht'), (0.030969074349699508, 'ht'), (0.027556656173453757, 'heləbul'), (0.02208031511063773, 'hɑrt'), (0.01865259578798935, 'h'), (0.014280507235832797, 'ha'), (0.012843420209193216, 'hɑnt'), (0.010211893493392134, 'hoxə')]
h
[(0.2843737857954816, 'hel'), (0.11503773708401324, 'helə'), (0.04161625407498444, 'ht'), (0.030969074349699508, 'ht'), (0.027556656173453757, 'heləbul'), (0.02208031511063773, 'hɑrt'), (0.01865259578798935, 'h'), (0.014280507235832797, 'ha'), (0.012843420209193216, 'hɑnt'), (0.010211893493392134, 'hoxə')]
h
[(0.2843737857954816, 'hel'), (0.11503773708401324, 'helə'), (0.04161625407498444, 'ht'), (0.030969074349699508, 'ht'), (0.027556656173453757, 'heləbul'), (0.02208031511063773, 'hɑrt'), (0.01865259578798935, 'h'), (0.014280507235832797, 'ha'), (0.012843420209193216, 'hɑnt'), (0.010211893493392134, 'hoxə')]
he
[(0.6092247775663425, 'hel'), (0.24644971965575924, 'helə'), (0.059

  self.ax.set_yticklabels([option_i for _, option_i in incremental_dist])
  self.ax.set_yticklabels([option_i for _, option_i in incremental_dist])


helə
[(0.7967883806543204, 'helə'), (0.19086626706381407, 'heləbul'), (0.009561477370025783, 'hel'), (0.0007058671381728866, 'hɛiləxə'), (0.0005689674464068114, 'hekəl'), (0.0004017861372421666, 'hɛiləx'), (0.00020622559791568684, 'heməl'), (0.00010486451961581173, 'hetə'), (6.861690441805815e-05, 'relixjœzə'), (6.518453824514022e-05, 'penəs')]
heləb
[(0.9925745083818459, 'heləbul'), (0.0073079214554912335, 'helə'), (8.76952115707618e-05, 'hel'), (6.47401710293987e-06, 'hɛiləxə'), (5.218411199292192e-06, 'hekəl'), (3.685070721613382e-06, 'hɛiləx'), (2.3126329349481686e-06, 'xəlʏk'), (1.8914438366206418e-06, 'heməl'), (1.1957089163794259e-06, 'penəs'), (9.617882130646226e-07, 'hetə')]
heləbu
[(0.9999942516993445, 'heləbul'), (5.628860856788864e-06, 'helə'), (6.75464489793558e-08, 'hel'), (1.99462140795383e-08, 'hɛiləxə'), (5.695692165012416e-09, 'relixjœzə'), (4.6872222371597065e-09, 'relixjœs'), (4.0194343512344055e-09, 'hekəl'), (2.838392621721737e-09, 'hɛiləx'), (1.7812847446340663e-

In [34]:
p_candidates[:, 0, 3].argmin()

tensor(1824)

In [35]:
p_candidates[1824, 0, :]

tensor([4.2436e-06, 5.6206e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06,
        5.6396e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06,
        5.6396e-06, 5.6396e-06], dtype=torch.float64)

In [39]:
stim.get_candidate_strs(1824, 1)

['ɑls']

In [32]:
p_candidates[:, 0, 0].min()

tensor(9.5920e-07, dtype=torch.float64)