# Test the model from checkpoints
This notebook allows you to take checkpoints from a trained model and make predictions for a set of words for each checkpoint. The data are written to CSV, one for each checkpoint (where a checkpoint corresponds to an epoch).

In [None]:
from src.dataset import ConnTextULDataset
import pandas as pd
import torch
import glob
import tqdm
from src.model import Model
from addict import Dict as AttrDict
from pathlib import Path

%reload_ext autoreload
%autoreload 2

## Config
Load in a dataset to Traindata class which is embedded inside the phonology tokenizer. Later implementations should consider breaking this process apart such that Traindata is initialized prior to tokenization.

In [None]:
config = type(
    "config",
    (object,),
    {"dataset_filename": Path("data/kidwords_5000000_020724_.csv")},
)
ds = ConnTextULDataset(config)

## Read in checkpoints
We nee the checkpoint file names in order to iterate through them, load and predict. Note that the data are written back in the same location from which the checkpoints are read.

In [None]:
PATH = "models/modelresults59355/root_2024-04-24_15h38m59355ms_chkpt*.pth"
checkpoints = glob.glob(PATH)
checkpoints.sort()
print(f"{checkpoints=}")

## Establish batches
Larger batches make for faster processing, but your machine may impose an upper limit.

In [None]:
batch_size = 1000
all_words=set()
all_words.update(ds.words)
batches = [list(all_words)[i:i+batch_size] for i in range(0, len(all_words), batch_size)]

## Write predictions
Iterate through and generate, write predictions for the words you've initialized in `config`

In [None]:
for checkpoint in tqdm.tqdm(checkpoints):

    outfile = checkpoint.replace(".pth", ".csv")


    chkpt = torch.load(checkpoint)
    dfa = pd.DataFrame(columns=["phon_prediction"])
    model = Model(AttrDict(chkpt["config"]), ds)
    model.load_state_dict(chkpt["model_state_dict"])
    model.eval()  # Set the model to evaluation mode
    
    dl = []

    start_row = 0  # Initialize starting row for each new sheet

    print("Checkpoint", checkpoint, "...started")

    with torch.no_grad():
        for batch_idx in range(len(batches)):
            batch = batches[batch_idx]
            new_row = {}
            datum = ds.character_tokenizer.encode(batch)
            pred = model.generate(
                "o2p",
                datum['enc_input_ids'],
                datum['enc_pad_mask'],
                None,
                None,
                deterministic=True,
            )
            for idx, orth in enumerate(batch):
                # Save the original input orthography
                new_row["word_raw"] = orth
                # Save the target phonology for the above input orthography
                phon = ds.cmudict[orth]
                new_row["phon_target"] = ":".join(phon)
                # Remove the start and end tokens from each phonological vector
                # and convert them from tensors to lists
                phon_pred_features = [tensor.tolist() for tensor in pred["phon_tokens"][idx][1:-1]]
                # Convert the phonological vectors to phonemes using Matt's handy dandy routine
                phon_pred = ds.phonology_tokenizer.traindata.convert_numeric_prediction(
                    phon_pred_features, phonology=True, hot_nodes=True
                )
                # Save the model's predicted pronunciation for this word. Phonemes are
                # separated by colons
                phon_pred = ["None" if p == None else p for p in phon_pred]
                new_row["phon_prediction"] = ":".join(phon_pred)
                # Save a boolean indicating whether the prediction was correct
                new_row["correct"] = new_row["phon_target"] == new_row["phon_prediction"]
                # Save the phonological features for the target phonology
                phon_target_features = ds.phonology_tokenizer.encode([orth])
                if phon_target_features:
                    phon_target_features = ";".join(
                        [
                            ":".join([str(v.item()) for v in vector])
                            for vector in phon_target_features["targets"][0]
                        ]
                    )
                else:
                    phon_target_features = "None"
                new_row["phon_target_features"] = phon_target_features
                # Save the phonological features for the predicted phonology
                phon_prediction_features = ";".join(
                    [
                        ":".join([str(int(v.item())) for v in vector])
                        for vector in pred["phon_vecs"][idx][1:-1]
                    ]
                )
                new_row["phon_prediction_features"] = phon_prediction_features
                # Save the phonological probabilities for the predicted phonology
                phon_prediction_probabilities = ";".join(
                    [
                        ":".join([str(v.item()) for v in vector])
                        for vector in pred["phon_probs"][idx][1:-1]
                    ]
                )
                new_row["phon_prediction_probabilities"] = phon_prediction_probabilities
                # Save the global encoding vector for the predicted phonology
                global_encoding = ";".join(
                    
                        [
                        ":".join([str(v.item()) for v in vector]) 
                        for vector in pred["global_encoding"][idx]
                        ]
                    
                    )
                new_row["global_encoding"] = global_encoding

                
                dfa = pd.DataFrame([new_row])
                dl.append(dfa)
            print("Batch", batch_idx, "of", len(batches), "...done")
    pd.concat(dl).to_csv(outfile, index=False)
    print("Checkpoint done:", checkpoint)
    

In [None]:
tmp = ds.phonology_tokenizer.encode([orth])['targets'][0]

In [None]:
word_end = len(ds.phonology_tokenizer.traindata.phonreps["%"])

with open('models/modelresults59355/accuracy.csv', 'w') as f:

    with torch.no_grad():
        for batch_idx in range(len(batches)):
            batch = batches[batch_idx]
            new_row = {}
            datum = ds.character_tokenizer.encode(batch)
            pred = model.generate(
                "o2p",
                datum['enc_input_ids'],
                datum['enc_pad_mask'],
                None,
                None,
                deterministic=True,
            )
            for idx, orth in enumerate(batch):

                by_phoneme = []
            
                target = ds.phonology_tokenizer.encode([orth])['targets'][0].tolist()
                prediction = [e.tolist() for e in pred['phon_vecs'][idx]]
                
                for i, e in enumerate(target):
                    by_phoneme.append(e == prediction[i])    

                phonemes_correct_total = sum(by_phoneme)
                phonemes_correct_mean = sum(by_phoneme)/len(by_phoneme)

                f.write("{}, {}, {}\n".format(orth, phonemes_correct_total, phonemes_correct_mean))

            print(batch_idx, "out of", len(batches), "...done")
f.close()