In [1]:
from collections import Counter, defaultdict
import itertools
from pathlib import Path
import pickle
import sys
from typing import List

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm

%load_ext autoreload
%autoreload 2
sys.path.append(str(Path(".").resolve().parent.parent))
from berp.datasets import NaturalLanguageStimulusProcessor
from berp.languages import english

In [2]:
tokenized_path = "oms.txt"
aligned_words_path = "word.csv"
aligned_phonemes_path = "phoneme.csv"
story_name = "old-man-and-the-sea"

cmu_ipa_dict_path = "cmudict_ipa.csv"
vocab_path = "../../workflow/heilbron2022/data/frequency/subtlexus2.csv"

output_dir = "old-man-and-the-sea"

model = "distilgpt2"
n_candidates = 1000

In [3]:
Path(output_dir).mkdir(exist_ok=True)

## Prepare tokenized data and aligned data

In [4]:
tokens = Path(tokenized_path).read_text().split(" ")

In [5]:
words_df = pd.read_csv(aligned_words_path, index_col=[0, 1])
phonemes_df = pd.read_csv(aligned_phonemes_path, index_col=[0, 1])

In [6]:
phonemes_df.phoneme.value_counts()

ʌ     3376
t     2378
n     2374
d     1957
ɪ     1850
ð     1505
l     1388
i     1334
s     1270
ɹ     1142
h     1033
m      962
ɛ      950
z      936
k      932
w      930
æ      814
ɚ      763
b      722
f      699
aɪ     618
oʊ     610
ɛɪ     536
ɔ      517
u      514
v      499
ɑ      478
p      442
ŋ      387
g      295
ʃ      284
aʊ     263
θ      258
ʊ      230
j      172
ɔɪ     117
tʃ     116
dʒ      79
ʒ        5
Name: phoneme, dtype: int64

## Prepare frequency data

In [7]:
frequency_df = pd.read_csv(vocab_path, sep="\t")

frequency_df["Word"] = frequency_df.Word.str.lower()
assert frequency_df.Word.value_counts().max() == 1

frequency_df["log_freq"] = -np.log2(frequency_df.FREQcount / frequency_df.FREQcount.sum())

In [8]:
words_df["word_lower"] = words_df.word.str.lower()
old_size = len(words_df)
words_df = pd.merge(words_df.reset_index(), frequency_df[["Word", "log_freq"]], left_on="word_lower", right_on="Word",
                    how="left")
assert len(words_df) == old_size

In [9]:
# Put words with missing frequency in the lowest 2 percentile.
missing_freq = words_df.log_freq.isna()
print(f"{missing_freq.sum()} ({int(missing_freq.mean() * 1000) / 1000}%) words missing frequency values.")
oov_freq = pd.qcut(words_df.log_freq, 50, retbins=True, duplicates="drop")[1][-1]
print(f"Replacing with 2-percentile log-frequency: {oov_freq}")
words_df.loc[missing_freq, "log_freq"] = oov_freq

244 (0.021%) words missing frequency values.
Replacing with 2-percentile log-frequency: 25.56731019333269


## Prepare phonemizer

In [10]:
phonemizer_df = pd.read_csv(cmu_ipa_dict_path)
phonemizer = english.Phonemizer(phonemizer_df)

  0%|          | 0/125807 [00:00<?, ?it/s]

In [11]:
dict_ipa_chars = set(itertools.chain.from_iterable(phonemizer.mapping.values()))
dict_ipa_chars

{'aɪ',
 'aʊ',
 'b',
 'd',
 'dʒ',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'oʊ',
 'p',
 's',
 't',
 'tʃ',
 'u',
 'v',
 'w',
 'z',
 'æ',
 'ð',
 'ŋ',
 'ɑ',
 'ɔ',
 'ɔɪ',
 'ɚ',
 'ɛ',
 'ɛɪ',
 'ɪ',
 'ɹ',
 'ʃ',
 'ʊ',
 'ʌ',
 'ʒ',
 'θ'}

In [12]:
assert dict_ipa_chars == set(phonemes_df.phoneme)

## Check phonemization agreement

For words with Heilbron phoneme annotations and a matching CMU annotation, check whether they agree.

In [43]:
misses = Counter()
mismatches = defaultdict(list)
matches = 0
for _, word_data in tqdm(phonemes_df.groupby(["run", "word_idx"])):
    word = word_data.word.iloc[0].lower()
    heilbron_pron = word_data.phoneme.str.cat(sep=" ")
    
    if word not in phonemizer.mapping:
        misses[word] += 1
    pron = " ".join(phonemizer(word))

    if heilbron_pron == pron:
        matches += 1
    else:
        mismatches[word].append(heilbron_pron)

  0%|          | 0/10763 [00:00<?, ?it/s]

In [44]:
sorted({word: len(mismatches_i) for word, mismatches_i in mismatches.items()}.items(),
       key=lambda x: -x[1])

[('the', 231),
 ('and', 165),
 ('with', 80),
 ('as', 73),
 ('him', 57),
 ('them', 52),
 ('when', 42),
 ('for', 40),
 ('can', 34),
 ('will', 29),
 ('into', 25),
 ('then', 19),
 ('from', 16),
 ('are', 15),
 ('get', 13),
 ('because', 12),
 ('to', 11),
 ('has', 10),
 ('an', 9),
 ('just', 9),
 ('without', 8),
 ('they', 8),
 ('your', 8),
 ('asked', 8),
 ('than', 8),
 ('before', 8),
 ('carried', 7),
 ('current', 7),
 ('not', 6),
 ('hands', 5),
 ('going', 5),
 ('too', 5),
 ('hundred', 5),
 ('hours', 5),
 ('circling', 5),
 ('or', 4),
 ('every', 4),
 ('where', 4),
 ('years', 4),
 ('it', 4),
 ('strange', 4),
 ('fisherman', 4),
 ('getting', 4),
 ('africa', 4),
 ('beaches', 4),
 ('strength', 4),
 ('that', 4),
 ('jump', 4),
 ('of', 3),
 ('take', 3),
 ('does', 3),
 ('a', 3),
 ('i', 3),
 ('different', 3),
 ('gently', 3),
 ('difference', 3),
 ('our', 3),
 ('either', 2),
 ('their', 2),
 ('edge', 2),
 ('you', 2),
 ('tell', 2),
 ('we', 2),
 ('put', 2),
 ('needed', 2),
 ('blanket', 2),
 ('dimaggio', 2),
 (

In [45]:
misses

Counter()

In [46]:
sum(len(mismatches_i) for mismatches_i in mismatches.values())

1330

In [47]:
mismatches

defaultdict(list,
            {'an': ['ʌ n',
              'ʌ n',
              'ʌ n',
              'ʌ n',
              'ʌ n',
              'ʌ n',
              'ʌ n',
              'ʌ n',
              'ʌ n'],
             'without': ['w ɪ ð aʊ t',
              'w ɪ ð aʊ t',
              'w ɪ ð aʊ t',
              'w ɪ ð aʊ t',
              'w ɪ ð aʊ t',
              'w ɪ ð aʊ t',
              'w ɪ ð aʊ t',
              'w ɪ ð aʊ t'],
             'with': ['w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ ð ð',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
              'w ɪ θ',
           

## Compute stimulus representations

### Prepare phoneme-level features

In [29]:
# syllable onsets are technically phoneme-level features!
syllable_tokenizer = english.IPASyllableTokenizer()

def compute_phoneme_syllable_features(word, phonemes):
    syllables = syllable_tokenizer.tokenize(phonemes)
    syllable_onset_idxs = [0] + list(np.cumsum([len(syll) for syll in syllables])[:-1])
    
    # Add syllable onset feature
    ret_features = torch.zeros((len(phonemes), 1))
    ret_features[syllable_onset_idxs, 0] = 1.
    
    return ret_features

### Run stimulus processor

In [30]:
PAD_PHONEME = "_"
proc = NaturalLanguageStimulusProcessor(phonemes=list(dict_ipa_chars) + [PAD_PHONEME],
                                        hf_model=model,
                                        num_candidates=n_candidates,
                                        phonemizer=phonemizer)

In [31]:
%pdb 1

Automatic pdb calling has been turned ON


In [32]:
for run, run_words in tqdm(words_df.groupby("run"), unit="run"):
    run_phonemes = phonemes_df.loc[run]
    
    # Prepare proc metadata input.
    word_to_token = run_words.groupby("word_idx") \
        .apply(lambda x: list(x.token_idx)).to_dict()
    
    ground_truth_phonemes = run_phonemes.groupby("word_idx") \
        .apply(lambda xs: list(xs.phoneme)).to_dict()
    
    # Prepare word-level features.
    word_features = dict(run_words.groupby("word_idx").apply(lambda xs: torch.tensor(xs.iloc[0].log_freq).unsqueeze(0)))
    word_feature_names = ["word_frequency"]
    
    phoneme_features = {}
    phoneme_feature_names = ["syllable_onset"]
    for idx in ground_truth_phonemes:
        syllable_features = compute_phoneme_syllable_features(word_to_token[idx], ground_truth_phonemes[idx])
        phoneme_features[idx] = syllable_features
        assert len(phoneme_feature_names) == phoneme_features[idx].shape[-1]
    
    # NB `tokens` contains tokens from all runs. so we'll actually be 
    # processing way too much per run. but it's okay I think
    
    stim = proc(f"{story_name}/run{run}", tokens, word_to_token,
                word_features, word_feature_names,
                phoneme_features, phoneme_feature_names,
                ground_truth_phonemes)
    
    with (Path(output_dir) / f"run{run}.pkl").open("wb") as f:
        pickle.dump(stim, f)

  0%|          | 0/19 [00:00<?, ?run/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]

  0%|          | 0/4 [00:00<?, ?batch/s]