# Loading and evaluating checkpoints

This notebook demonstrates how checkpoints can be evaluated and combined into a reranking system. We use WikiHan as an example.

In [12]:
import torch
from einops import rearrange, repeat

import pytorch_lightning as pl
from lib.analysis_utils import *
from specialtokens import *
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import torchshow as ts
import os
import argparse
import json
import os
from dotenv import load_dotenv
load_dotenv()
import warnings
warnings.filterwarnings(action="ignore", message=".*num_workers.*")
warnings.filterwarnings(action="ignore", message=".*negatively affect performance.*")
warnings.filterwarnings(action="ignore", message=".*MPS available.*")

In [14]:
# === load d2p model === 7jff7nhf
d2p_run = SubmodelRunFromFile(checkpoint_path = 'demo_checkpoints/model-7jff7nhf:v0', config_path = 'best_hparams/d2p-wikihan-GRU_beam.pkl')
assert d2p_run.config_class.submodel == 'd2p'
d2p_dm = d2p_run.get_dm()
d2p_model = d2p_run.get_model(d2p_dm)

In [15]:
# === load p2d model === bkdvx767
p2d_run = SubmodelRunFromFile(checkpoint_path = 'demo_checkpoints/model-bkdvx767:v0', config_path = 'best_hparams/p2d-wikihan-GRU.pkl')
assert p2d_run.config_class.submodel == 'p2d'
p2d_dm = p2d_run.get_dm()
p2d_model = p2d_run.get_model(p2d_dm)

In [16]:
# === Making sure the two models use same dataset ===
assert d2p_run.config_class.dataset == p2d_run.config_class.dataset

In [17]:
# === results from grid search ===
def get_hparam():
    match (d2p_run.config_class.dataset, p2d_run.config_class.architecture):
        # this is average from reranking_eval2
        case ('chinese_baxter', 'GRU'):
            return round(7.2), 1.7550000000000003
        case ('chinese_baxter', 'Transformer'):
            return round(6.50000001), 2.43
        case ('chinese_baxter', 'JambuTransformer'):
            return round(6.50000001), 2.415
        case ('chinese_wikihan2022', 'GRU'):
            return round(5.9), 1.395
        case ('chinese_wikihan2022', 'Transformer'):
            return round(6.3), 1.275
        case ('chinese_wikihan2022', 'JambuTransformer'):
            return round(6.6), 1.2600000000000002
        case ('chinese_wikihan2022_augmented', 'GRU'):
            return round(7.0), 1.6199999999999997
        case ('chinese_wikihan2022_augmented', 'Transformer'):
            return round(8.3), 1.7550000000000001
        case ('chinese_wikihan2022_augmented', 'JambuTransformer'):
            return round(6.6), 1.5749999999999997
        case ('Nromance_ipa', 'GRU'):
            return round(5.4), 0.41999999999999993
        case ('Nromance_ipa', 'Transformer'):
            return round(6.1), 0.5549999999999999
        case ('Nromance_ipa', 'JambuTransformer'):
            return round(6.0), 0.5850000000000001
        case ('Nromance_orto', 'GRU'):
            return round(6.1), 0.8699999999999999
        case ('Nromance_orto', 'Transformer'):
            return round(5.5000001), 0.99
        case ('Nromance_orto', 'JambuTransformer'):
            return round(5.2), 0.915
        case _:
            raise NotImplemented

## Evaluate Submodels

### Evaluate GRU-BS with k = 10

In [18]:
beam_search_eval(d2p_model, d2p_dm, beam_size=10, split='test')

Testing DataLoader 0: 100%|██████████| 17/17 [00:03<00:00,  4.50it/s]



Testing DataLoader 0: 100%|██████████| 17/17 [00:04<00:00,  3.85it/s]


{'d2p/test/loss': 0.5365998148918152,
 'd2p/test/recon_loss': 0.5365998148918152,
 'd2p/test/kl_loss': 0.0,
 'd2p/test/accuracy': 0.5285575985908508,
 'd2p/test/char_edit_distance': 0.8818973898887634,
 'd2p/test/phoneme_edit_distance': 0.7676669955253601,
 'd2p/test/phoneme_error_rate': 0.17981860041618347,
 'd2p/test/feature_error_rate': 0.07038774341344833,
 'd2p/test/bcubed_f_score': 0.7298532128334045,
 'd2p/test/avg_target_phoneme_len': 4.2691192626953125,
 'd2p/test/avg_prediction_phoneme_len': 4.254598140716553,
 'd2p/test/avg_t_rank_in_beam_search': 1.0334448160535117,
 'd2p/test/std_t_rank_in_beam_search': 1.8629152827312523,
 'd2p/test/target_in_beam': 0.8683446049690247}

### Evaluate Reflex Prediction

In [19]:
eval_on_set(p2d_model, p2d_dm, split='test')

Testing: 0it [00:00, ?it/s]

Testing DataLoader 0: 100%|██████████| 17/17 [00:02<00:00,  7.14it/s]



Testing DataLoader 0: 100%|██████████| 17/17 [00:05<00:00,  3.16it/s]


{'p2d/test/loss': 0.45579683780670166,
 'p2d/test/recon_loss': 0.45579683780670166,
 'p2d/test/kl_loss': 0.0,
 'p2d/test/Mandarin/accuracy': 0.7299128770828247,
 'p2d/test/Mandarin/char_edit_distance': 0.6292352080345154,
 'p2d/test/Mandarin/phoneme_edit_distance': 0.39399805665016174,
 'p2d/test/Mandarin/phoneme_error_rate': 0.12041420489549637,
 'p2d/test/Mandarin/feature_error_rate': 0.03036404587328434,
 'p2d/test/Mandarin/bcubed_f_score': 0.7924480438232422,
 'p2d/test/Gan/accuracy': 0.7351598143577576,
 'p2d/test/Gan/char_edit_distance': 0.6392694115638733,
 'p2d/test/Gan/phoneme_edit_distance': 0.3652968108654022,
 'p2d/test/Gan/phoneme_error_rate': 0.10840108245611191,
 'p2d/test/Gan/feature_error_rate': 0.02858753129839897,
 'p2d/test/Gan/bcubed_f_score': 0.8253531455993652,
 'p2d/test/Jin/accuracy': 0.7120622396469116,
 'p2d/test/Jin/char_edit_distance': 0.774319052696228,
 'p2d/test/Jin/phoneme_edit_distance': 0.47081711888313293,
 'p2d/test/Jin/phoneme_error_rate': 0.140371

## Evaluate Reranked Reconstruction

In [20]:
best_beam_size, best_beam_reranker_weight_ratio = get_hparam()
batched_reranking_eval(d2p_model, p2d_model, d2p_dm, 
    reranker = BatchedCorrectRateReranker(p2d_model),
    rescorer = BatchedLinearRescorer(original_log_prob_weight = 1.0, reranker_weight = best_beam_reranker_weight_ratio), 
    beam_size = best_beam_size, 
    split = 'test',
)

Testing: 0it [00:00, ?it/s]

Testing DataLoader 0: 100%|██████████| 17/17 [00:09<00:00,  1.78it/s]



Testing DataLoader 0: 100%|██████████| 17/17 [00:10<00:00,  1.69it/s]


{'d2p/test/loss': 0.5365998148918152,
 'd2p/test/recon_loss': 0.5365998148918152,
 'd2p/test/kl_loss': 0.0,
 'd2p/test/accuracy': 0.5682477951049805,
 'd2p/test/char_edit_distance': 0.7909002900123596,
 'd2p/test/phoneme_edit_distance': 0.7008712291717529,
 'd2p/test/phoneme_error_rate': 0.16417233645915985,
 'd2p/test/feature_error_rate': 0.06428731232881546,
 'd2p/test/bcubed_f_score': 0.750727117061615,
 'd2p/test/avg_target_phoneme_len': 4.2691192626953125,
 'd2p/test/avg_prediction_phoneme_len': 4.254598140716553}