In [1]:
from src.dataset import ConnTextULDataset
import pandas as pd
import torch as pt
import glob
import json
import tqdm
import torch
import csv
from time import time
from src.model import Model
from addict import Dict as AttrDict

%reload_ext autoreload
%autoreload 2

[nltk_data] Downloading package cmudict to /root/nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


In [2]:
# Load in a dataset to access Matt's Traindata class which
# is embedded inside the phonology tokenizer (consider changing this
from pathlib import Path

config = type(
    "config",
    (object,),
    {"dataset_filename": Path("data/random_kid_10140_021624.csv")},
)
ds = ConnTextULDataset(config)

Cache folder: /workspaces/ConnTextUL/data/.cache already exists


In [4]:
# read in the checkpoint file names. We will load one at a time
# and generate predictions of the wj3 assessments one at at ime
checkpoints = glob.glob("models/root_2024-04-16_04h11m47948ms_chkpt*.pth")
checkpoints.sort()
print(f"{checkpoints=}")

checkpoints=['models/root_2024-04-16_04h11m47948ms_chkpt001.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt002.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt003.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt004.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt005.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt006.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt007.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt008.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt009.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt010.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt011.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt012.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt013.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt014.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt015.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt016.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt017.pth', 'models/root_2024-04-16_04h11m47948ms_chkpt018.pth', 'models/root_2024-04-16_04h11m479

In [5]:
batch_size = 64
all_words=set()
all_words.update(ds.words)
batches = [list(all_words)[i:i+batch_size] for i in range(0, len(all_words), batch_size)]

In [7]:
from src.model import Model
from addict import Dict as AttrDict
import pandas as pd
import tqdm
import torch

csv_name = f"model_{checkpoints[-1][-42:-13]}.xlsx"
writer = pd.ExcelWriter(csv_name, engine='openpyxl')

with pd.ExcelWriter(csv_name, engine='openpyxl') as writer:
    for checkpoint in tqdm.tqdm(checkpoints[:1]):
        chkpt = torch.load(checkpoint)
        dfa = pd.DataFrame(columns=["phon_prediction"])
        model = Model(AttrDict(chkpt["config"]), ds)
        model.load_state_dict(chkpt["model_state_dict"])
        model.eval()  # Set the model to evaluation mode
        
        sheet_name = "epoch " + checkpoint[-7:-4]
        start_row = 0  # Initialize starting row for each new sheet

        with torch.no_grad():
            for batch_idx in range(len(batches)):
                batch = batches[batch_idx]
                new_row = {}
                datum = ds.character_tokenizer.encode(batch)
                pred = model.generate(
                    "o2p",
                    datum['enc_input_ids'],
                    datum['enc_pad_mask'],
                    None,
                    None,
                    deterministic=True,
                )
                for idx, orth in enumerate(batch):
                    # Save the original input orthography
                    new_row["word_raw"] = orth
                    # Save the target phonology for the above input orthography
                    phon = ds.cmudict[orth]
                    new_row["phon_target"] = ":".join(phon)
                    # Remove the start and end tokens from each phonological vector
                    # and convert them from tensors to lists
                    phon_pred_features = [tensor.tolist() for tensor in pred["phon_tokens"][idx][1:-1]]
                    # Convert the phonological vectors to phonemes using Matt's handy dandy routine
                    phon_pred = ds.phonology_tokenizer.traindata.convert_numeric_prediction(
                        phon_pred_features, phonology=True, hot_nodes=True
                    )
                    # Save the model's predicted pronunciation for this word. Phonemes are
                    # separated by colons
                    phon_pred = ["None" if p == None else p for p in phon_pred]
                    new_row["phon_prediction"] = ":".join(phon_pred)
                    # Save a boolean indicating whether the prediction was correct
                    new_row["correct"] = new_row["phon_target"] == new_row["phon_prediction"]
                    # Save the phonological features for the target phonology
                    phon_target_features = ds.phonology_tokenizer.encode([orth])
                    if phon_target_features:
                        phon_target_features = ";".join(
                            [
                                ":".join([str(v.item()) for v in vector])
                                for vector in phon_target_features["targets"][0]
                            ]
                        )
                    else:
                        phon_target_features = "None"
                    new_row["phon_target_features"] = phon_target_features
                    # Save the phonological features for the predicted phonology
                    phon_prediction_features = ";".join(
                        [
                            ":".join([str(int(v.item())) for v in vector])
                            for vector in pred["phon_vecs"][idx][1:-1]
                        ]
                    )
                    new_row["phon_prediction_features"] = phon_prediction_features
                    # Save the phonological probabilities for the predicted phonology
                    phon_prediction_probabilities = ";".join(
                        [
                            ":".join([str(v.item()) for v in vector])
                            for vector in pred["phon_probs"][idx][1:-1]
                        ]
                    )
                    new_row["phon_prediction_probabilities"] = phon_prediction_probabilities
                    # Save the global encoding vector for the predicted phonology
                    global_encoding = ":".join(
                        [str(v.item()) for v in pred["global_encoding"][idx].squeeze()]
                    )
                    new_row["global_encoding"] = global_encoding
                    dfa = pd.DataFrame([new_row])
                    dfa.to_excel(writer, sheet_name=sheet_name, startrow=start_row, header=(start_row == 0), index=False)
                    start_row += 1  # Increment the start row for the next batch

  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


In [47]:
datum = ds.character_tokenizer.encode(["hero", "timer"])
model.generate("o2p", datum['enc_input_ids'], datum['enc_pad_mask'], None, None, deterministic=True)

{'orth_probs': None,
 'orth_tokens': None,
 'phon_probs': [[tensor([0.1235, 0.0301, 0.2751, 0.0970, 0.0729, 0.0163, 0.2480, 0.1374, 0.0356,
           0.0988, 0.0781, 0.0327, 0.0741, 0.0366, 0.6443, 0.1211, 0.1016, 0.0618,
           0.0997, 0.0355, 0.0472, 0.1412, 0.0292, 0.0535, 0.0992, 0.0576, 0.0549,
           0.0453, 0.0322, 0.1114, 0.0231, 0.0152, 0.0651]),
   tensor([0.0766, 0.0369, 0.2764, 0.1017, 0.0662, 0.0142, 0.1386, 0.0974, 0.0427,
           0.0874, 0.0668, 0.0321, 0.0987, 0.0274, 0.7055, 0.1951, 0.0829, 0.0603,
           0.1323, 0.0367, 0.0512, 0.1212, 0.0306, 0.0557, 0.1013, 0.0492, 0.0413,
           0.0560, 0.0364, 0.2061, 0.0261, 0.0188, 0.1734]),
   tensor([0.1126, 0.0468, 0.3562, 0.0800, 0.0828, 0.0142, 0.1907, 0.1525, 0.0349,
           0.0954, 0.0736, 0.0458, 0.0647, 0.0358, 0.6809, 0.1508, 0.1023, 0.0733,
           0.1044, 0.0353, 0.0527, 0.1193, 0.0277, 0.0495, 0.0818, 0.0642, 0.0388,
           0.0501, 0.0252, 0.1292, 0.0215, 0.0144, 0.1131]),
   tensor([0.

In [62]:
pred['global_encoding'][0].squeeze()

tensor([-0.2310,  0.1568,  1.0121,  0.0091, -0.9715,  1.6031, -0.3786, -1.5646,
        -1.4997, -1.8750,  0.6249, -1.6910, -0.8417,  0.2227, -0.0109, -1.4188,
        -0.1484, -0.9539,  0.4672, -0.2058,  0.3121,  1.7313, -1.3862,  1.4183,
         1.0050, -1.6837,  1.0909,  0.7093,  0.6097,  0.9900,  1.5349,  0.9439])

In [56]:
batches[0]

['unprecedented',
 'causal',
 'walks',
 'cupboard',
 'greed',
 'chartered',
 'player',
 'yards',
 'outpost',
 'mentally',
 'intervening',
 'paints',
 'canada',
 'professions',
 'transported',
 'unconscious',
 'flew',
 'nate',
 'carpet',
 'durable',
 'tagged',
 'tip',
 'crow',
 'happier',
 'reside',
 'cursing',
 'mobs',
 'accordance',
 'forehead',
 'abruptly',
 'newer',
 'washington',
 'peanut',
 'zebras',
 'reproductive',
 'finances',
 'determines',
 'ten',
 'defective',
 'beauty',
 'minnie',
 'battled',
 'smokey',
 'strains',
 'bedding',
 'thriving',
 'occupational',
 'governed',
 'yugoslavia',
 'steady',
 'beating',
 'matrix',
 'vertebrae',
 'dotted',
 'boa',
 'sutton',
 'roberta',
 'ineffective',
 'quixote',
 'fetus',
 'deficit',
 'fragments',
 'eccentric',
 'hordes']