In [1]:
from src.dataset import ConnTextULDataset
import pandas as pd
import torch as pt
import glob
import json
import tqdm

%reload_ext autoreload
%autoreload 2

[nltk_data] Downloading package cmudict to /Users/nathan/nltk_data...
[nltk_data]   Unzipping corpora/cmudict.zip.


In [8]:
# Load in a dataset to access Matt's Traindata class which
# is embedded inside the phonology tokenizer (consider changing this
from pathlib import Path

config = type(
    "config",
    (object,),
    {"dataset_filename": Path("data/data.csv")},
)
ds = ConnTextULDataset(config)

Cache folder: /Users/nathan/Dropbox/code/research/fsu_haskins/ConnTextUL/data/.cache already exists
shark's removed from pool because it is missing in cmudict
honeyed removed from pool because it is missing in cmudict
zookeeper removed from pool because it is missing in cmudict
statue's removed from pool because it is missing in cmudict
tia's removed from pool because it is missing in cmudict
we’re removed from pool because it is missing in cmudict
wiggled removed from pool because it is missing in cmudict
yipped removed from pool because it is missing in cmudict
u.s.a removed from pool because it is missing in cmudict
moosling's removed from pool because it is missing in cmudict
gumbles removed from pool because it is missing in cmudict
snuffling removed from pool because it is missing in cmudict
oink removed from pool because it is missing in cmudict
quacked removed from pool because it is missing in cmudict
quacking removed from pool because it is missing in cmudict
“we removed from

In [3]:
# Read in the Woodcock Johnson III Form A dataset for assessment
with open("data/wj_iii_form_a.json", "r") as f:
    wj3_json = json.load(f)

In [4]:
# Create a template dataframe to populate for each checkpoint
df = pd.DataFrame(
    columns=[
        "word_raw",
        "phon_target",
        "phon_prediction",
        "correct",
        "phon_target_features",
        "phon_prediction_features",
        "phon_prediction_probabilities",
        "global_encoding",
        "in_wj3",
        "in_traindata",
    ]
)

In [10]:
# read in the checkpoint file names. We will load one at a time
# and generate predictions of the wj3 assessments one at at ime
checkpoints = glob.glob("models/*.pth")
checkpoints.sort()
print(f"{checkpoints=}")

checkpoints=['models/nathan_2023-10-12_09h52m23512ms_chkpt020.pth', 'models/root_2024-03-04_17h17m55970ms_chkpt002.pth']


In [28]:
from src.model import Model
from addict import Dict as AttrDict

chkpt = pt.load(checkpoints[-1])
model = Model(AttrDict(chkpt["config"]), ds)
model.load_state_dict(chkpt["model_state_dict"])

datum = ds.character_tokenizer.encode(["aspfdjhqwrpiv", "hiccup", "h"])
print(f"{datum['enc_input_ids']=}")
print(f"{datum['enc_pad_mask']=}")

pred = model.generate(
    "o2p",
    datum["enc_input_ids"],
    datum["enc_pad_mask"],
    None,
    None,
    deterministic=True,
)
for k, v in pred.items():
    print(f"{k=}, {v=}")

datum['enc_input_ids']=tensor([[ 0, 11, 29, 26, 16, 14, 20, 18, 27, 33, 28, 26, 19, 32,  1],
        [ 0, 18, 19, 13, 13, 31, 26,  1,  4,  4,  4,  4,  4,  4,  4],
        [ 0, 18,  1,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4]])
--In Generate--
orth_enc_input:  torch.Size([3, 15])
orth_enc_pad_mask:  torch.Size([3, 15])
generated_phon_tokens=[[tensor([31])], [tensor([31])], [tensor([31])]]
--In phonology_decoder_loop--
generated_phon_embeddings.shape=torch.Size([3, 1, 64])
prompt_encoding.shape=torch.Size([3, 1, 64])
len(generated_phon_tokens)=3
	generated_phon_embeddings:  torch.Size([3, 1, 64])
	prompt_encoding:  torch.Size([3, 1, 64])
	step_mask:  torch.Size([1, 1])
--In phono_sample--
last_token_probs.shape=torch.Size([3, 2, 33])
last_token_probs=tensor([[[0.8230, 0.6945, 0.6237, 0.8580, 0.7908, 0.7536, 0.6512, 0.6216,
          0.7205, 0.5940, 0.8924, 0.7799, 0.7046, 0.6498, 0.7362, 0.5607,
          0.8805, 0.7564, 0.8643, 0.7366, 0.6603, 0.5095, 0.6486, 0.6562,
          0.

In [6]:
# Here we take all orthographic words from both WJ3 and the
# original dataset and remove duplicates.
all_words = set(wj3_json.keys())
all_words.update(ds.words)  # Take the union of both sets

In [None]:
writer = pd.ExcelWriter("wj3_predictions.xlsx", engine="openpyxl")
for checkpoint in tqdm.tqdm(checkpoints):
    # Empty the dataframe to prepare for the next checkpoint (aka excel page)
    df.drop(df.index, inplace=True)
    chkpt = pt.load(checkpoint)
    model = chkpt["model"]
    # In this loop, orth is the orthographic form of the target word and
    # phon is the phonological form of the target word (not phonetic vectors)
    for orth in all_words:
        new_row = {}
        # Since all_words is a union of both WJ3 and the original dataset,
        # the phonological form of any orth in all_words is in either WJ3 or
        # the original cmudict dataset. We check both below
        in_wj3 = orth in wj3_json.keys()
        new_row["in_wj3"] = in_wj3
        in_traindata = orth in ds.words
        new_row["in_traindata"] = in_traindata
        if in_traindata:
            phon = ds.cmudict[orth]
        else:
            phon = wj3_json[orth]

        datum = ds.character_tokenizer.encode(orth)
        pred = chkpt["model"].generate(
            "o2p",
            datum["enc_input_ids"],
            datum["enc_pad_mask"],
            datum["enc_input_ids"],
            datum["enc_pad_mask"],
            deterministic=True,
        )
        # Save the original input orthography
        new_row["word_raw"] = orth
        # Save the target phonology for the above input orthography
        new_row["phon_target"] = ":".join(phon)
        # Remove the start and end tokens from each phonological vector
        # and convert them from tensors to lists
        phon_pred_features = [tensor.tolist() for tensor in pred["phon_tokens"][1:-1]]
        # Convert the phonological vectors to phonemes using Matt's handy dandy routine
        phon_pred = ds.phonology_tokenizer.traindata.convert_numeric_prediction(
            phon_pred_features, phonology=True, hot_nodes=True
        )
        # Save the model's predicted pronunciation for this word. Phonemes are
        # separated by colons
        phon_pred = ["None" if p == None else p for p in phon_pred]
        new_row["phon_prediction"] = ":".join(phon_pred)
        # Save a boolean indicating whether the prediction was correct
        new_row["correct"] = new_row["phon_target"] == new_row["phon_prediction"]
        # Save the phonological features for the target phonology
        phon_target_features = ds.phonology_tokenizer.encode([orth])
        if phon_target_features:
            phon_target_features = ";".join(
                [
                    ":".join([str(v.item()) for v in vector])
                    for vector in phon_target_features["targets"][0]
                ]
            )
        else:
            phon_target_features = "None"
        new_row["phon_target_features"] = phon_target_features
        # Save the phonological features for the predicted phonology
        phon_prediction_features = ";".join(
            [
                ":".join([str(int(v.item())) for v in vector])
                for vector in pred["phon_vecs"][1:-1]
            ]
        )
        new_row["phon_prediction_features"] = phon_prediction_features
        # Save the phonological probabilities for the predicted phonology
        phon_prediction_probabilities = ";".join(
            [
                ":".join([str(v.item()) for v in vector])
                for vector in pred["phon_probs"][1:-1]
            ]
        )
        new_row["phon_prediction_probabilities"] = phon_prediction_probabilities
        # Save the global encoding vector for the predicted phonology
        global_encoding = ":".join(
            [str(v.item()) for v in pred["global_encoding"].squeeze()]
        )
        new_row["global_encoding"] = global_encoding
        new_row = pd.Series(new_row)
        new_row = new_row.to_frame().transpose()
        df = pd.concat([df, new_row])
    df.to_excel(writer, sheet_name="epoch " + checkpoint[-7:-4])
writer.book.save("wj3_predictions.xlsx")