# Test the model from checkpoints
This notebook allows you to take checkpoints from a trained model and make predictions for a set of words for each checkpoint. The data are written to CSV, one for each checkpoint (where a checkpoint corresponds to an epoch).

In [1]:
from src.dataset import ConnTextULDataset
import pandas as pd
import torch
import glob
import tqdm
from src.model import Model
from addict import Dict as AttrDict
from pathlib import Path

import nltk
cmu = nltk.corpus.cmudict.dict()

%reload_ext autoreload
%autoreload 2

[nltk_data] Downloading package cmudict to /root/nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


# Pre-aggregate the word data
Each of these models was trained with a different dataset, all contained in `data/SSSR/`. For the frequencies for each of those sets for the purposes of analysis, we can reference those files individually later. Right now, we just need the words. To make things simpler we will read in all of them and get only the unique words, rather than testing on all the datasets individually, which would be onerous because there are several different sets, they contain repeated words across sets and repreated words within any given set.

In [7]:
import os

DIR = "data/SSSR2024/"
words = []

for filename in os.listdir(DIR):
    if filename.endswith('.csv'):

        filepath = os.path.join(DIR, filename)
        # read the .csv file into a pandas DataFrame and store it in the dictionary
        words.extend(pd.read_csv(filepath, header=None)[0].tolist())

words = [word for word in words if isinstance(word, str)]
words = sorted(list(set(words)))

with open('data/SSSR2024/all_input_words.csv', 'w') as f:
    f.write("word_raw\n")
    for word in words:
        f.write("{}\n".format(word))
f.close()


Let's aggregate the Woodcock words and use those for testing as well.

In [26]:
import json

with open('data/wj_iii_form_a.json', 'r') as file:
    wj3 = json.load(file)

wj3 = [word for word in wj3.keys() if len(word) > 1 and word in cmu.keys()]


## Config
Load in a dataset to Traindata class which is embedded inside the phonology tokenizer. Later implementations should consider breaking this process apart such that Traindata is initialized prior to tokenization.

In [62]:
os.listdir(trade_book_directories)

['trade_books_75_percent_background_25_percent',
 'trade_books_25_percent_background_75_percent',
 'trade_books_100_percent',
 'trade_books_50_percent_background_50_percent',
 'trade_books_weighted_sample',
 'trade_books_0_percent_background_100_percent']

In [69]:
FILEPATH ='models/SSSR2024/trade_books/trade_books_0_percent_background_100_percent/trade_books_0_percent_background_100_percent.csv'

In [70]:
FILEPATH

'models/SSSR2024/trade_books/trade_books_0_percent_background_100_percent/trade_books_0_percent_background_100_percent.csv'

In [93]:
os.listdir(os.path.join(trade_book_directories, dir))[0]

'root_2024-05-24_02h27m23023ms_chkpt111.pth'

In [94]:
trade_book_directories = "models/SSSR2024/trade_books"

for dir in os.listdir(trade_book_directories):
    filenames = os.listdir(os.path.join(trade_book_directories, dir))
    for filename in filenames:
        checkpoints = []
        FILEPATH = os.path.join(trade_book_directories, dir, filename)



        if filename.startswith("trade_books") & filename.endswith(".csv"):

            words = []
            words.extend(pd.read_csv(FILEPATH)['word_raw'].tolist())
            maxorth = max([len(word) for word in words])
            maxphon = max([len(cmu[word][0]) for word in words])
            for word in wj3:
                if len(word) <= maxorth & len(cmu[word][0]) <= maxphon:
                    words.extend(word)

            words = sorted(list(set(words)))
            outfile = os.path.join(trade_book_directories, dir, "words_for_test.csv")
            with open(outfile, 'w') as file:
                file.write("word_raw\n")
                for word in words:
                    file.write('{}\n'.format(word))
            file.close()
            config = type("config",
                (object,),
                {"dataset_filename": Path(outfile)}, )
            ds = ConnTextULDataset(config)
        if filename.endswith(".pth"):
               checkpoints.append(filename)
        checkpoints.sort()
        print(checkpoints)

['root_2024-05-24_03h25m28092ms_chkpt131.pth']
['root_2024-05-24_03h25m28092ms_chkpt101.pth']
['root_2024-05-24_03h25m28092ms_chkpt091.pth']
['root_2024-05-24_03h25m28092ms_chkpt031.pth']
['root_2024-05-24_03h25m28092ms_chkpt021.pth']
['root_2024-05-24_03h25m28092ms_chkpt051.pth']
['root_2024-05-24_03h25m28092ms_chkpt111.pth']
['root_2024-05-24_03h25m28092ms_chkpt001.pth']
['root_2024-05-24_03h25m28092ms_chkpt041.pth']
['root_2024-05-24_03h25m28092ms_chkpt011.pth']
['root_2024-05-24_03h25m28092ms_chkpt081.pth']
['root_2024-05-24_03h25m28092ms_chkpt071.pth']
Cache folder: /workspaces/BRIDGE/data/.cache already exists
orthpad changed to 0 because onehot encodings were selected for orthography
Representations initialized. Done.
[]
['root_2024-05-24_03h25m28092ms_chkpt141.pth']
['root_2024-05-24_03h25m28092ms_chkpt121.pth']
['root_2024-05-24_03h25m28092ms_chkpt061.pth']
['root_2024-05-24_02h41m32173ms_chkpt091.pth']
['root_2024-05-24_02h41m32173ms_chkpt011.pth']
['root_2024-05-24_02h41m321

## Read in checkpoints
We nee the checkpoint file names in order to iterate through them, load and predict. Note that the data are written back in the same location from which the checkpoints are read.

In [16]:
#PATH = "models/modelresults59355/root_2024-04-24_15h38m59355ms_chkpt*.pth"
PATH = "models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt*.pth"
checkpoints = glob.glob(PATH)
checkpoints.sort()
print(f"{checkpoints=}")

checkpoints=['models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt001.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt011.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt021.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt031.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt041.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt051.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt061.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt071.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt081.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt091.pth', 'models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt101.pth', 'models/SSSR2024/tr

## Establish batches
Larger batches make for faster processing, but your machine may impose an upper limit.

In [17]:
batch_size = 1000
all_words=set()
all_words.update(ds.words)
batches = [list(all_words)[i:i+batch_size] for i in range(0, len(all_words), batch_size)]

## Write predictions
Iterate through and generate, write predictions for the words you've initialized in `config`

In [18]:
for checkpoint in tqdm.tqdm(checkpoints):

    outfile = checkpoint.replace(".pth", ".csv")

    chkpt = torch.load(checkpoint)
    dfa = pd.DataFrame(columns=["phon_prediction"])
    model = Model(AttrDict(chkpt["config"]), ds)
    model.load_state_dict(chkpt["model_state_dict"])
    model.eval()  # Set the model to evaluation mode
    
    dl = []

    start_row = 0  # Initialize starting row for each new sheet

    print("Checkpoint", checkpoint, "...started")

    with torch.no_grad():
        for batch_idx in range(len(batches)):
            batch = batches[batch_idx]
            new_row = {}
            datum = ds.character_tokenizer.encode(batch)
            pred = model.generate(
                "o2p",
                datum['enc_input_ids'],
                datum['enc_pad_mask'],
                None,
                None,
                deterministic=True,
            )
            for idx, orth in enumerate(batch):
                # Save the original input orthography
                new_row["word_raw"] = orth
                # Save the target phonology for the above input orthography
                phon = ds.cmudict[orth]
                new_row["phon_target"] = ":".join(phon)
                # Remove the start and end tokens from each phonological vector
                # and convert them from tensors to lists
                phon_pred_features = [tensor.tolist() for tensor in pred["phon_tokens"][idx][1:-1]]
                # Convert the phonological vectors to phonemes using Matt's handy dandy routine
                phon_pred = ds.phonology_tokenizer.traindata.convert_numeric_prediction(
                    phon_pred_features, phonology=True, hot_nodes=True
                )
                # Save the model's predicted pronunciation for this word. Phonemes are
                # separated by colons
                phon_pred = ["None" if p == None else p for p in phon_pred]
                new_row["phon_prediction"] = ":".join(phon_pred)
                # Save a boolean indicating whether the prediction was correct
                new_row["correct"] = new_row["phon_target"] == new_row["phon_prediction"]
                # Save the phonological features for the target phonology
                phon_target_features = ds.phonology_tokenizer.encode([orth])
                if phon_target_features:
                    phon_target_features = ";".join(
                        [
                            ":".join([str(v.item()) for v in vector])
                            for vector in phon_target_features["targets"][0]
                        ]
                    )
                else:
                    phon_target_features = "None"
                new_row["phon_target_features"] = phon_target_features
                # Save the phonological features for the predicted phonology
                phon_prediction_features = ";".join(
                    [
                        ":".join([str(int(v.item())) for v in vector])
                        for vector in pred["phon_vecs"][idx][1:-1]
                    ]
                )
                new_row["phon_prediction_features"] = phon_prediction_features
                # Save the phonological probabilities for the predicted phonology
                phon_prediction_probabilities = ";".join(
                    [
                        ":".join([str(v.item()) for v in vector])
                        for vector in pred["phon_probs"][idx][1:-1]
                    ]
                )
                new_row["phon_prediction_probabilities"] = phon_prediction_probabilities
                # Save the global encoding vector for the predicted phonology
                global_encoding = ";".join(
                    
                        [
                        ":".join([str(v.item()) for v in vector]) 
                        for vector in pred["global_encoding"][idx]
                        ]
                    
                    )
                new_row["global_encoding"] = global_encoding

                
                dfa = pd.DataFrame([new_row])
                dl.append(dfa)
            print("Batch", batch_idx, "of", len(batches), "...done")
    pd.concat(dl).to_csv(outfile, index=False)
    print("Checkpoint done:", checkpoint)
        

  0%|          | 0/15 [00:00<?, ?it/s]

  output = torch._nested_tensor_from_mask(output, src_key_padding_mask.logical_not(), mask_check=False)


Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt001.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


  7%|▋         | 1/15 [00:09<02:18,  9.91s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt001.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt011.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 13%|█▎        | 2/15 [00:17<01:52,  8.62s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt011.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt021.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 20%|██        | 3/15 [00:24<01:36,  8.02s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt021.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt031.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 27%|██▋       | 4/15 [00:33<01:29,  8.11s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt031.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt041.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 33%|███▎      | 5/15 [00:41<01:21,  8.11s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt041.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt051.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 40%|████      | 6/15 [00:49<01:12,  8.11s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt051.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt061.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 47%|████▋     | 7/15 [00:57<01:05,  8.16s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt061.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt071.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 53%|█████▎    | 8/15 [01:05<00:57,  8.19s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt071.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt081.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 60%|██████    | 9/15 [01:13<00:48,  8.00s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt081.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt091.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 67%|██████▋   | 10/15 [01:21<00:39,  7.97s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt091.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt101.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 73%|███████▎  | 11/15 [01:29<00:31,  7.94s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt101.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt111.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 80%|████████  | 12/15 [01:37<00:24,  8.17s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt111.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt121.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 87%|████████▋ | 13/15 [01:48<00:17,  8.89s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt121.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt131.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


 93%|█████████▎| 14/15 [01:58<00:09,  9.19s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt131.pth
Checkpoint models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt141.pth ...started


  gen_phon_tokes.append(torch.tensor(new_phon_tokes))


Batch 0 of 2 ...done
Batch 1 of 2 ...done


100%|██████████| 15/15 [02:06<00:00,  8.41s/it]

Checkpoint done: models/SSSR2024/trade_books/fifty_percent/root_2024-05-24_03h05m33587ms_chkpt141.pth



