# Test the model from checkpoints
This notebook allows you to take checkpoints from a trained model and make predictions for a set of words for each checkpoint. The data are written to CSV, one for each checkpoint (where a checkpoint corresponds to an epoch).

In [None]:
from src.dataset import ConnTextULDataset
import pandas as pd
import torch
import tqdm
from src.model import Model
from addict import Dict as AttrDict
from pathlib import Path
import os
import json
import nltk
cmu = nltk.corpus.cmudict.dict()
import os


%reload_ext autoreload
%autoreload 2

# Pre-aggregate the word data
Each of these models was trained with a different dataset, all contained in `data/SSSR/`. For the frequencies for each of those sets for the purposes of analysis, we can reference those files individually later. Right now, we just need the words. To make things simpler we will read in all of them and get only the unique words, rather than testing on all the datasets individually, which would be onerous because there are several different sets, they contain repeated words across sets and repreated words within any given set.

In [18]:
DIRECTORY = "my_sidewalks_75_percent_background_25_percent"
CONDITION = "my_sidewalks"
words = []

for filename in os.listdir(DIRECTORY):
    if filename.startswith(CONDITION) & filename.endswith('.csv'):

        FILEPATH = os.path.join(DIRECTORY, filename)
        # read the .csv file into a pandas DataFrame and store it in the dictionary
        words.extend(pd.read_csv(FILEPATH)['word_raw'].tolist())

words = [word for word in words if isinstance(word, str) and word in cmu.keys()]


Let's aggregate the Woodcock words and use those for testing as well.

In [12]:
maxorth = max([len(word) for word in words])
maxphon = max([len(cmu[word][0]) for word in words])

with open('data/wj_iii_form_a.json', 'r') as file:
    wj3 = json.load(file)

wj3 = [word for word in wj3.keys() if len(word) > 1 and word in cmu.keys()]

with open('data/TOWRE2_sight_words.json', 'r') as file:
    towre = json.load(file)

towre = [word for word in towre.keys() if len(word) > 1 and word in cmu.keys()]



for word in wj3:
    if len(word) <= maxorth & len(cmu[word][0]) <= maxphon:
        words.extend(word)
for word in towre:
    if len(word) <= maxorth & len(cmu[word][0]) <= maxphon:
        words.extend(word)


Create wordlist to write to file for config

In [19]:
words = sorted(list(set(words)))
OUTFILE = os.path.join(DIRECTORY, "words_for_test.csv")

with open(OUTFILE, 'w') as file:
    file.write("word_raw\n")
    for word in words:
        file.write('{}\n'.format(word))
file.close()


## Config
Load in a dataset to Traindata class which is embedded inside the phonology tokenizer. Later implementations should consider breaking this process apart such that Traindata is initialized prior to tokenization. Also establish batches: larger batches make for faster processing, but your machine may impose an upper limit.

# Trade books
Work through each trade books directory and generate model predictions.

In [20]:

config = type("config",
    (object,),
    {"dataset_filename": Path(OUTFILE)}, )
ds = ConnTextULDataset(config)

checkpoints = []

for filename in os.listdir(DIRECTORY):

    if filename.endswith(".pth"):
        checkpoints.append(filename)


batch_size = 1000
all_words=set()
all_words.update(ds.words)
batches = [list(all_words)[i:i+batch_size] for i in range(0, len(all_words), batch_size)]
checkpoints.sort()

    

Cache folder: /workspaces/BRIDGE/data/.cache already exists


In [21]:
def insert_substring(x, y, target = '.csv'):
    i = x.find(target)
    if i == -1:
        return x
    else:
        return x[:i] + y + x[i:]

In [24]:
chkpt['config']

{'num_epochs': 150,
 'batch_size_train': 32,
 'batch_size_val': 64,
 'd_model': 128,
 'd_embedding': 2,
 'nhead': 2,
 'pathway': 'o2p',
 'test': False,
 'train_test_split': 1.0,
 'save_every': 10,
 'seed': 2025,
 'learning_rate': 0.001,
 'project': 'SSSR2024',
 'wandb': True,
 'num_phon_enc_layers': 1,
 'num_orth_enc_layers': 1,
 'num_mixing_enc_layers': 1,
 'num_phon_dec_layers': 1,
 'num_orth_dec_layers': 1,
 'device': 'cpu',
 'model_path': '/workspaces/BRIDGE/./models',
 'dataset_filename': '/workspaces/BRIDGE/data/my_sidewalks_75_percent_background_25_percent.csv',
 'max_nb_steps': 20,
 'sweep_filename': '',
 'test_filenames': ['/workspaces/BRIDGE/data/tests/test1.csv',
  '/workspaces/BRIDGE/data/tests/test2.csv'],
 'model_id': 'root_2024-05-24_01h23m08973ms',
 'model_file_name': 'root_2024-05-24_01h23m08973ms_chkpt000.pth',
 'n_steps_per_epoch': 311}

In [22]:
for checkpoint in tqdm.tqdm(checkpoints):
    print(checkpoint, "arrived")
    
    outfile = checkpoint.replace(".pth", ".csv")
    outpath = os.path.join(DIRECTORY, outfile)

    outpath2 = insert_substring(outpath, "_from_units")

    outfile2 = open(outpath2, "w")

    chkpt = torch.load(os.path.join(DIRECTORY, checkpoint))
    dfa = pd.DataFrame(columns=["phon_prediction"])
    model = Model(AttrDict(chkpt["config"]), ds)
    model.load_state_dict(chkpt["model_state_dict"])
    model.eval()  # Set the model to evaluation mode
    
    dl = []

    start_row = 0  # Initialize starting row for each new sheet

    print("Checkpoint", checkpoint, "...started")

    with torch.no_grad():
        for batch_idx in range(len(batches)):
            batch = batches[batch_idx]
            new_row = {}
            datum = ds.character_tokenizer.encode(batch)
            pred = model.generate(
                "o2p",
                datum['enc_input_ids'],
                datum['enc_pad_mask'],
                None,
                None,
                deterministic=True,
            )




            for idx, orth in enumerate(batch):
                # Save the original input orthography
                new_row["word_raw"] = orth
                # Save the target phonology for the above input orthography
                phon = ds.cmudict[orth]
                new_row["phon_target"] = ":".join(phon)
                # Remove the start and end tokens from each phonological vector
                # and convert them from tensors to lists
                phon_pred_features = [tensor.tolist() for tensor in pred["phon_tokens"][idx][1:-1]]
                # Convert the phonological vectors to phonemes using Matt's handy dandy routine
                phon_pred = ds.phonology_tokenizer.traindata.convert_numeric_prediction(
                    phon_pred_features, phonology=True, hot_nodes=True
                )
                # Save the model's predicted pronunciation for this word. Phonemes are
                # separated by colons
                phon_pred = ["None" if p == None else p for p in phon_pred]
                new_row["phon_prediction"] = ":".join(phon_pred)
                # Save a boolean indicating whether the prediction was correct
                new_row["correct"] = new_row["phon_target"] == new_row["phon_prediction"]
                # Save the phonological features for the target phonology
                phon_target_features = ds.phonology_tokenizer.encode([orth])
                if phon_target_features:
                    phon_target_features = ";".join(
                        [
                            ":".join([str(v.item()) for v in vector])
                            for vector in phon_target_features["targets"][0]
                        ]
                    )
                else:
                    phon_target_features = "None"
                new_row["phon_target_features"] = phon_target_features
                # Save the phonological features for the predicted phonology
                phon_prediction_features = ";".join(
                    [
                        ":".join([str(int(v.item())) for v in vector])
                        for vector in pred["phon_vecs"][idx][1:-1]
                    ]
                )
                new_row["phon_prediction_features"] = phon_prediction_features
                # Save the phonological probabilities for the predicted phonology
                phon_prediction_probabilities = ";".join(
                    [
                        ":".join([str(v.item()) for v in vector])
                        for vector in pred["phon_probs"][idx][1:-1]
                    ]
                )
                new_row["phon_prediction_probabilities"] = phon_prediction_probabilities
                # Save the global encoding vector for the predicted phonology
                global_encoding = ";".join(
                    
                        [
                        ":".join([str(v.item()) for v in vector]) 
                        for vector in pred["global_encoding"][idx]
                        ]
                    
                    )
                new_row["global_encoding"] = global_encoding

                dfa = pd.DataFrame([new_row])
                dl.append(dfa)

                # calculate accuracies from units rather than string predictions

                by_phoneme = []
                by_unit = []
            
                target = ds.phonology_tokenizer.encode([orth])['targets'][0].tolist()
                prediction = [e.tolist() for e in pred['phon_vecs'][idx]]
                
                for i, e in enumerate(target):
                    by_phoneme.append(e == prediction[i])    
                    by_unit.extend([u[0] == u[1] for u in zip(prediction[i], e)])
                    
                phonemes_correct_total = sum(by_phoneme)
                phonemes_correct_mean = sum(by_phoneme)/len(by_phoneme)
                units_correct_mean = sum(by_unit)/len(by_unit)

                outfile2.write("{}, {}, {}, {}\n".format(orth, phonemes_correct_total, phonemes_correct_mean, units_correct_mean))



                print("Batch", batch_idx, "of", len(batches), "...done")
        pd.concat(dl).to_csv(outpath, index=False)
        outfile2.close()
        print("Checkpoint done:", checkpoint)

  0%|          | 0/15 [00:00<?, ?it/s]

root_2024-05-24_01h23m08973ms_chkpt001.pth arrived





RuntimeError: Error(s) in loading state_dict for Model:
	size mismatch for orth_position_embedding.weight: copying a param with shape torch.Size([13, 128]) from checkpoint, the shape in current model is torch.Size([15, 128]).
	size mismatch for phon_position_embedding.weight: copying a param with shape torch.Size([11, 128]) from checkpoint, the shape in current model is torch.Size([15, 128]).