In [1]:
from collections import Counter, defaultdict
import itertools
from pathlib import Path
import pickle
import sys
from typing import List, Tuple, Optional
from typing_extensions import TypeAlias

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm

%load_ext autoreload
%autoreload 2
sys.path.append(str(Path(".").resolve().parent.parent))
from berp.datasets import Phoneme
from berp.datasets import NaturalLanguageStimulusProcessor
from berp.languages import english

In [2]:
WordForm: TypeAlias = Tuple[Phoneme, ...]

In [3]:
tokenized_path = "oms.txt"
aligned_words_path = "word.csv"
aligned_phonemes_path = "phoneme.csv"
story_name = "old-man-and-the-sea"

cmu_ipa_dict_path = "cmudict_ipa.csv"
vocab_path = "../../workflow/heilbron2022/data/frequency/subtlexus2.csv"

output_dir = "old-man-and-the-sea"

model = "distilgpt2"
n_candidates = 100#0

In [4]:
Path(output_dir).mkdir(exist_ok=True)

## Prepare tokenized data and aligned data

In [5]:
tokens = Path(tokenized_path).read_text().split(" ")

In [6]:
words_df = pd.read_csv(aligned_words_path, index_col=[0, 1])
phonemes_df = pd.read_csv(aligned_phonemes_path, index_col=[0, 1])

In [7]:
phonemes_df.phoneme.value_counts()

ʌ     3376
t     2378
n     2374
d     1957
ɪ     1850
ð     1505
l     1388
i     1334
s     1270
ɹ     1142
h     1033
m      962
ɛ      950
z      936
k      932
w      930
æ      814
ɚ      763
b      722
f      699
aɪ     618
oʊ     610
ɛɪ     536
ɔ      517
u      514
v      499
ɑ      478
p      442
ŋ      387
g      295
ʃ      284
aʊ     263
θ      258
ʊ      230
j      172
ɔɪ     117
tʃ     116
dʒ      79
ʒ        5
Name: phoneme, dtype: int64

## Prepare frequency data

In [8]:
frequency_df = pd.read_csv(vocab_path, sep="\t")

frequency_df["Word"] = frequency_df.Word.str.lower()
assert frequency_df.Word.value_counts().max() == 1

frequency_df["log_freq"] = -np.log2(frequency_df.FREQcount / frequency_df.FREQcount.sum())

In [9]:
words_df["word_lower"] = words_df.word.str.lower()
old_size = len(words_df)
words_df = pd.merge(words_df.reset_index(), frequency_df[["Word", "log_freq"]], left_on="word_lower", right_on="Word",
                    how="left")
assert len(words_df) == old_size

In [10]:
# Put words with missing frequency in the lowest 2 percentile.
missing_freq = words_df.log_freq.isna()
print(f"{missing_freq.sum()} ({int(missing_freq.mean() * 1000) / 1000}%) words missing frequency values.")
oov_freq = pd.qcut(words_df.log_freq, 50, retbins=True, duplicates="drop")[1][-1]
print(f"Replacing with 2-percentile log-frequency: {oov_freq}")
words_df.loc[missing_freq, "log_freq"] = oov_freq

244 (0.021%) words missing frequency values.
Replacing with 2-percentile log-frequency: 25.56731019333269


## Prepare phonemizer

In [11]:
# Load IPA dictionary and merge in processed frequency information
phonemizer_df = pd.read_csv(cmu_ipa_dict_path)
phonemizer_df = pd.merge(phonemizer_df, frequency_df[["Word", "log_freq"]], how="left",
                         left_on="word", right_on="Word").drop(columns=["Word"])
phonemizer_df["log_freq"] = phonemizer_df.log_freq.fillna(oov_freq)
phonemizer_df["freq"] = np.power(2, -phonemizer_df.log_freq)
phonemizer_df

Unnamed: 0,word,pronunciation_idx,pronunciation_syllable,pronunciation,log_freq,freq
0,was,0,ˈ w ʌ z,w ʌ z,7.429644,5.800353e-03
1,was,1,ˈ w ɑ z,w ɑ z,7.429644,5.800353e-03
2,was,2,. w ɑ z,w ɑ z,7.429644,5.800353e-03
3,wind,0,ˈ w ɪ n d,w ɪ n d,14.003161,6.090158e-05
4,wind,1,ˈ w aɪ n d,w aɪ n d,14.003161,6.090158e-05
...,...,...,...,...,...,...
134431,{brace,0,ˈ b ɹ ɛɪ s,b ɹ ɛɪ s,25.567310,2.011281e-08
134432,{left-brace,0,ˈ l ɛ f t ˈ b ɹ ɛɪ s,l ɛ f t b ɹ ɛɪ s,25.567310,2.011281e-08
134433,{open-brace,0,ˈ oʊ . p ɛ n ˈ b ɹ ɛɪ s,oʊ p ɛ n b ɹ ɛɪ s,25.567310,2.011281e-08
134434,}close-brace,0,ˈ k l oʊ z ˈ b ɹ ɛɪ s,k l oʊ z b ɹ ɛɪ s,25.567310,2.011281e-08


In [12]:
phonemizer = english.Phonemizer(phonemizer_df)

  0%|          | 0/125807 [00:00<?, ?it/s]

In [13]:
dict_ipa_chars = set(itertools.chain.from_iterable(phonemizer.mapping.values()))
dict_ipa_chars

{'aɪ',
 'aʊ',
 'b',
 'd',
 'dʒ',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'oʊ',
 'p',
 's',
 't',
 'tʃ',
 'u',
 'v',
 'w',
 'z',
 'æ',
 'ð',
 'ŋ',
 'ɑ',
 'ɔ',
 'ɔɪ',
 'ɚ',
 'ɛ',
 'ɛɪ',
 'ɪ',
 'ɹ',
 'ʃ',
 'ʊ',
 'ʌ',
 'ʒ',
 'θ'}

In [14]:
assert dict_ipa_chars == set(phonemes_df.phoneme)

## Check phonemization agreement

For words with Heilbron phoneme annotations and a matching CMU annotation, check whether they agree.

In [15]:
misses = Counter()
mismatches = defaultdict(list)
matches = 0
for _, word_data in tqdm(phonemes_df.groupby(["run", "word_idx"])):
    word = word_data.word.iloc[0].lower()
    heilbron_pron = word_data.phoneme.str.cat(sep=" ")
    
    if word not in phonemizer.mapping:
        misses[word] += 1
    pron = " ".join(phonemizer(word))

    if heilbron_pron == pron:
        matches += 1
    else:
        mismatches[word].append(heilbron_pron)

  0%|          | 0/10769 [00:00<?, ?it/s]

In [16]:
sorted({word: len(mismatches_i) for word, mismatches_i in mismatches.items()}.items(),
       key=lambda x: -x[1])

[('the', 231),
 ('and', 165),
 ('him', 57),
 ('them', 52),
 ('when', 42),
 ('can', 34),
 ('will', 29),
 ('into', 25),
 ('or', 22),
 ('then', 19),
 ('from', 16),
 ('for', 16),
 ('are', 15),
 ('get', 13),
 ('to', 10),
 ('has', 10),
 ('just', 9),
 ('they', 8),
 ('your', 8),
 ('asked', 8),
 ('than', 8),
 ('before', 8),
 ('because', 7),
 ('carried', 7),
 ('current', 7),
 ('not', 6),
 ('going', 5),
 ('too', 5),
 ('hundred', 5),
 ('hours', 5),
 ('circling', 5),
 ('where', 4),
 ('years', 4),
 ('it', 4),
 ('strange', 4),
 ('fisherman', 4),
 ('getting', 4),
 ('africa', 4),
 ('beaches', 4),
 ('strength', 4),
 ('that', 4),
 ('jump', 4),
 ('of', 3),
 ('take', 3),
 ('with', 3),
 ('does', 3),
 ('a', 3),
 ('i', 3),
 ('different', 3),
 ('gently', 3),
 ('difference', 3),
 ('our', 3),
 ('their', 2),
 ('edge', 2),
 ('you', 2),
 ('tell', 2),
 ('we', 2),
 ('put', 2),
 ('needed', 2),
 ('blanket', 2),
 ('dimaggio', 2),
 ('wanted', 2),
 ('great', 2),
 ('horses', 2),
 ('most', 2),
 ('he', 2),
 ('urinated', 2),


In [17]:
misses

Counter()

In [18]:
sum(len(mismatches_i) for mismatches_i in mismatches.values())

1137

In [19]:
mismatches

defaultdict(list,
            {'him': ['ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m dʒ',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
              'ɪ m',
    

## Compute stimulus representations

### Prepare phoneme-level features

In [20]:
class CohortPredictiveModel:
    """
    A model of phoneme probabilities which incorporates a lexical frequency prior
    together with a 0-1 cohort likelihood.
    """
    
    def __init__(self, lexicon_items: List[WordForm]):
        # Compute a unigram phoneme frequency distribution
        # Also simultaneously compute cohorts (sets of words compatible with prefix)
        phoneme_freqs, cohorts = Counter(), defaultdict(set)
        for ipa_word in lexicon_items:
            phoneme_freqs.update(ipa_word)
            for prefix in range(len(ipa_word)):
                cohorts[tuple(ipa_word[:prefix])].add(ipa_word)
                
        self._unigram_phoneme_distribution = pd.Series(phoneme_freqs)
        self._unigram_phoneme_distribution /= self._unigram_phoneme_distribution.sum()
        
        self._cohorts = {prefix: list(cohort) for prefix, cohort in cohorts.items()}
        
    def cohort_distribution(self, ipa_prefix: WordForm, pad_phoneme="_") -> Optional[pd.DataFrame]:
        cohort_words = self._cohorts.get(ipa_prefix, [])
        if not cohort_words:
            return None

In [21]:
# word, phonemes = word_to_token[1], ground_truth_phonemes[1]
# word, phonemes

# # Mixture parameter for conditional and unigram phoneme distributions when computing phoneme probabilities
# # TODO tune?
# gamma = 0.1

NameError: name 'word_to_token' is not defined

In [None]:
# def cohort_distribution(ipa_prefix, pad_phoneme="_"):
#     cohort_words =

# def cohort_phoneme_distribution(ipa_prefix, pad_phoneme="_"):
#     df = cohort_distribution(ipa_prefix, pad_phoneme=pad_phoneme)
#     if df is None:
#         # Back off to unigram distribution
#         return _unigram_phoneme_distribution
    
#     ps = df.groupby("next").p.sum()
#     # mix with phoneme unigram distribution
#     gamma = 0.1
#     ps = (1 - gamma) * ps
#     ps = ps.add(gamma * _unigram_phoneme_distribution, fill_value=0)
#     ps /= ps.sum()
#     return ps

In [None]:
# surprisals, entropies = [], []
# for prefix_length in range(len(phonemes) + 1):
#     prefix = "".join(phonemes[:prefix_length])
    
#     if prefix_length == len(phonemes):
#         # Only compute posterior entropy
#         surprisal = 0
#         dist_prev = cohort_phoneme_distribution(prefix)
#         entropy = (dist_prev * -np.log2(dist_prev)).sum()
#     else:
#         # Compute surprisal and entropy
#         surprisal, entropy = phoneme_surprisal_entropy(prefix, phonemes[prefix_length])
        
#     surprisals.append(surprisal)
#     entropies.append(entropies)

In [22]:
# syllable onsets are technically phoneme-level features!
syllable_tokenizer = english.IPASyllableTokenizer()

def compute_syllable_onset_idxs(phonemes) -> List[int]:
    """
    Compute the phoneme indices which constitute the start of a new syllable.
    """
    syllables = syllable_tokenizer.tokenize(phonemes)
    syllable_onset_idxs = [0] + list(np.cumsum([len(syll) for syll in syllables])[:-1])
    
    return syllable_onset_idxs

def compute_phoneme_features(phonemes, syllable_onset_idxs):
    # Add syllable onset feature
    ret_features = torch.zeros((len(phonemes), 1))
    ret_features[syllable_onset_idxs, 0] = 1.
    
    return ret_features

### Run stimulus processor

In [23]:
PAD_PHONEME = "_"
proc = NaturalLanguageStimulusProcessor(phonemes=list(dict_ipa_chars) + [PAD_PHONEME],
                                        hf_model=model,
                                        num_candidates=n_candidates,
                                        phonemizer=phonemizer)

Using pad_token, but it is not set yet.


In [24]:
%pdb 1

Automatic pdb calling has been turned ON


In [25]:
for run, run_words in tqdm(words_df.groupby("run"), unit="run"):
    run_phonemes = phonemes_df.loc[run]
    
    # Prepare proc metadata input.
    word_to_token = run_words.groupby("word_idx") \
        .apply(lambda x: list(x.token_idx)).to_dict()
    
    ground_truth_phonemes = run_phonemes.groupby("word_idx") \
        .apply(lambda xs: list(xs.phoneme)).to_dict()
    
    # Prepare word-level features.
    word_features = dict(run_words.groupby("word_idx").apply(lambda xs: torch.tensor(xs.iloc[0].log_freq).unsqueeze(0)))
    word_feature_names = ["word_frequency"]
    
    phoneme_features = {}
    syllable_onset_idxs = {}
    phoneme_feature_names = ["syllable_onset"]
    for word_idx in ground_truth_phonemes:
        phonemes_i = ground_truth_phonemes[word_idx]
        syllable_onset_idxs[word_idx] = compute_syllable_onset_idxs(phonemes_i)
        phoneme_features[word_idx] = compute_phoneme_features(phonemes_i, syllable_onset_idxs[word_idx])
        assert len(phoneme_feature_names) == phoneme_features[word_idx].shape[-1]
    
    # NB `tokens` contains tokens from all runs. so we'll actually be 
    # processing way too much per run. but it's okay I think
    
    stim = proc(f"{story_name}/run{run}", tokens, word_to_token,
                word_features, word_feature_names,
                phoneme_features, phoneme_feature_names,
                ground_truth_phonemes)
    
    # Add syllable onset annotation
    if stim.word_events is None:
        stim.word_events = {}
    stim.word_events["syllable_onset"] = [syllable_onset_idxs[word_id.item()]
                                          for word_id in stim.word_ids]
    
    with (Path(output_dir) / f"run{run}.pkl").open("wb") as f:
        pickle.dump(stim, f)

  0%|          | 0/19 [00:00<?, ?run/s]

  0%|          | 0/4 [00:00<?, ?batch/s]