In [259]:
from argparse import ArgumentParser, Namespace
from collections import Counter
import itertools
from pathlib import Path
import pickle
import sys
from typing import List, Tuple

from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from tqdm.auto import tqdm, trange

In [2]:
sns.set_context("talk")

In [3]:
from IPython.display import HTML

In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
sys.path.append(str(Path(".").resolve().parent))
import berp.models.reindexing_regression as rr

In [197]:
natural_language_stimulus_path = "heilbron2022/old-man-and-the-sea/run1.pkl"
confusion_path = "heilbron2022/confusion.npz"
lambda_ = 1.0
threshold = 0.25

In [198]:
with Path(natural_language_stimulus_path).open("rb") as f:
    stim = pickle.load(f)
confusion = np.load(confusion_path)

In [199]:
assert confusion["phonemes"].tolist() == stim.phonemes

In [243]:
confusion_matrix = torch.tensor(confusion["confusion"])
torch.testing.assert_close(
    confusion_matrix.sum(0),
    torch.ones(confusion_matrix.shape[1]).to(confusion_matrix))

In [201]:
p_candidates = rr.predictive_model(
    stim.p_candidates, stim.candidate_phonemes,
    confusion=confusion_matrix,
    lambda_=torch.tensor(lambda_),
    return_gt_only=False)

In [234]:
# Nested array. Outer list of words and inner list of phonemes per word.
plot_data: List[List[Tuple[str, float]]] = []

# Keep top-k candidates from each posterior predictive.
k = 10
for p_candidates_i, candidate_ids, gt_length in zip(tqdm(p_candidates), stim.candidate_ids.numpy(), stim.word_lengths):
    candidate_strs = [stim.candidate_vocabulary[idx] for idx in candidate_ids]
    plot_data_i = []
    
    # First p_candidates_i column corresponds to prior; rest correspond to incremental posteriors
    assert p_candidates_i.shape[1] == stim.max_n_phonemes + 1
    # Take just the phonemes for which we have ground truth phonemes
    p_candidates_i = p_candidates_i[:, :gt_length + 1].T.numpy()
    
    topk = (-p_candidates_i).argsort(1)[:, :k * 2]
    p_candidates_i = np.take_along_axis(p_candidates_i, topk, 1)
    gt_str = candidate_strs[0]
    
    # print(gt_str, p_candidates_i.shape, topk.shape)

    for t, (p_candidates_t, candidates_t) in enumerate(zip(p_candidates_i, topk)):
        # Merge candidates with same form.
        # HACK
        weights = Counter()
        # print([candidate_strs[idx] for idx in candidates_t])
        for p, idx in zip(p_candidates_t, candidates_t):
            weights[candidate_strs[idx]] += p
        
        plot_data_i.append((gt_str[:t], list(weights.most_common(k))))
        
    plot_data.append(plot_data_i)
    # break

  0%|          | 0/557 [00:00<?, ?it/s]

In [235]:
plot_data[:3]

[[((),
   [(('ð', 'ʌ'), 0.0852500144392252),
    (('ʌ',), 0.03402941208332777),
    (('aɪ',), 0.027632102370262146),
    (('ɪ', 't'), 0.01590776862576604),
    (('ð', 'ɪ', 's'), 0.015373277477920055),
    (('ɛ', 's'), 0.00927803386002779),
    (('ɪ', 'n'), 0.008502335287630558),
    (('w', 'i'), 0.008139224722981453),
    (('dʒ', 'i'), 0.007861110381782055),
    (('d', 'i'), 0.00779521930962801)]),
  (('w',),
   [(('w', 'i'), 0.1796080507338047),
    (('w', 'ɛ', 'n'), 0.08992338925600052),
    (('w', 'ɪ', 'ð'), 0.07762743532657623),
    (('w', 'ʌ', 't'), 0.0761676374822855),
    (('w', 'ʌ', 'n'), 0.06388868764042854),
    (('w', 'aɪ'), 0.062180422246456146),
    (('w', 'ʌ', 'z'), 0.030637424439191818),
    (('w', 'aɪ', 'l'), 0.02580820769071579),
    (('w', 'ɛ', 'l'), 0.015921398997306824),
    (('w', 'ɑ', 'tʃ'), 0.015686525031924248)]),
  (('w', 'ʌ'),
   [(('w', 'ʌ', 't'), 0.3326660320162773),
    (('w', 'ʌ', 'n'), 0.314370721578598),
    (('w', 'ʌ', 'z'), 0.2044962290674448),
    (('

In [244]:
def to_gt_words(word_ids):
    return ["".join(stim.get_candidate_strs(idx, top_k=1)[0])
            for idx in word_ids]

In [245]:
" ".join(to_gt_words(np.arange(50)))

'wʌz ʌn oʊld mæn hu fɪʃt ʌloʊn ɪn ʌ skɪf ɪn ðʌ gʌlf stɹim ʌnd hi hæd gɔn ɛɪti fɔɹ dɛɪz naʊ wɪðaʊt tɛɪkɪŋ ʌ fɪʃ ɪn fɚst fɔɹti dɛɪz ʌ bɔɪ hæd bɪn wɪθ hɪm bʌt æftɚ fɔɹti dɛɪz wɪðaʊt ʌ fɪʃ ðʌ bɔɪz pɛɹʌnts hæd toʊld ɪm ðæt'

In [246]:
" ".join(to_gt_words(np.arange(160, 170)))

'nɛk ðʌ bɹaʊn blɑttʃʌz ʌv bʌnɛvʌlʌnt skɪn kænsɚ ðʌ sʌn'

In [247]:
class PredictiveAnimation:
    
    def __init__(self, plot_data, k=3, threshold=0.7):
        self.plot_data = plot_data
        # Mapping between frame index and plot_data nested index.
        self.plot_idxs_flat = [(i, j) for i in range(len(plot_data)) for j in range(len(plot_data[i]))]

        self.k = k
        self.threshold = threshold
        
    def _prepare_figure(self):
        self.fig, self.ax = plt.subplots(dpi=200)

        self.ax.set_xlim((0, 1))
        self.ax.set_xlabel("Evidence for word")
        self.ax.set_ylabel("Word")

        self.bar = self.ax.barh([str(i) * 7 for i in range(self.k)], [0] * self.k)
        self.threshold_line = self.ax.axvline(self.threshold, color="gray", linestyle="--")
        self.prev_label = self.ax.annotate("", (0.49, 1.05), xycoords="axes fraction", ha="right")
        self.incremental_label = self.ax.annotate("abc", (0.5, 1.05), xycoords="axes fraction")
        self.prev_incremental_text = [None]
        self.acc_text = []
        
        plt.tight_layout()
        
    def get_first_frame_for_word_idx(self, word_idx):
        return next(i for i, (word_idx_i, _) in enumerate(self.plot_idxs_flat) if word_idx_i == word_idx)
        
    def animate(self, i):
        token_idx, incremental_idx = self.plot_idxs_flat[i]
        incremental_text, incremental_dist = self.plot_data[token_idx][incremental_idx]
        incremental_text = "".join(incremental_text)
        print(incremental_text)
        print(incremental_dist)

        xticklabels = self.ax.get_xticklabels()
        artists = list(self.bar) + xticklabels

        for i, ((option_i, p_i), bar_i, label_i) in enumerate(zip(incremental_dist, self.bar, xticklabels)):
            bar_i.set_width(p_i)

            # Doesn't work .. do it differently below
            # label_i.set_text(option_i)

        self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])
        self.ax.tick_params("y", width=10)

        if self.prev_incremental_text[0] is not None and not incremental_text.startswith(self.prev_incremental_text[0]):
            print("----", self.prev_incremental_text)
            self.acc_text.append(self.prev_incremental_text[0])
            self.prev_label.set_text(" ".join(self.acc_text))
            artists.append(self.prev_label)

        self.incremental_label.set_text(incremental_text)
        artists.append(self.incremental_label)
        self.prev_incremental_text[0] = incremental_text

        return artists
    
    def plot(self, start_frame=None, n_frames=100, **kwargs):
        plt.ioff()
        if start_frame is None: start_frame = 0
        
        default_kwargs = dict(blit=True, repeat=True, interval=750)
        default_kwargs.update(kwargs)
        
        self._prepare_figure()
        anim = FuncAnimation(self.fig, self.animate, frames=list(range(start_frame, start_frame + n_frames)),
                             **default_kwargs)
        return anim
    
    def plot_for_word(self, word_idx=None, num_prior_frames=0, **kwargs):
        start_frame = self.get_first_frame_for_word_idx(word_idx)
        start_frame = max(0, start_frame - num_prior_frames)
        return self.plot(start_frame, n_frames=stim.word_lengths[word_idx].item() + num_prior_frames + 1,
                         **kwargs)

In [248]:
# anim = PredictiveAnimation(plot_data).plot(n_frames=1, start_frame=1)
# # HTML(anim.to_html5_video())
# anim.save("test.mp4", dpi=200)

### MAP due to high-prior

In [249]:
high_prior_words = [(i, word_dat[-1][0], word_dat)
                    for i, word_dat in enumerate(plot_data)
                    if word_dat[0][1][0][0] == word_dat[-1][0]
                    # exclude words which are high prior for uninteresting reasons (because the prior space is not big enough .. need top-p sampling)
                    and len(word_dat[-1][0]) < 5]

In [250]:
sorted(high_prior_words, key=lambda x: -len(x[1]))[1]

(14,
 ('ʌ', 'n', 'd'),
 [((),
   [(('ʌ', 'n', 'd'), 0.16660763323307037),
    (('ɪ', 'n'), 0.10714977979660034),
    (('w', 'ɛ', 'n'), 0.03971933200955391),
    (('ɑ', 'n'), 0.03572690114378929),
    (('ʌ', 'v'), 0.028572333976626396),
    (('n', 'ɪ', 'ɹ'), 0.028044182807207108),
    (('æ', 't'), 0.027583500370383263),
    (('f', 'ɔ', 'ɹ'), 0.024426816031336784),
    (('ɔ', 'f'), 0.02414545789361),
    (('w', 'ɪ', 'ð'), 0.024139564484357834)]),
  (('ʌ',),
   [(('ʌ', 'n', 'd'), 0.6690016388893127),
    (('ʌ', 'v'), 0.11473023891448975),
    (('ʌ',), 0.03751917346380651),
    (('ɑ', 'n'), 0.023077700287103653),
    (('ɔ', 'f'), 0.021370472386479378),
    (('ʌ', 'b', 'aʊ', 't'), 0.0205172561109066),
    (('ʌ', 'n', 't', 'ɪ', 'l'), 0.01971033774316311),
    (('ʌ', 'l', 'ɔ', 'ŋ'), 0.013498319312930107),
    (('ʌ', 'n', 'd', 'ɚ'), 0.011841410771012306),
    (('ɔ', 'ɹ'), 0.008652024902403355)]),
  (('ʌ', 'n'),
   [(('ʌ', 'n', 'd'), 0.9200675111060264),
    (('ɑ', 'n'), 0.03165260702371597),
 

In [251]:
# Plot up through the next word
anim = PredictiveAnimation(plot_data, k=5, threshold=threshold).plot_for_word(14, num_prior_frames=13, interval=1000)
anim.save("high_prior.gif", dpi=200)
plt.close()

ð
[(('ð', 'ʌ'), 0.900890713557601), (('l', 'ʌ', 'n', 'd', 'ʌ', 'n'), 0.017709124833345413), (('ð', 'ɪ', 's'), 0.008264764212071896), (('ð', 'ɛ', 'ɹ'), 0.006689589936286211), (('ð', 'æ', 't'), 0.003266195999458432), (('h', 'ɪ', 'z'), 0.002229166915640235), (('v', 'ɛ', 'n', 'ʌ', 's'), 0.0020196062978357077), (('d', 'ʌ', 'b', 'l', 'ɪ', 'n'), 0.0018624125514179468), (('l', 'ɛɪ', 'k'), 0.0017465573037043214), (('v', 'i', 'ɛ', 'n', 'ʌ'), 0.0013644139980897307)]
ð
[(('ð', 'ʌ'), 0.900890713557601), (('l', 'ʌ', 'n', 'd', 'ʌ', 'n'), 0.017709124833345413), (('ð', 'ɪ', 's'), 0.008264764212071896), (('ð', 'ɛ', 'ɹ'), 0.006689589936286211), (('ð', 'æ', 't'), 0.003266195999458432), (('h', 'ɪ', 'z'), 0.002229166915640235), (('v', 'ɛ', 'n', 'ʌ', 's'), 0.0020196062978357077), (('d', 'ʌ', 'b', 'l', 'ɪ', 'n'), 0.0018624125514179468), (('l', 'ɛɪ', 'k'), 0.0017465573037043214), (('v', 'i', 'ɛ', 'n', 'ʌ'), 0.0013644139980897307)]
ð
[(('ð', 'ʌ'), 0.900890713557601), (('l', 'ʌ', 'n', 'd', 'ʌ', 'n'), 0.017709124

  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])
  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])


g
[(('g', 'ɑ', 'ɹ', 'd', 'ʌ', 'n'), 0.1688644140958786), (('g', 'ɹ', 'ɛɪ', 't'), 0.06848350167274475), (('g', 'ɹ', 'æ', 's'), 0.06609975546598434), (('b', 'æ', 'k'), 0.028335537761449814), (('v', 'ɪ', 'l', 'ʌ', 'dʒ'), 0.0282503142952919), (('g', 'ɚ', 'ɑ', 'ʒ'), 0.02580440230667591), (('g', 'ʌ', 'l', 'f'), 0.025570306926965714), (('m', 'ɪ', 'd', 'ʌ', 'l'), 0.01897267811000347), (('b', 'ʊ', 'ʃ'), 0.018866106867790222), (('g', 'ɹ', 'æ', 'n', 'd'), 0.0166217889636755)]
gʌ
[(('g', 'ɑ', 'ɹ', 'd', 'ʌ', 'n'), 0.268437790684402), (('g', 'ʌ', 'l', 'f'), 0.24076923727989197), (('k', 'ʌ', 'n', 't', 'ɹ', 'i', 's', 'aɪ', 'd'), 0.044455721974372864), (('g', 'ʌ', 't'), 0.03067513182759285), (('g', 'ɑ', 'ɹ', 'd', 'ʌ', 'n', 'z'), 0.021426072344183922), (('b', 'æ', 'k'), 0.0198093019425869), (('d', 'ɑ', 'ɹ', 'k'), 0.019784588366746902), (('f', 'ɔ', 'ɹ', 'ʌ', 's', 't'), 0.018962906673550606), (('k', 'ʌ', 'n', 't', 'ɹ', 'i'), 0.01613122597336769), (('m', 'ɔ', 'ɹ', 'n', 'ɪ', 'ŋ'), 0.014577100984752178)]
gʌl

### Low-prior but high posterior

In [263]:
not_in_top_k_for_n = 3
low_prior_high_post_words = [
    (i, word_dat[-1][0], word_dat)
     for i, word_dat in enumerate(tqdm(plot_data))
     if len(word_dat) >= 3
     # not in top K for first two phonemes
     and not any(opt_name == word_dat[-1][0] for opt_name, _ in itertools.chain.from_iterable(word_dat[j][1] for j in range(not_in_top_k_for_n)))
     and any(opt_name == word_dat[-1][0] for opt_name, _ in word_dat[not_in_top_k_for_n][1])
     # exclude words which are high prior for uninteresting reasons (because the prior space is not big enough .. need top-p sampling)
     # and len(word_dat[-1][0]) < 5
]

  0%|          | 0/557 [00:00<?, ?it/s]

In [265]:
len(low_prior_high_post_words)

25

In [264]:
low_prior_high_post_words[0]

(28,
 ('f', 'ɔ', 'ɹ', 't', 'i'),
 [((),
   [(('t', 'aɪ', 'm'), 0.16776257753372192),
    (('ʌ', 'v'), 0.04247628524899483),
    (('p', 'ɚ', 's', 'ʌ', 'n'), 0.0320553220808506),
    (('t', 'u'), 0.030345959588885307),
    (('s', 't', 'ɛ', 'p'), 0.028222326189279556),
    (('θ', 'ɪ', 'ŋ'), 0.0239461250603199),
    (('d', 'ɛɪ'), 0.021939942613244057),
    (('j', 'ɪ', 'ɹ'), 0.01656157895922661),
    (('ʌ', 'n', 'd'), 0.01628647744655609),
    (('ɪ', 'n'), 0.012438971549272537)]),
  (('f',),
   [(('t', 'aɪ', 'm'), 0.12580832839012146),
    (('f', 'ɔ', 'ɹ'), 0.07003900781273842),
    (('f', 'j', 'u'), 0.06915374845266342),
    (('θ', 'ɪ', 'ŋ'), 0.06681916117668152),
    (('p', 'ɚ', 's', 'ʌ', 'n'), 0.059258654713630676),
    (('f', 'l', 'aɪ', 't'), 0.027344483882188797),
    (('f', 'ɛɪ', 'z'), 0.024579696357250214),
    (('f', 'l', 'ɔ', 'ɹ'), 0.024571167305111885),
    (('f', 'aɪ', 'v'), 0.024346400052309036),
    (('f', 'ʊ', 'l'), 0.02395486831665039)]),
  (('f', 'ɔ'),
   [(('f', 'ɔ', 'ɹ'), 

In [257]:
# Plot up through the next word
anim = PredictiveAnimation(plot_data, k=5, threshold=threshold).plot_for_word(12, num_prior_frames=21, interval=1000)
anim.save("low_prior.gif", dpi=200)
plt.close()


[(('ð', 'ʌ'), 0.11670882999897003), (('ʌ', 'n', 'd'), 0.11062169820070267), (('ɪ', 'n'), 0.07812955230474472), (('h', 'ɪ', 'z'), 0.06144588813185692), (('aʊ', 't'), 0.058351825922727585), (('f', 'ɔ', 'ɹ'), 0.05287881940603256), (('æ', 't'), 0.045396625995635986), (('ʌ',), 0.042579859495162964), (('w', 'ɪ', 'ð'), 0.03896272927522659), (('ɑ', 'n'), 0.029892463237047195)]

[(('ð', 'ʌ'), 0.11670882999897003), (('ʌ', 'n', 'd'), 0.11062169820070267), (('ɪ', 'n'), 0.07812955230474472), (('h', 'ɪ', 'z'), 0.06144588813185692), (('aʊ', 't'), 0.058351825922727585), (('f', 'ɔ', 'ɹ'), 0.05287881940603256), (('æ', 't'), 0.045396625995635986), (('ʌ',), 0.042579859495162964), (('w', 'ɪ', 'ð'), 0.03896272927522659), (('ɑ', 'n'), 0.029892463237047195)]

[(('ð', 'ʌ'), 0.11670882999897003), (('ʌ', 'n', 'd'), 0.11062169820070267), (('ɪ', 'n'), 0.07812955230474472), (('h', 'ɪ', 'z'), 0.06144588813185692), (('aʊ', 't'), 0.058351825922727585), (('f', 'ɔ', 'ɹ'), 0.05287881940603256), (('æ', 't'), 0.0453966259

  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])
  self.ax.set_yticklabels(["".join(option_i) for option_i, _ in incremental_dist])


ʌloʊn
[(('ʌ', 'l', 'oʊ', 'n'), 0.9995334148406982), (('ʌ', 'l', 'ɔ', 'ŋ'), 0.00034413972753100097), (('ʌ', 'l', 'ɔ', 'ŋ', 's', 'aɪ', 'd'), 8.658699516672641e-05), (('ɔ', 'l'), 2.2353835447574966e-05), (('ʌ',), 5.852026333741378e-06), (('ʌ', 'm', 'ʌ', 'ŋ'), 1.4719450973643688e-06), (('ɔ', 'f'), 1.1584675121412147e-06), (('ʌ', 'n', 'd'), 8.001819651326514e-07), (('ʌ', 'p', 'ɑ', 'n'), 7.291773158613069e-07), (('ɑ', 'n'), 6.608887019865506e-07)]

[(('ɪ', 'n'), 0.23478636145591736), (('ʌ', 'n', 'd'), 0.2185746431350708), (('ɑ', 'n'), 0.07124397903680801), (('f', 'ɔ', 'ɹ'), 0.07055483758449554), (('æ', 't'), 0.06559246778488159), (('w', 'ɪ', 'ð'), 0.05746229365468025), (('ɔ', 'l'), 0.01464034803211689), (('ɔ', 'ɹ'), 0.011018280871212482), (('t', 'u'), 0.010955079458653927), (('b', 'ʌ', 't'), 0.010897809639573097)]
---- ['ʌloʊn']
ɪ
[(('ɪ', 'n'), 0.9687198996543884), (('ʌ', 'n', 'd'), 0.008705920539796352), (('ɪ', 'n', 't', 'u'), 0.0049550230614840984), (('ɪ', 'n', 's', 'aɪ', 'd'), 0.003647950

## Index candidates by feature

### Find low-prior ground truth words

In [267]:
k = 5
low_prior_word_ids = stim.p_candidates[:, 0].argsort()[:k]
low_prior_words = to_gt_words(low_prior_word_ids)
low_prior_words

['sælɑu', 'bʌnɛvʌlʌnt', 'sæntiɑgoʊ', 'fɪʃlʌs', 'mɑɹlʌn']

### Find words which are still low-posterior at phoneme $k$

In [268]:
k = 3
n = 20

# Only include words which are still running at phoneme $k$
mask = stim.word_lengths > k
p_candidates_masked = p_candidates.clone()
p_candidates_masked[~mask] = torch.inf

low_posterior_word_ids = p_candidates_masked[:, 0, k].argsort()[:n]
low_posterior_words = to_gt_words(low_posterior_word_ids)
list(zip(low_posterior_word_ids, low_posterior_words))

[(tensor(165), 'bʌnɛvʌlʌnt'),
 (tensor(244), 'ʌndɪfitɪd'),
 (tensor(205), 'kɔɹdz'),
 (tensor(486), 'stægɚɪŋ'),
 (tensor(537), 'tækʌl'),
 (tensor(114), 'laɪnz'),
 (tensor(57), 'sælɑu'),
 (tensor(28), 'fɔɹti'),
 (tensor(491), 'plæŋk'),
 (tensor(138), 'lʊkt'),
 (tensor(55), 'dɛfʌnʌtli'),
 (tensor(166), 'skɪn'),
 (tensor(319), 'ɹɪmɛmbɚ'),
 (tensor(396), 'fɪʃɚmɪn'),
 (tensor(395), 'bɪtwin'),
 (tensor(193), 'hænz'),
 (tensor(56), 'faɪnʌli'),
 (tensor(556), 'sɔltɪŋ'),
 (tensor(442), 'dɛpθs'),
 (tensor(472), 'mɑɹlʌn')]

In [269]:
p_candidates[low_posterior_word_ids, 0, k]

tensor([0.0002, 0.0006, 0.0030, 0.0034, 0.0052, 0.0056, 0.0066, 0.0097, 0.0107,
        0.0128, 0.0159, 0.0235, 0.0285, 0.0289, 0.0338, 0.0364, 0.0424, 0.0493,
        0.0569, 0.0582])

In [270]:
anim_word_idx(1338)

NameError: name 'anim_word_idx' is not defined

### Find words which are not the MAP at phoneme $k$

In [61]:
p_candidates.shape

torch.Size([2187, 1000, 16])

In [69]:
k = 5

# Only include words which are still running at phoneme $k$
mask = stim.word_lengths > k
p_candidates_masked = p_candidates.clone()
p_candidates_masked[~mask] = -torch.inf
p_candidates_argmax = p_candidates_masked[:, k].argmax(dim=1)

not_argmax_idxs = torch.where(p_candidates_argmax != 0)[0]
list(zip(not_argmax_idxs, to_gt_words(not_argmax_idxs)))

[(tensor(15), 'mojstə'),
 (tensor(27), '#dipər'),
 (tensor(37), 'heləbul'),
 (tensor(63), 'zemənsə#x'),
 (tensor(89), 'bladrə'),
 (tensor(114), '#ɑləmal#x#'),
 (tensor(175), '#idərə'),
 (tensor(228), '#ustərs'),
 (tensor(233), 'ɑndərə'),
 (tensor(238), 'zɛs#ɪs'),
 (tensor(239), 'vɛrdər'),
 (tensor(240), 'vərdində'),
 (tensor(264), 'ɑləmal#r'),
 (tensor(282), 'merə#x'),
 (tensor(456), '#pʏrpərən'),
 (tensor(506), '#maktə'),
 (tensor(529), 'xəkert#xə'),
 (tensor(592), 'xəkomə'),
 (tensor(628), 'bəwox#xə'),
 (tensor(655), 'lœkərst'),
 (tensor(865), 'vərlɑŋənt'),
 (tensor(869), 'dəxenə'),
 (tensor(898), 'kɛikə#h#'),
 (tensor(950), 'zwɔm#h'),
 (tensor(967), 'zemermɪnətjə'),
 (tensor(974), '#œyjtstrɛktə'),
 (tensor(980), 'vɛiftin'),
 (tensor(990), 'hɔndərt'),
 (tensor(996), 'herləkstə'),
 (tensor(1007), '#kɑlmə'),
 (tensor(1010), 'lɪxə#xə'),
 (tensor(1043), 'rɛijtœyjxən'),
 (tensor(1082), 'darnax'),
 (tensor(1357), 'hɔndərt'),
 (tensor(1359), 'lɑntarnsh'),
 (tensor(1423), 'zɛstin'),
 (tensor

In [71]:
anim_word_idx(37)

h
[(0.2843737857954816, 'hel'), (0.11503773708401324, 'helə'), (0.04161625407498444, 'ht'), (0.030969074349699508, 'ht'), (0.027556656173453757, 'heləbul'), (0.02208031511063773, 'hɑrt'), (0.01865259578798935, 'h'), (0.014280507235832797, 'ha'), (0.012843420209193216, 'hɑnt'), (0.010211893493392134, 'hoxə')]
h
[(0.2843737857954816, 'hel'), (0.11503773708401324, 'helə'), (0.04161625407498444, 'ht'), (0.030969074349699508, 'ht'), (0.027556656173453757, 'heləbul'), (0.02208031511063773, 'hɑrt'), (0.01865259578798935, 'h'), (0.014280507235832797, 'ha'), (0.012843420209193216, 'hɑnt'), (0.010211893493392134, 'hoxə')]
h
[(0.2843737857954816, 'hel'), (0.11503773708401324, 'helə'), (0.04161625407498444, 'ht'), (0.030969074349699508, 'ht'), (0.027556656173453757, 'heləbul'), (0.02208031511063773, 'hɑrt'), (0.01865259578798935, 'h'), (0.014280507235832797, 'ha'), (0.012843420209193216, 'hɑnt'), (0.010211893493392134, 'hoxə')]
he
[(0.6092247775663425, 'hel'), (0.24644971965575924, 'helə'), (0.059

  self.ax.set_yticklabels([option_i for _, option_i in incremental_dist])
  self.ax.set_yticklabels([option_i for _, option_i in incremental_dist])


helə
[(0.7967883806543204, 'helə'), (0.19086626706381407, 'heləbul'), (0.009561477370025783, 'hel'), (0.0007058671381728866, 'hɛiləxə'), (0.0005689674464068114, 'hekəl'), (0.0004017861372421666, 'hɛiləx'), (0.00020622559791568684, 'heməl'), (0.00010486451961581173, 'hetə'), (6.861690441805815e-05, 'relixjœzə'), (6.518453824514022e-05, 'penəs')]
heləb
[(0.9925745083818459, 'heləbul'), (0.0073079214554912335, 'helə'), (8.76952115707618e-05, 'hel'), (6.47401710293987e-06, 'hɛiləxə'), (5.218411199292192e-06, 'hekəl'), (3.685070721613382e-06, 'hɛiləx'), (2.3126329349481686e-06, 'xəlʏk'), (1.8914438366206418e-06, 'heməl'), (1.1957089163794259e-06, 'penəs'), (9.617882130646226e-07, 'hetə')]
heləbu
[(0.9999942516993445, 'heləbul'), (5.628860856788864e-06, 'helə'), (6.75464489793558e-08, 'hel'), (1.99462140795383e-08, 'hɛiləxə'), (5.695692165012416e-09, 'relixjœzə'), (4.6872222371597065e-09, 'relixjœs'), (4.0194343512344055e-09, 'hekəl'), (2.838392621721737e-09, 'hɛiləx'), (1.7812847446340663e-

In [34]:
p_candidates[:, 0, 3].argmin()

tensor(1824)

In [35]:
p_candidates[1824, 0, :]

tensor([4.2436e-06, 5.6206e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06,
        5.6396e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06, 5.6396e-06,
        5.6396e-06, 5.6396e-06], dtype=torch.float64)

In [39]:
stim.get_candidate_strs(1824, 1)

['ɑls']

In [32]:
p_candidates[:, 0, 0].min()

tensor(9.5920e-07, dtype=torch.float64)