In [None]:
import pandas as pd
import scipy.stats
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import EsmTokenizer

In [None]:

# Define model paths for  models
model_dict_full= {
    "m_8M_F": "/home/jovyan/shared/mahdi/1_projects/model_optimization/01all_esm_models/deepspeed/esm/all_checkpoints_4good/esm_8M_full_batch_128_2025-02-10/checkpoint-500000",
    "m_35M_F": "/home/jovyan/shared/mahdi/1_projects/model_optimization/01all_esm_models/deepspeed/esm/all_checkpoints_4good/esm_35M_full_batch_128_2025-02-10/checkpoint-500000",
    "m_150M_F": "/home/jovyan/shared/mahdi/1_projects/model_optimization/01all_esm_models/deepspeed/esm/all_checkpoints_4good/esm_150M_full_batch_128_2025-02-11/checkpoint-500000",
    "m_350M_F": "/home/jovyan/shared/mahdi/1_projects/model_optimization/01all_esm_models/deepspeed/esm/all_checkpoints_4good/esm_350M_full_batch_128_2025-01-29/checkpoint-500000",
    "m_650M_F": "/home/jovyan/shared/mahdi/1_projects/model_optimization/01all_esm_models/deepspeed/esm/all_checkpoints_4good/esm_650M_full_batch_128_2025-01-29/checkpoint-395000",
    
}

# tokenizer
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")


# Define model stats paths
model_stats = {
    "m_8M_F_germline": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_8M_F_germline_2000.json",
    "m_35M_F_germline": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_35M_F_germline_2000.json",
    "m_150M_F_germline": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_150M_F_germline_2000.json",
    "m_350M_F_germline": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_350M_F_germline_2000.json",
    "m_650M_F_germline": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_650M_F_germline_2000.json",
    
    "m_8M_F_mutated": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_8M_F_mutated_2000.json",
    "m_35M_F_mutated": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_35M_F_mutated_2000.json",
    "m_150M_F_mutated": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_150M_F_mutated_2000.json",
    "m_350M_F_mutated": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_350M_F_mutated_2000.json",
    "m_650M_F_mutated": "/home/jovyan/shared/mahdi/1_projects/model_optimization/04Per_residue_inference/results_full/m_650M_F_mutated_2000.json"
}

def make_stats_df(m_8M_F_path, m_35M_F_path, m_150M_F_path, m_350M_F_path, m_650M_F_path):
    model_dict = {
        "m_8M_F": m_8M_F_path,
        "m_35M_F": m_35M_F_path,
        "m_150M_F": m_150M_F_path,
        "m_350M_F": m_350M_F_path,
        "m_650M_F": m_650M_F_path,
    }

    stats_list = []

    for _id, data_path in tqdm(model_dict.items()):
        model_df = pd.read_json(data_path)
        model_df = model_df[model_df["cdr_indices"].map(len) == 15]
        acc_df = pd.DataFrame(list(model_df["accuracy_by_region"]))
        score_df = pd.DataFrame(list(model_df["score_by_region"]))
        loss_df = pd.DataFrame(list(model_df["loss_by_region"]))
        perplexity_df = pd.DataFrame(list(model_df["perplexity_by_region"]))
        
        for n in acc_df.columns:
            for i in range(len(acc_df)):
                stats_list.append({
                    "seq_id": i,
                    "region": n,
                    "model_id": _id,
                    "accuracy": acc_df.iloc[i, n],
                    "score": score_df.iloc[i, n],
                    "loss": loss_df.iloc[i, n],
                    "perplexity": perplexity_df.iloc[i, n],
                })

    stats_df = pd.DataFrame(stats_list)
    
    collapse_mapping = {
        0: "FRH", 1: "CDRH1", 2: "FRH", 3: "CDRH2", 4: "FRH", 5: "CDRH3", 6: "FRH",
        7: "FRL", 8: "CDRL1", 9: "FRL", 10: "CDRL2", 11: "FRL", 12: "CDRL3", 13: "FRL"
    }
    stats_df.replace({"region": collapse_mapping}, inplace=True)
    
    return stats_df

def dep_ttest(data, col_name, models, save_tstats=False, alternative="two-sided"):  
    regions = data["region"].unique()  
    pvals = {}  
    if save_tstats:  
        tstats = {}  
    
    for r in regions:  
        for i in range(len(models)):  
            for j in range(i+1, len(models)):  
                model_i = models[i]  
                model_j = models[j]  
                stat, pval = scipy.stats.ttest_rel(  
                    data[(data["region"] == r) & (data["model_id"] == model_i)][col_name],  
                    data[(data["region"] == r) & (data["model_id"] == model_j)][col_name],  
                    alternative=alternative,  
                )  
                pvals[(r, model_i, model_j)] = pval  
                if save_tstats:  
                    tstats[(r, model_i, model_j)] = stat  
    return (pvals, tstats) if save_tstats else pvals  

sns.set(font_scale=2.0)
sns.set_style("whitegrid")
sns.set_context("talk")

metric_to_use = "accuracy"
y_label_mapping = {
    "accuracy": "Accuracy (%)",
    "loss": "Loss",
    "perplexity": "Per-Position Perplexity"
}
y_label = y_label_mapping.get(metric_to_use, metric_to_use)

models_list = ["m_8M_F", "m_35M_F", "m_150M_F", "m_350M_F", "m_650M_F"]

g_data = make_stats_df(
    model_stats["m_8M_F_germline"],
    model_stats["m_35M_F_germline"],
    model_stats["m_150M_F_germline"],
    model_stats["m_350M_F_germline"],
    model_stats["m_650M_F_germline"]
)
g_pvals = dep_ttest(g_data, metric_to_use, models_list)

m_data = make_stats_df(
    model_stats["m_8M_F_mutated"],
    model_stats["m_35M_F_mutated"],
    model_stats["m_150M_F_mutated"],
    model_stats["m_350M_F_mutated"],
    model_stats["m_650M_F_mutated"]
)
m_pvals = dep_ttest(m_data, metric_to_use, models_list)

plots = {
    "a.": {"title": "Unmutated", "data": g_data, "pvals": g_pvals},
    "b.": {"title": "Mutated", "data": m_data, "pvals": m_pvals},
}

palette_dict = dict(zip(models_list, sns.color_palette("colorblind", n_colors=len(models_list))))

x_order = ["FRH", "CDRH1", "CDRH2", "CDRH3", "FRL", "CDRL1", "CDRL2", "CDRL3"]

fig, ax = plt.subplot_mosaic([["a."], ["b."]], layout='constrained', figsize=(20, 16))

for label, x in plots.items():
    title, data, pvals = x["title"], x["data"], x["pvals"]
    
    metric_min, metric_max = data[metric_to_use].min(), data[metric_to_use].max()
    y_min, y_max = metric_min - abs(metric_min) * 0.1, metric_max + abs(metric_max) * 0.1
    
    if metric_to_use == "accuracy":
        data[metric_to_use] *= 100
        y_min *= 100
        y_max *= 100
    
    sns.boxplot(
        x="region", y=metric_to_use, hue="model_id", data=data, ax=ax[label],
        order=x_order,
        palette=palette_dict, whis=[5, 95], showfliers=False, linewidth=2.5
    )

    ax[label].set_title(title, fontsize=34)
    ax[label].set_xlabel("", size=34)
    ax[label].set_ylabel(y_label, size=34)
    ax[label].tick_params(labelsize=24)
    ax[label].annotate(label, xy=(-0.1, 1.05), xycoords="axes fraction", fontsize=30, weight="bold")
    ax[label].set_ylim(bottom=y_min, top=y_max)
    ax[label].set_xticklabels(ax[label].get_xticklabels(), rotation=45, ha='right', fontsize=24)

    if label == "a.":
        handles, labels_ = ax[label].get_legend_handles_labels()
        ax[label].legend(handles=handles, labels=labels_, loc="lower right", title="Model", fontsize=22, title_fontsize=24)
    else:
        ax[label].legend_.remove()

sns.despine()
#plt.savefig("Accuracy_perposition.png", dpi=300, bbox_inches='tight')
plt.show()
