### Comparing OE plots

In [7]:
import os

import numpy as np
import pandas as pd

from netam.framework import (
    trimmed_shm_model_outputs_of_crepe,
)

from shmex.shm_eval import oe_plot_of

import matplotlib
import matplotlib.pyplot as plt


from netam.common import (
    parameter_count_of_model,
)
from netam.framework import (
    load_crepe,
    trimmed_shm_model_outputs_of_crepe,
)

from shmex.shm_data import parent_and_child_differ, train_val_dfs_of_nicknames
from shmex.shm_zoo import standardize_and_optimize_branch_lengths
from shmex.shm_eval import (
    make_n_outside_of_shmoof_region, 
    ragged_np_pcp_encoding,
)


In [9]:
def write_test_accuracy(
    crepe_prefix,
    dataset_name,
    min_log_prob,
    directory=".",
    restrict_evaluation_to_shmoof_region=False,
    optimize_branch_lengths=False,
):
    matplotlib.use("Agg")
    crepe_basename = os.path.basename(crepe_prefix)
    comparison_title = f"{crepe_basename}-ON-{dataset_name}"
    crepe = load_crepe(crepe_prefix)
    _, pcp_df = train_val_dfs_of_nicknames(dataset_name)
    if restrict_evaluation_to_shmoof_region:
        pcp_df["child"] = make_n_outside_of_shmoof_region(pcp_df["child"])
        pcp_df = pcp_df[pcp_df.apply(parent_and_child_differ, axis=1)]
    if optimize_branch_lengths:
        pcp_df = standardize_and_optimize_branch_lengths(crepe.model, pcp_df)
        pcp_df.to_csv(
            f"{directory}/{comparison_title}.branch_lengths_csv",
            index=False,
            columns=["branch_length"],
        )

    if min_log_prob is not None:
        binning = np.linspace(min_log_prob, 0, 101)
    else:
        binning = None

    def test_accuracy_for(pcp_df, suffix):
        ratess, cspss = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
        site_count = crepe.encoder.site_count
        mut_indicators, base_idxss, masks = ragged_np_pcp_encoding(
            pcp_df["parent"], pcp_df["child"], site_count
        )
        val_bls = pcp_df["branch_length"].values
        df_dict = {
            "crepe_prefix": crepe_prefix,
            "crepe_basename": crepe_basename,
            "parameter_count": parameter_count_of_model(crepe.model),
            "dataset_name": f"{dataset_name}_{suffix}",
        }
        fig, oe_results, _ = oe_plot_of(
            ratess,
            masks,
            val_bls,
            mut_indicators,
            f"{comparison_title}_{suffix}",
            binning=binning,
        )
        oe_results.pop("counts_twinx_ax")
        df_dict.update(oe_results)
        return fig, pd.DataFrame(df_dict, index=[0])

    accuracy_list = []

    fig_all, df_all = test_accuracy_for(pcp_df, "all")
    accuracy_list.append(df_all)

    fig_all.savefig(f"{directory}/{comparison_title}_all.png")

    return fig_all

write_test_accuracy(
    "../train/trained_models/fivemer-shmoof_notbig-simple-0",
    "val_tangshm",
    min_log_prob=-4,
    directory="output",
    restrict_evaluation_to_shmoof_region=True,
    optimize_branch_lengths=True,
)


Loading /Users/matsen/data/v1/tang-deepshm-oof_pcp_2024-04-09_MASKED_NI.csv.gz


Finding optimal branch lengths: 100%|██████████| 9304/9304 [00:50<00:00, 185.86it/s]
