In [None]:
import copy
import os
import glob
import ast
import sys

import pandas as pd
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import QuadMesh

sys.path.append("/home/akubaney/projects/na_mpnn/evaluation")
from na_eval_utils import read_json_file

In [None]:
font_path = "./ARIAL.TTF"
matplotlib.font_manager.fontManager.addfont(font_path)

# 2) Tell Matplotlib to load it
plt.rcParams['font.family'] = "Arial"

In [None]:
def parse_model_name(json_path: str, parent_dir: str) -> str:
    """
    Infer the model name from the JSON’s path, e.g.
    parent_dir/model_name/some_subdir/design_json/foo.json → model_name
    """
    rel = os.path.relpath(json_path, parent_dir)
    return rel.split(os.sep)[0]

def get_polymer_group(row: dict) -> str | None:
    """
    Your original grouping logic. Returns one of
      "DNA", "Protein-DNA", "RNA", "Protein-RNA"
    or None if it's a hybrid (which we'll skip).
    """
    nac = ast.literal_eval(row["nucleic_acid_chain_cluster_ids_chain_types"])
    pcc = ast.literal_eval(row["protein_chain_cluster_ids_chain_types"])

    has_protein = len(pcc) > 0
    has_dna     = "polydeoxyribonucleotide" in nac
    has_rna     = "polyribonucleotide" in nac
    has_hybrid  = "polydeoxyribonucleotide/polyribonucleotide hybrid" in nac

    if has_protein and has_dna and not has_rna and not has_hybrid:
        return "DNA (protein context)"
    if has_protein and has_rna and not has_dna and not has_hybrid:
        return "RNA (protein context)"
    if has_dna and not has_protein and not has_rna and not has_hybrid:
        return "DNA"
    if has_rna and not has_protein and not has_dna and not has_hybrid:
        return "RNA"
    return None

def get_ppm_group(row: dict) -> str | None:
    """
    Classify the source of the PPMs in the row.
    """
    ppm_paths = ast.literal_eval(row["ppm_paths"])

    has_ppm = len(ppm_paths) > 0
    ppm_from_crystal = row["dataset_name"] == "rcsb_cif_na"
    ppm_from_distillation = (row["dataset_name"] == "rf2na_distillation_cis_bp") or (row["dataset_name"] == "rf2na_distillation_transfac")

    if has_ppm and ppm_from_crystal:
        return "Crystal"
    if has_ppm and ppm_from_distillation:
        return "Distillation"
    if has_ppm and not ppm_from_crystal and not ppm_from_distillation:
        return "Other"
    return None

def extract_records(
    df: pd.DataFrame,
    json_paths: list[str],
    parent_dir: str,
    metric_fns: dict,        # map metric_name → function(js) or None
    orig_fn,
    name_fn = None,
    path_column: str = "structure_path",
    group_type: str = None
) -> pd.DataFrame:
    """
    For each JSON in json_paths:
      - read it, extract orig = orig_fn(js)
      - match it to df[path_column]
      - classify it (polymer/ppm)
      - parse model name
      - pull out metrics
    Return a DataFrame of all records *but* only for those origs
    that show up under *every* model.
    """
    recs = []
    for jp in json_paths:
        js = read_json_file(jp)
        orig = orig_fn(js)
        if orig is None:
            continue

        if name_fn is not None:
            # get the name from the JSON
            name = name_fn(js)
            if name is None:
                continue

        # find the CSV row
        match = df[df[path_column] == orig]
        if match.empty:
            continue

        row = match.iloc[0].to_dict()
        # classify
        if group_type == "polymer":
            group = get_polymer_group(row)
            if group is None:
                continue
        elif group_type == "ppm":
            group = get_ppm_group(row)
            if group is None:
                continue
        else:
            group = None

        model = parse_model_name(jp, parent_dir)

        # build the record
        row_dict = {
            path_column: orig,
            "Name":     name if name_fn is not None else None,
            "Model":    model,
            "Group":    group,
            "dataset_name": row["dataset_name"],
        }
        for mname, mfn in metric_fns.items():
            row_dict[mname] = js[mname] if mfn is None else mfn(js)

        recs.append(row_dict)

    records = pd.DataFrame(recs)
    if records.empty:
        return records

    # 1) find the full set of models
    all_models = set(records["Model"].unique())

    # 2) for each orig, get the set of models it appears under
    orig_to_models = (
        records
        .groupby(path_column)["Model"]
        .apply(set)
    )

    # 3) keep only those origs whose set == all_models
    valid_origs = orig_to_models[orig_to_models == all_models].index

    # 4) filter
    return records[records[path_column].isin(valid_origs)]

In [None]:
design_valid_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_csvs/design_valid.csv")
specificity_valid_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_csvs/specificity_valid.csv")

design_test_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_csvs/design_test.csv")
design_rna_monomer_test_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_csvs/design_rna_monomer_test.csv")
design_pseudoknot_test_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_csvs/design_pseudoknot_test.csv")

specificity_test_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_csvs/specificity_test.csv")

In [None]:
# Sub the pseudoknot test set.
pseudoknot_pdb_ids = [
    '1drz',
    '2m8k', 
    '2miy', 
    '3q3z', 
    '4oqu', 
    '4plx', 
    '4znp', 
    '7kd1', 
    '7kga', 
    '7qr4'
]

design_pseudoknot_test_df = design_pseudoknot_test_df[design_pseudoknot_test_df["id"].isin(pseudoknot_pdb_ids)]

In [None]:
design_valid_plot_df = extract_records(
    design_valid_df, 
    glob.glob(os.path.join("/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_valid", '*', '*', "design_json", '*.json')), 
    "/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_valid", 
    group_type = "polymer",
    metric_fns = {
        "Sequence Recovery": lambda json_dict: float(json_dict["tool_reported_sequence_recovery"])
    },
    orig_fn = lambda json_dict: json_dict.get("original_input_structure_path")
)
len(design_valid_plot_df)

In [None]:
specificity_valid_plot_df = extract_records(
    specificity_valid_df,
    glob.glob(os.path.join("/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/specificity_valid_scores", '*', '*', '*.json')), 
    "/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/specificity_valid_scores", 
    group_type = "ppm",
    metric_fns = {
        "Mean Absolute Error": lambda json_dict: float(json_dict["mean_absolute_error_dna"]["mean_absolute_error"]),
        "Cross-Entropy": lambda json_dict: float(json_dict["cross_entropy_dna"]["cross_entropy"])
    },
    orig_fn = lambda json_dict: read_json_file(json_dict["subject_path"])["original_input_structure_path"]
)
len(specificity_valid_plot_df)

In [None]:
specificity_valid_hypersweep_plot_df = extract_records(
    specificity_valid_df,
    glob.glob(os.path.join("/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/specificity_valid_hypersweep_scores", '*', '*', '*.json')),
    "/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/specificity_valid_hypersweep_scores", 
    group_type = "ppm",
    metric_fns = {
        "Mean Absolute Error": lambda json_dict: float(json_dict["mean_absolute_error_dna"]["mean_absolute_error"]),
        "Cross-Entropy": lambda json_dict: float(json_dict["cross_entropy_dna"]["cross_entropy"])
    },
    orig_fn = lambda json_dict: read_json_file(json_dict["subject_path"])["original_input_structure_path"]
)
len(specificity_valid_hypersweep_plot_df)

In [None]:
design_test_plot_df = extract_records(
    design_test_df, 
    glob.glob(os.path.join("/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_test", '*', '*', "design_json", '*.json')), 
    "/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_test", 
    group_type = "polymer",
    metric_fns = {
        "Sequence Recovery": lambda json_dict: float(json_dict["tool_reported_sequence_recovery"])
    },
    orig_fn = lambda json_dict: json_dict.get("original_input_structure_path")
)
len(design_test_plot_df)

In [None]:
design_rna_monomer_test_plot_df = extract_records(
    design_rna_monomer_test_df, 
    glob.glob(os.path.join("/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_rna_monomer_test", '*', '*', "design_json", '*.json')), 
    "/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_rna_monomer_test", 
    group_type = "polymer",
    metric_fns = {
        "Sequence Recovery": lambda json_dict: float(json_dict["tool_reported_sequence_recovery"])
    },
    orig_fn = lambda json_dict: json_dict.get("original_input_structure_path")
)
len(design_rna_monomer_test_plot_df)

In [None]:
design_pseudoknot_test_plot_df = extract_records(
    design_pseudoknot_test_df, 
    glob.glob(os.path.join("/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_pseudoknot_test", '*', '*', "design_json", '*.json')), 
    "/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_pseudoknot_test", 
    group_type = "polymer",
    metric_fns = {
        "Sequence Recovery": lambda json_dict: float(json_dict["tool_reported_sequence_recovery"])
    },
    orig_fn = lambda json_dict: json_dict.get("original_input_structure_path")
)
len(design_pseudoknot_test_plot_df)

In [None]:
design_pseudoknot_test_scores_plot_df = extract_records(
    design_pseudoknot_test_df, 
    glob.glob(os.path.join("/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_pseudoknot_test_scores", '*', '*', '*.json')), 
    "/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/design_pseudoknot_test_scores", 
    group_type = "polymer",
    metric_fns = {
        "AF3 C1'-RMSD": lambda json_dict: json_dict.get("alphafold3_c1_prime_rmsd"),
        "AF3 C1'-LDDT": lambda json_dict: json_dict.get("alphafold3_c1_prime_lddt"),
        "AF3 C1'-GDDT": lambda json_dict: json_dict.get("alphafold3_c1_prime_gddt"),
        "AF3 pLDDT": lambda json_dict: json_dict.get("alphafold3_plddt"),
        "AF3 pTM": lambda json_dict: json_dict.get("alphafold3_ptm"),
        "AF3 pAE": lambda json_dict: json_dict.get("alphafold3_pae"),
        "RibonanzaNet OKS": lambda json_dict: json_dict.get("ribonanza_net_openknot_score"),
        "RibonanzaNet Pair F1": lambda json_dict: json_dict.get("ribonanza_net_f1_score_pairs"), 
        "RibonanzaNet Loop F1": lambda json_dict: json_dict.get("ribonanza_net_f1_score_loops")
    },
    orig_fn = lambda json_dict: read_json_file(read_json_file(json_dict["subject_path"])["design_input_path"])["original_input_structure_path"],
    name_fn = lambda json_dict: json_dict.get("subject_name"),
)
len(design_pseudoknot_test_scores_plot_df)

In [None]:
specificity_test_plot_df = extract_records(
    specificity_test_df, 
    glob.glob(os.path.join("/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/specificity_test_scores", '*', '*', '*.json')), 
    "/home/akubaney/projects/na_mpnn/evaluation/evaluation_outputs/specificity_test_scores", 
    group_type = "ppm",
    metric_fns = {
        "Mean Absolute Error": lambda json_dict: float(json_dict["mean_absolute_error_dna"]["mean_absolute_error"]),
        "Cross-Entropy": lambda json_dict: float(json_dict["cross_entropy_dna"]["cross_entropy"])
    },
    orig_fn = lambda json_dict: read_json_file(json_dict["subject_path"])["original_input_structure_path"]
)
len(specificity_test_plot_df)

In [None]:
# Save the dataframes.
design_valid_plot_df.to_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_valid_plot.csv", index=False)
specificity_valid_plot_df.to_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/specificity_valid_plot.csv", index=False)
specificity_valid_hypersweep_plot_df.to_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/specificity_valid_hypersweep_plot.csv", index=False)
design_test_plot_df.to_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_test_plot.csv", index=False)
design_rna_monomer_test_plot_df.to_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_rna_monomer_test_plot.csv", index=False)
design_pseudoknot_test_plot_df.to_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_pseudoknot_test_plot.csv", index=False)
design_pseudoknot_test_scores_plot_df.to_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_pseudoknot_test_scores_plot.csv", index=False)
specificity_test_plot_df.to_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/specificity_test_plot.csv", index=False)

In [None]:
polymer_type_palette = {
    "DNA": "#FF4B4B",
    "DNA (protein context)": "#FF7F7F",
    "RNA": "#4B4BFF",
    "RNA (protein context)": "#7F7FFF",
}

ppm_type_palette = {
    "Crystal": "#E0BBE4",
    "Distillation": "#5D3A9B",
    "Other": "#A9A9A9"
}
design_model_palette = {
    "na_mpnn": "#2ECC71",
    "grnade": "#D3D3D3",
    "rhodesign": "#A9A9A9"
}
specificity_model_palette = {
    "na_mpnn": "#2ECC71",
    "deeppbs": "#D3D3D3",
}
model_name_to_label = {
    "na_mpnn": "NA-MPNN",
    "grnade": "gRNAde",
    "rhodesign": "RhoDesign",
    "deeppbs": "DeepPBS",
}
data_source_name_to_label = {
    "rf2na_distillation_cis_bp": "CIS-BP",
    "rf2na_distillation_transfac": "TRANSFAC"
}

metric_vs_model_step_style = {
    "title_fontsize": None,
    "axis_title_fontsize": 8,
    "tick_labelsize": 6,
    "legend_title_fontsize": None,
    "legend_fontsize": 6,
    "dpi": 300,
    "hide_top_right_spines": True,
    "spine_linewidth": 0.5,
    "tick_width": 0.5,
    "line_width": 0.5,
    # Legend saving defaults
    "save_legend_separately": True,
    "legend_save_suffix": "_legend",
    "legend_figsize": (2, 2),
    "show_legend": True,
    "legend_fill_figure": True,
    "legend_frameon": False,
    "figsize": (125 / 25.4, 50 / 25.4),
    "legend_handlelength": 1.5,
    "legend_handleheight": 0.3,
    "legend_labelspacing": 0.5,
    "legend_borderpad": None,
    "legend_frame_linewidth": None,
    "legend_handle_linewidth": 0.5,
    # New legend controls
    "legend_markerscale": None,
    "legend_handletextpad": 0.2,
    "markersize": 3
}

metric_heatmap_style = {
    "title_fontsize": 8,
    "axis_title_fontsize": 8,
    "tick_labelsize": 6,
    "legend_title_fontsize": None,
    "legend_fontsize": 6,
    "annot_fontsize": 6,
    "dpi": 300,
    "annot": True,
    "fmt": ".2f",
    # Legend saving defaults (no-op if no legend)
    "save_legend_separately": True,
    "legend_save_suffix": "_legend",
    "legend_figsize": (2, 2),
    "figsize": (120 / 25.4, 100 / 25.4),
    "palette": "viridis",
    "legend_fill_figure": True,
    "legend_frameon": False,
    "legend_handlelength": 1.5,
    "legend_handleheight": 0.3,
    "legend_labelspacing": 0.5,
    "legend_borderpad": None,
    "legend_frame_linewidth": None,
    "legend_handle_linewidth": 0.5,
    "spine_linewidth": 0.5,
    "tick_width": 0.5,
    "line_width": 0.5,
    # New legend controls
    "legend_markerscale": None,
    "legend_handletextpad": 0.2,
}

box_style = {
    "title_fontsize": None,
    "axis_title_fontsize": 8,
    "tick_labelsize": 6,
    "legend_title_fontsize": None,
    "legend_fontsize": 6,
    "dpi": 300,
    "hide_top_right_spines": True,
    "spine_linewidth": 0.5,
    "tick_width": 0.5,
    "box_linewidth": 0.5,
    # Legend saving defaults
    "save_legend_separately": True,
    "legend_save_suffix": "_legend",
    "legend_figsize": (18 / 25.4, 9 / 25.4),
    "figsize": (40 / 25.4, 40 / 25.4),
    # Control outlier markers in boxplots
    "flier_markersize": 2.5,
    "flier_linewidth": 0.5,
    "median_color": None,
    "legend_fill_figure": True,
    "legend_frameon": False,
    "legend_handlelength": 1.5,
    "legend_handleheight": 0.3,
    "legend_labelspacing": 0.5,
    "legend_borderpad": None,
    "legend_frame_linewidth": None,
    "legend_handle_linewidth": 0.5,
    # New legend controls
    "legend_markerscale": None,
    "legend_handletextpad": 0.2,
}

In [None]:
def extend_ylim_without_new_ticks(ax: plt.Axes, headroom: float):
    """
    Extend the y-axis limit by a fraction `headroom` without changing tick locations.
    """
    if headroom and headroom > 0:
        orig_ticks = ax.get_yticks()
        ymin, ymax = ax.get_ylim()
        ax.set_ylim(ymin, ymax + (ymax - ymin) * headroom)
        ax.set_yticks(orig_ticks)

def subset_df(df: pd.DataFrame, groups=None, models=None) -> pd.DataFrame:
    """
    Return a copy of df filtered by optional groups and models.
    """
    df = df.copy()
    if groups is not None:
        df = df[df["Group"].isin(groups)]
    if models is not None:
        df = df[df["Model"].isin(models)]
    return df

def _save_legend_from_axes(ax, parent_save_name, style_dict):
    """Save legend from ax as a separate figure and remove it from parent."""
    if not parent_save_name:
        return None

    lg = ax.get_legend()
    if lg is None:
        return None

    handles, labels = ax.get_legend_handles_labels()

    try:
        lg.remove()
    except Exception:
        pass

    dpi = style_dict.get('dpi', 300)
    suffix = style_dict.get('legend_save_suffix', '_legend')
    root, ext = os.path.splitext(parent_save_name)
    if not ext:
        ext = '.svg'
    save_path = f"{root}{suffix}{ext}"

    # Legend appearance/style options
    legend_figsize = style_dict.get('legend_figsize', (6, 2))
    legend_ncol = style_dict.get('legend_ncol', 1)
    legend_title = style_dict.get('legend_title', None)
    legend_title_fs = style_dict.get('legend_title_fontsize', None)
    legend_fs = style_dict.get('legend_fontsize', None)
    legend_loc = style_dict.get('legend_loc', 'center')
    legend_mode = style_dict.get('legend_mode', None)
    legend_frameon = style_dict.get('legend_frameon', True)
    legend_fill_figure = style_dict.get('legend_fill_figure', False)

    # Spacing/handle controls
    legend_handlelength = style_dict.get('legend_handlelength', None)
    legend_handleheight = style_dict.get('legend_handleheight', None)
    legend_labelspacing = style_dict.get('legend_labelspacing', None)
    legend_borderpad = style_dict.get('legend_borderpad', None)
    legend_markerscale = style_dict.get('legend_markerscale', None)
    legend_handletextpad = style_dict.get('legend_handletextpad', None)

    # Line widths
    legend_frame_linewidth = style_dict.get('legend_frame_linewidth', None)
    legend_handle_linewidth = style_dict.get('legend_handle_linewidth', None)

    legend_kw = dict(ncol=legend_ncol, loc=legend_loc)
    # frame and mode handled per-branch
    if legend_handlelength is not None:
        legend_kw['handlelength'] = legend_handlelength
    if legend_handleheight is not None:
        legend_kw['handleheight'] = legend_handleheight
    if legend_labelspacing is not None:
        legend_kw['labelspacing'] = legend_labelspacing
    if legend_borderpad is not None:
        legend_kw['borderpad'] = legend_borderpad
    if legend_mode is not None:
        legend_kw['mode'] = legend_mode
    if legend_title is not None:
        legend_kw['title'] = legend_title
    if legend_markerscale is not None:
        legend_kw['markerscale'] = legend_markerscale
    if legend_handletextpad is not None:
        legend_kw['handletextpad'] = legend_handletextpad

    # Create the legend figure
    fig_leg = plt.figure(figsize=legend_figsize, dpi=dpi, constrained_layout=True)

    if legend_fill_figure:
        # Fill the entire figure area with the legend
        ax_leg = fig_leg.add_axes([0, 0, 1, 1])
        ax_leg.axis('off')
        legend_kw['loc'] = legend_loc or 'center'
        legend_kw['mode'] = legend_mode or 'expand'
        legend_kw['frameon'] = style_dict.get('legend_frameon', False)  # default off when filling
        legend = ax_leg.legend(handles, labels, **legend_kw)
    else:
        ax_leg = fig_leg.add_subplot(111)
        ax_leg.axis('off')
        legend_kw['frameon'] = legend_frameon
        legend = ax_leg.legend(handles, labels, **legend_kw)

    # Apply font sizes
    if legend_title and legend_title_fs:
        legend.get_title().set_fontsize(legend_title_fs)
    if legend_fs:
        for txt in legend.get_texts():
            txt.set_fontsize(legend_fs)

    # Apply frame linewidth
    if legend_frame_linewidth is not None and legend.get_frame() is not None:
        legend.get_frame().set_linewidth(legend_frame_linewidth)

    # Apply handle linewidth to visible legend handles
    if legend_handle_linewidth is not None:
        for h in getattr(legend, 'legend_handles', []) or []:
            if hasattr(h, 'set_linewidth'):
                try:
                    h.set_linewidth(legend_handle_linewidth)
                except Exception:
                    pass

    fig_leg.savefig(save_path, dpi=dpi)
    return fig_leg


def configure_figure(ax: plt.Axes, style: dict | None = None):
    """
    Configure a matplotlib Axes with title, axis labels, tick labels, legend,
    and y-axis limits/steps using a `style` dict.

    Recognized style keys:
      # Display options
      "show_title": bool (default False)
      "title": str
      "x_label": str
      "y_label": str
      "show_legend": bool (default False)
      "legend_title": str
      "legend_loc": str (default 'upper left')
      "legend_ncol": int (default 1)
      "legend_mode": str | None (e.g. 'expand')
      "legend_frameon": bool (default True)
      "legend_handlelength": float
      "legend_handleheight": float
      "legend_labelspacing": float
      "legend_borderpad": float
      "legend_frame_linewidth": float
      "legend_handle_linewidth": float
      "legend_markerscale": float
      "legend_handletextpad": float
      "legend_headroom": float (default 0.0)

      # Font sizes
      "title_fontsize": float
      "legend_title_fontsize": float
      "legend_fontsize": float
      "axis_title_fontsize": float
      "tick_labelsize": float

      # Y-axis limits and ticks
      "ymin": float
      "ytick_range": tuple(float, float)
      "eps": float (default 1e-6)

      # Line / spine controls
      "spine_linewidth": float
      "tick_width": float
      "hide_top_right_spines": bool

      # Legend saving (new)
      "save_legend_separately": bool (default False)
      "legend_save_suffix": str (default '_legend')
      "legend_figsize": tuple (default (6, 2))
      "legend_fill_figure": bool (default False)
      "save_name": str (optional; plot functions will set this)
    """
    style = copy.copy(style) if style else {}

    # Folded parameters
    show_title = style.get("show_title", False)
    title = style.get("title", None)
    x_label = style.get("x_label", None)
    y_label = style.get("y_label", None)
    show_legend = style.get("show_legend", False)
    legend_title = style.get("legend_title", None)

    # Font sizes
    title_fs = style.get("title_fontsize", None)
    legend_title_fs = style.get("legend_title_fontsize", None)
    legend_fs = style.get("legend_fontsize", None)
    axis_title_fs = style.get("axis_title_fontsize", None)
    tick_fs = style.get("tick_labelsize", None)

    # Layout and axis params
    legend_headroom = style.get("legend_headroom", 0.0)
    ymin = style.get("ymin", None)
    ytick_range = style.get("ytick_range", None)
    eps = style.get("eps", 1e-6)

    # Title
    if show_title:
        if title is None:
            raise ValueError("If show_title is True, 'title' must be provided in style.")
        if title_fs is not None:
            ax.set_title(title, fontsize=title_fs)
        else:
            ax.set_title(title)

    # Axis labels
    if x_label is not None:
        if axis_title_fs is not None:
            ax.set_xlabel(x_label, fontsize=axis_title_fs)
        else:
            ax.set_xlabel(x_label)
    if y_label is not None:
        if axis_title_fs is not None:
            ax.set_ylabel(y_label, fontsize=axis_title_fs)
        else:
            ax.set_ylabel(y_label)
    
    # Remove x-label and x-axis ticks if no label provided
    if x_label is None:
        ax.set_xlabel("")
        ax.set_xticks([])

    # Tick labels
    if tick_fs is not None:
        ax.tick_params(axis="both", labelsize=tick_fs)

    # Legend using artists' labels
    if show_legend:
        loc = style.get("legend_loc", "upper left")
        ncol = style.get("legend_ncol", 1)
        frameon = style.get("legend_frameon", True)
        mode = style.get("legend_mode", None)

        # Spacing/handle controls
        legend_handlelength = style.get('legend_handlelength', None)
        legend_handleheight = style.get('legend_handleheight', None)
        legend_labelspacing = style.get('legend_labelspacing', None)
        legend_borderpad = style.get('legend_borderpad', None)
        legend_markerscale = style.get('legend_markerscale', None)
        legend_handletextpad = style.get('legend_handletextpad', None)

        legend_kwargs = dict(loc=loc, ncol=ncol, frameon=frameon)
        if mode is not None:
            legend_kwargs["mode"] = mode
        if legend_handlelength is not None:
            legend_kwargs['handlelength'] = legend_handlelength
        if legend_handleheight is not None:
            legend_kwargs['handleheight'] = legend_handleheight
        if legend_labelspacing is not None:
            legend_kwargs['labelspacing'] = legend_labelspacing
        if legend_borderpad is not None:
            legend_kwargs['borderpad'] = legend_borderpad
        if legend_markerscale is not None:
            legend_kwargs['markerscale'] = legend_markerscale
        if legend_handletextpad is not None:
            legend_kwargs['handletextpad'] = legend_handletextpad

        if legend_title is None:
            legend = ax.legend(**legend_kwargs)
        else:
            legend = ax.legend(title=legend_title, **legend_kwargs)

        # Apply frame and handle linewidths
        legend_frame_linewidth = style.get('legend_frame_linewidth', None)
        if legend_frame_linewidth is not None and legend.get_frame() is not None:
            legend.get_frame().set_linewidth(legend_frame_linewidth)

        legend_handle_linewidth = style.get('legend_handle_linewidth', None)
        if legend_handle_linewidth is not None:
            for h in getattr(legend, 'legend_handles', []) or []:
                if hasattr(h, 'set_linewidth'):
                    try:
                        h.set_linewidth(legend_handle_linewidth)
                    except Exception:
                        pass

        if legend_title_fs is not None and legend.get_title():
            legend.get_title().set_fontsize(legend_title_fs)
        if legend_fs is not None:
            for text in legend.get_texts():
                text.set_fontsize(legend_fs)
        extend_ylim_without_new_ticks(ax, legend_headroom)

    # y-axis limits
    if ymin is not None:
        ax.set_ylim(ymin=ymin)

    # y-axis tick range
    if ytick_range is not None:
        yt_min, yt_max = ytick_range
        ticks = ax.get_yticks()
        ticks = ticks[(ticks >= (yt_min - eps)) & (ticks <= (yt_max + eps))]
        ax.set_yticks(ticks)

    # Line and spine controls
    spine_lw = style.get("spine_linewidth", None)
    tick_width = style.get("tick_width", None)
    hide_top_right = style.get("hide_top_right_spines", False)

    # Apply spine linewidth if provided
    if spine_lw is not None:
        for spine in ax.spines.values():
            spine.set_linewidth(spine_lw)

    # Apply tick width if provided
    if tick_width is not None:
        ax.tick_params(width=tick_width)

    # Optionally hide top/right spines
    if hide_top_right:
        if "top" in ax.spines:
            ax.spines["top"].set_visible(False)
        if "right" in ax.spines:
            ax.spines["right"].set_visible(False)

def plot_aggregate_metric_vs_model_step(
    df,
    metric,
    agg="median",
    groups=None,
    models=None,
    style=None,
):
    """
    Creates a line plot of the aggregate metric vs. model step.
    """
    df = subset_df(df, groups, models)
    style = copy.copy(style) if style else {}
    dpi = style.get("dpi", None)
    figsize = style.get("figsize", (10, 6))
    palette = style.get("palette", None)
    save_name = style.get("save_name", None)

    # aggregate by group and model
    agg_df = (
        df
        .groupby(["Group", "Model"])[metric]
        .agg(agg)
        .reset_index(name=f"{agg}_{metric}")
    )
    agg_df["step"] = (
        agg_df["Model"].str.rsplit("_", n=1)
                        .str[-1].astype(int)
    )
    agg_df = agg_df.sort_values("step")

    # Set folded style parameters
    style["x_label"] = "Number of Batches"
    style["y_label"] = f"{agg.title()} {metric}"
    style["save_name"] = save_name

    # Necessary for legend handle/text alignment.
    base_fs = matplotlib.rcParams['legend.fontsize']
    if style.get("legend_fontsize", None) is not None:
        matplotlib.rcParams['legend.fontsize'] = style["legend_fontsize"]

    fig, ax = plt.subplots(figsize=figsize, dpi=dpi, constrained_layout=True)
    ax = sns.lineplot(
        data=agg_df,
        x="step",
        y=f"{agg}_{metric}",
        hue="Group",
        hue_order=groups,
        palette=palette,
        marker="o",
        linewidth=style.get("line_width", None),
        ax=ax,
        markersize=style.get("markersize", None)
    )

    configure_figure(ax=ax, style=style)

    # Save legend separately if requested
    if style.get("save_legend_separately", False) and save_name:
        _save_legend_from_axes(ax, save_name, style)

    if save_name:
        fig.savefig(save_name, dpi=dpi)

    plt.show()

    matplotlib.rcParams['legend.fontsize'] = base_fs

def plot_metric_heatmap(
    df,
    metric,
    agg="median",
    groups=None,
    models=None,
    style=None,
):
    """
    Creates a heatmap of the aggregate metric by sample size and temperature.
    """
    df = subset_df(df, groups, models)
    style = copy.copy(style) if style else {}
    dpi = style.get("dpi", None)
    annot = style.get("annot", True)
    fmt = style.get("fmt", ".2f")
    figsize = style.get("figsize", (10, 8))
    palette = style.get("palette", "viridis")
    save_name = style.get("save_name", None)

    agg_df = (
        df
        .groupby("Model")[metric]
        .agg(agg)
        .reset_index(name=f"{agg}_{metric}")
    )
    agg_df["n"] = (
        agg_df["Model"].str.rsplit("__").str[0]
                   .str.rsplit("_").str[-1].astype(int)
    )
    agg_df["t"] = (
        agg_df["Model"].str.rsplit("__").str[1]
                   .str.rsplit("t_").str[-1].str.replace("_", ".").astype(float)
    )
    pivot = agg_df.pivot(index="n", columns="t", values=f"{agg}_{metric}")

    style["x_label"] = "Temperature"
    style["y_label"] = "Sample Size"
    style["save_name"] = save_name

    # Necessary for legend handle/text alignment.
    base_fs = matplotlib.rcParams['legend.fontsize']
    if style.get("legend_fontsize", None) is not None:
        matplotlib.rcParams['legend.fontsize'] = style["legend_fontsize"]

    fig, ax = plt.subplots(figsize=figsize, dpi=dpi, constrained_layout=True)
    ax = sns.heatmap(
        pivot, 
        annot=annot, 
        annot_kws = {
            "size": style.get("annot_fontsize", 10)
        },
        fmt=fmt, 
        cmap=palette, 
        ax=ax
    )

    # Colorbar control.
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(
        labelsize = style.get('tick_labelsize', 10),
        width = style.get('tick_width', None)
    )

    # Rasterization control; to not show grid lines when vectorized in svg.
    for coll in ax.collections:
        if isinstance(coll, QuadMesh):
            coll.set_linewidth(0)
            coll.set_edgecolor('face')
            coll.set_antialiased(False)
            if style.get('rasterize_heatmap', True):
                coll.set_rasterized(True)

    configure_figure(ax=ax, style=style)
    if style.get("save_legend_separately", False) and save_name:
        _save_legend_from_axes(ax, save_name, style)  # no-op if no legend

    if save_name:
        fig.savefig(save_name, dpi=dpi)

    plt.show()

    matplotlib.rcParams['legend.fontsize'] = base_fs

def plot_box(
    df,
    metric,
    groups=None,
    models=None,
    style=None,
    print_stats=True,
):
    """
    Creates a box plot, either by group or model.
    """
    df = subset_df(df, groups, models)
    style = copy.copy(style) if style else {}
    dpi = style.get("dpi", None)
    figsize = style.get("figsize", (10, 6))
    palette = style.get("palette", None)
    model_name_to_label = style.get("model_name_to_label", None)
    data_source_name_to_label = style.get("data_source_name_to_label", None)
    save_name = style.get("save_name", None)

    # Infer plot_by if not provided
    plot_by = style.get("plot_by", None) 
    if plot_by not in {"Group", "Model"}:
        plot_by = "Group" if groups else "Model"
    group_by = style.get("group_by", None)
    if group_by is None:
        group_by = plot_by

    if model_name_to_label is not None and "Model" in df.columns:
        df["Model"] = df["Model"].map(model_name_to_label).fillna(df["Model"])
        if palette is not None:
            palette = {model_name_to_label.get(k, k): v for k, v in palette.items()}
        if models is not None:
            models = [model_name_to_label.get(m, m) for m in models]
    
    if data_source_name_to_label is not None and "dataset_name" in df.columns:
        df["dataset_name"] = df["dataset_name"].map(data_source_name_to_label).fillna(df["dataset_name"])

    if group_by:
        df = df.sort_values(group_by)
    
    # Determine hue order
    if plot_by == "Group":
        hue_order = groups
        if groups:
            df = df.sort_values("Group", key=lambda x: x.map({g: i for i, g in enumerate(hue_order)}))
    else:
        hue_order = models
        if models:
            df = df.sort_values("Model", key=lambda x: x.map({m: i for i, m in enumerate(hue_order)}))

    style["y_label"] = metric
    style["save_name"] = save_name

    # Box line width controls
    box_lw = style.get("box_linewidth", None)
    whisker_lw = style.get("whisker_linewidth", box_lw if box_lw is not None else 1)
    cap_lw = style.get("cap_linewidth", whisker_lw)
    median_lw = style.get("median_linewidth", whisker_lw)
    median_color = style.get("median_color", None)

    # Outlier marker size and edge width
    flier_ms = style.get("flier_markersize", None)
    flier_lw = style.get("flier_linewidth", None)

    boxprops = {"linewidth": box_lw} if box_lw is not None else None
    whiskerprops = {"linewidth": whisker_lw}
    capprops = {"linewidth": cap_lw}
    medianprops = {"linewidth": median_lw}
    if median_color is not None:
        medianprops["color"] = median_color
    flierprops = {}
    if flier_lw is not None:
        flierprops["markeredgewidth"] = flier_lw

    # Necessary for legend handle/text alignment.
    base_fs = matplotlib.rcParams['legend.fontsize']
    if style.get("legend_fontsize", None) is not None:
        matplotlib.rcParams['legend.fontsize'] = style["legend_fontsize"]

    fig, ax = plt.subplots(figsize=figsize, dpi=dpi, constrained_layout=True)
    sns.boxplot(
        data=df,
        x=group_by,
        y=metric,
        hue=plot_by,
        hue_order=hue_order,
        palette=palette,
        ax=ax,
        showcaps=True,
        showfliers=True,
        legend=True,
        boxprops=boxprops,
        whiskerprops=whiskerprops,
        capprops=capprops,
        medianprops=medianprops,
        fliersize=flier_ms,
        flierprops=flierprops or None,
    )

    configure_figure(ax=ax, style=style)

    if style.get("save_legend_separately", False) and save_name:
        _save_legend_from_axes(ax, save_name, style)

    if save_name:
        fig.savefig(save_name, dpi=dpi)

    if print_stats:
        if group_by != plot_by:
            stats = (
                df
                .groupby([plot_by, group_by])[metric]
                .agg(count='count', median='median')
                .reset_index()
            )
            print(stats)
        stats = (
            df
            .groupby(plot_by)[metric]
            .agg(count='count', median='median')
            .reset_index()
        )
        print(stats)

    plt.show()

    matplotlib.rcParams['legend.fontsize'] = base_fs

In [None]:
# Load the dataframes.
design_valid_plot_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_valid_plot.csv")
specificity_valid_plot_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/specificity_valid_plot.csv")
specificity_valid_hypersweep_plot_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/specificity_valid_hypersweep_plot.csv")
design_test_plot_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_test_plot.csv")
design_rna_monomer_test_plot_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_rna_monomer_test_plot.csv")
design_pseudoknot_test_plot_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_pseudoknot_test_plot.csv")
design_pseudoknot_test_scores_plot_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/design_pseudoknot_test_scores_plot.csv")
specificity_test_plot_df = pd.read_csv("/home/akubaney/projects/na_mpnn/evaluation/evaluation_summaries/specificity_test_plot.csv")

In [None]:
# Scale RibonanzaNet scores to 0-100 (from (0-1)).
# (processing script mistakenly divided by 100, so we undo that here).
design_pseudoknot_test_scores_plot_df["RibonanzaNet OKS"] = design_pseudoknot_test_scores_plot_df["RibonanzaNet OKS"] * 100.0

In [None]:
plot_aggregate_metric_vs_model_step(
    design_valid_plot_df,
    metric = "Sequence Recovery",
    agg = "median",
    groups = ["DNA", "DNA (protein context)", "RNA", "RNA (protein context)"],
    style = {
        **metric_vs_model_step_style,
        "show_title": False,
        "title": "Design Validation Set: Sequence Recovery vs. Model Step",
        "palette": polymer_type_palette,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_valid_sequence_recovery_vs_model_step.svg",
        "legend_ncol": 2,
        "legend_figsize": (53 / 25.4, 5 / 25.4)
    },
)

In [None]:
plot_aggregate_metric_vs_model_step(
    specificity_valid_plot_df,
    metric = "Cross-Entropy",
    agg = "median",
    groups = ["Crystal", "Distillation"],
    style = {
        **metric_vs_model_step_style,
        "show_title": False,
        "title": "Specificity Validation Set: Cross Entropy vs. Model Step",
        "palette": ppm_type_palette,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_valid_cross_entropy_vs_model_step.svg",
        "legend_figsize": (16 / 25.4, 8 / 25.4)
    },
)

In [None]:
plot_aggregate_metric_vs_model_step(
    specificity_valid_plot_df,
    metric = "Mean Absolute Error",
    agg = "median",
    groups = ["Crystal", "Distillation"],
    style = {
        **metric_vs_model_step_style,
        "show_title": False,
        "title": "Specificity Validation Set: Mean Absolute Error vs. Model Step",
        "palette": ppm_type_palette,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_valid_mean_absolute_error_vs_model_step.svg",
        "legend_figsize": (16 / 25.4, 8 / 25.4)
    },
)

In [None]:
plot_metric_heatmap(
    specificity_valid_hypersweep_plot_df,
    metric = "Cross-Entropy",
    agg = "median",
    groups = ["Distillation"],
    style = {
        **metric_heatmap_style,
        "show_title": True,
        "title": "Median Cross-Entropy",
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_valid_hypersweep_distillation_cross_entropy_heatmap.svg",
    },
)

In [None]:
plot_metric_heatmap(
    specificity_valid_hypersweep_plot_df,
    metric = "Mean Absolute Error",
    agg = "median",
    groups = ["Distillation"],
    style = {
        **metric_heatmap_style,
        "show_title": True,
        "title": "Median Mean Absolute Error",
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_valid_hypersweep_distillation_mean_absolute_error_heatmap.svg",
    },
)

In [None]:
plot_metric_heatmap(
    specificity_valid_hypersweep_plot_df,
    metric = "Cross-Entropy",
    agg = "median",
    groups = ["Crystal"],
    style = {
        **metric_heatmap_style,
        "show_title": True,
        "title": "Median Cross-Entropy",
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_valid_hypersweep_crystal_cross_entropy_heatmap.svg",
    },
)

In [None]:
plot_metric_heatmap(
    specificity_valid_hypersweep_plot_df,
    metric = "Mean Absolute Error",
    agg = "median",
    groups = ["Crystal"],
    style = {
        **metric_heatmap_style,
        "show_title": True,
        "title": "Median Mean Absolute Error",
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_valid_hypersweep_crystal_mean_absolute_error_heatmap.svg",
    },
)

In [None]:
plot_box(
    design_test_plot_df,
    metric = "Sequence Recovery",
    groups = ["DNA", "DNA (protein context)", "RNA", "RNA (protein context)"],
    models = ["na_mpnn"],
    style = {
        **box_style,
        "show_legend": True,
        "ytick_range": (0, 1),
        "ymin": -0.05,
        "plot_by": "Group",
        "palette": polymer_type_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_test_sequence_recovery_by_group.svg",
        "figsize": (55 / 25.4, 33 / 25.4),
        "legend_figsize": (53 / 25.4, 5 / 25.4),
        "legend_ncol": 2
    },
)

In [None]:
plot_box(
    design_rna_monomer_test_plot_df,
    metric = "Sequence Recovery",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "show_legend": True,
        "ytick_range": (0, 1),
        "ymin": -0.05,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_rna_monomer_test_sequence_recovery_by_model.svg",
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_plot_df,
    metric = "Sequence Recovery",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_sequence_recovery_by_model.svg",
        "legend_ncol": 3,
        "legend_figsize": (52 / 25.4, 4.5 / 25.4)
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "AF3 pLDDT",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_af3_plddt_by_model.svg",
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "AF3 pTM",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_af3_ptm_by_model.svg",
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "AF3 pAE",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_af3_pae_by_model.svg",
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "AF3 C1'-RMSD",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_af3_c1_prime_rmsd_by_model.svg",
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "AF3 C1'-LDDT",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_af3_c1_prime_lddt_by_model.svg",
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "AF3 C1'-GDDT",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_af3_c1_prime_gddt_by_model.svg",
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "RibonanzaNet OKS",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_ribonanza_net_predicted_openknot_score_by_model.svg",
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "RibonanzaNet Pair F1",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_ribonanza_net_f1_score_pairs_by_model.svg",
        "legend_ncol": 3,
        "legend_figsize": (52 / 25.4, 4.5 / 25.4)
    },
)

In [None]:
plot_box(
    design_pseudoknot_test_scores_plot_df,
    metric = "RibonanzaNet Loop F1",
    models = ["na_mpnn", "grnade", "rhodesign"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": design_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/design_pseudoknot_test_ribonanza_net_f1_score_loops_by_model.svg",
    },
)

In [None]:
plot_box(
    specificity_test_plot_df,
    metric = "Cross-Entropy",
    groups = ["Distillation"],
    models = ["na_mpnn", "deeppbs"],
    style = {
        **box_style,
        "show_legend": True,
        "plot_by": "Model",
        "group_by": "dataset_name",
        "x_label": "Data Source",
        "data_source_name_to_label": data_source_name_to_label,
        "palette": specificity_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_test_distillation_cross_entropy_by_model.svg",
    },
)

In [None]:
plot_box(
    specificity_test_plot_df,
    metric = "Mean Absolute Error",
    groups = ["Distillation"],
    models = ["na_mpnn", "deeppbs"],
    style = {
        **box_style,
        "plot_by": "Model",
        "group_by": "dataset_name",
        "x_label": "Data Source",
        "data_source_name_to_label": data_source_name_to_label,
        "palette": specificity_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_test_distillation_mean_absolute_error_by_model.svg",
        "legend_ncol": 2,
        "legend_figsize": (36 / 25.4, 4.5 / 25.4)
    },
)

In [None]:
plot_box(
    specificity_test_plot_df,
    metric = "Cross-Entropy",
    groups = ["Crystal"],
    models = ["na_mpnn", "deeppbs"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": specificity_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_test_crystal_cross_entropy_by_model.svg",
    },
)

In [None]:
plot_box(
    specificity_test_plot_df,
    metric = "Mean Absolute Error",
    groups = ["Crystal"],
    models = ["na_mpnn", "deeppbs"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": specificity_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_test_crystal_mean_absolute_error_by_model.svg",
    },
)

In [None]:
deeppbs_test_ids = {'6byy', '5emq', '2wbu', '6mg2', '3v79', '1pzu', '7cli', '5vmv', '1a1g', '5yef', '6fj5', '5hdn', '4x9j', '5vmu', '2ady', '1le8', '1gu4', '3kov', '2e42', '3wtw', '4cn3', '6a8r', '6dfy', '2i9t', '2lt7', '1r0n', '6vg8', '5ke8', '5ke7', '6fbq', '2erg', '6dks', '1ram', '1e3o', '5lty', '6vge', '6vg2', '3coa', '6b0q', '6dfc', '7vuq', '5vpf', '2evi', '5kl3', '6u81', '4r2p', '1hjb', '1c7u', '6fbr', '5und', '1a1j', '6mg3', '4f6n', '7jm4', '7oh9', '2evf', '6lbi', '3wts', '1srs', '6vgg', '5ke6', '7dcj', '1vkx', '4l0z', '6od3', '1gu5', '6wqu', '5z2t', '4r2d', '2euz', '4xrm', '5emp', '3gut', '1n6j', '2etw', '5bng', '1a1i', '2wty', '3q05', '3iag', '2evj', '3wtu', '6x6e', '5kl4'}
deep_pbs_overlap_plot_df = specificity_test_plot_df[
    specificity_test_plot_df["structure_path"].str.rsplit("/", n=1).str[-1].str.split(".").str[0].isin(deeppbs_test_ids)
]
len(deep_pbs_overlap_plot_df)

In [None]:
plot_box(
    deep_pbs_overlap_plot_df,
    metric = "Cross-Entropy",
    groups = ["Crystal"],
    models = ["na_mpnn", "deeppbs"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": specificity_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_test_crystal_deeppbs_overlap_cross_entropy_by_model.svg",
    },
)

In [None]:
plot_box(
    deep_pbs_overlap_plot_df,
    metric = "Mean Absolute Error",
    groups = ["Crystal"],
    models = ["na_mpnn", "deeppbs"],
    style = {
        **box_style,
        "plot_by": "Model",
        "palette": specificity_model_palette,
        "model_name_to_label": model_name_to_label,
        "save_name": "/home/akubaney/projects/na_mpnn/figures/matplotlib/specificity_test_crystal_deeppbs_overlap_mean_absolute_error_by_model.svg",
    },
)

In [None]:
best_by_rmsd = {
    "na_mpnn": dict(),
    "grnade":  dict(),
    "rhodesign": dict()
}
for name, model, af3_rmsd in zip(
    design_pseudoknot_test_scores_plot_df["Name"],
    design_pseudoknot_test_scores_plot_df["Model"],
    design_pseudoknot_test_scores_plot_df["AF3 C1'-RMSD"]
):
    pdb_id, design_id = name.split("_")
    if pdb_id not in best_by_rmsd[model] or af3_rmsd < best_by_rmsd[model][pdb_id][1]:
        best_by_rmsd[model][pdb_id] = (design_id, af3_rmsd)

best_by_rmsd_df = []
for pdb_id in list(best_by_rmsd["na_mpnn"].keys()):
    best_by_rmsd_df.append({
        "PDB ID": pdb_id,
        "na_mpnn": best_by_rmsd["na_mpnn"][pdb_id],
        "grnade":  best_by_rmsd["grnade"][pdb_id],
        "rhodesign": best_by_rmsd["rhodesign"][pdb_id]
    })
best_by_rmsd_df = pd.DataFrame(best_by_rmsd_df)
best_by_rmsd_df


In [None]:
entries_and_deltas = []
for structure_path in specificity_test_plot_df["structure_path"].unique():
    structure_df = specificity_test_plot_df[specificity_test_plot_df["structure_path"] == structure_path]

    na_mpnn_mae = structure_df[structure_df["Model"] == "na_mpnn"]["Mean Absolute Error"].values[0]
    deeppbs_mae = structure_df[structure_df["Model"] == "deeppbs"]["Mean Absolute Error"].values[0]
    mae_delta = na_mpnn_mae - deeppbs_mae

    na_mpnn_cross_entropy = structure_df[structure_df["Model"] == "na_mpnn"]["Cross-Entropy"].values[0]
    deeppbs_cross_entropy = structure_df[structure_df["Model"] == "deeppbs"]["Cross-Entropy"].values[0]
    cross_entropy_delta = na_mpnn_cross_entropy - deeppbs_cross_entropy

    entries_and_deltas.append((
        structure_path,
        mae_delta,
        na_mpnn_mae,
        cross_entropy_delta,
        na_mpnn_cross_entropy
    ))

entries_and_deltas = sorted(entries_and_deltas, key=lambda x: x[2])
for i in range(0, len(entries_and_deltas)):
    print(entries_and_deltas[i])