# Imports

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import seaborn as sns
import pandas as pd
from pathlib import Path
from harbor.analysis import cross_docking as cd

# Load Data

In [None]:
pose_data_path = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_combined_results/ALL_1_poses.parquet")
pose_data = cd.DockingDataModel.deserialize(pose_data_path)

In [None]:
fpd_path = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_combined_results/FRED_1_poses.parquet")
fpd = cd.DockingDataModel.deserialize(fpd_path)

In [None]:
fdf = fpd.dataframe
fdf["Method"] = "FRED"

In [None]:
pdf = pose_data.dataframe
pdf["Method"] = "POSIT"

In [None]:
raw_df = pd.concat([fdf, pdf])

In [None]:
tc_df = raw_df[raw_df["TanimotoComboData_Aligned"] == False] 

In [None]:
df = tc_df.copy()

In [None]:
sns.displot(df, x="TanimotoComboData_Tanimoto", hue="Method", kind="ecdf")

In [None]:
sns.displot(df[df["Method"] == "POSIT"], x="TanimotoComboData_Tanimoto", hue="PoseData_POSIT_Method", kind="ecdf")

In [None]:
tc_df = raw_df[raw_df["TanimotoComboData_Aligned"] == True] 

In [None]:
df = tc_df.copy()

In [None]:
sns.displot(df, x="TanimotoComboData_Tanimoto", hue="Method", kind="ecdf")

In [None]:
sns.displot(df[df["Method"] == "POSIT"], x="TanimotoComboData_Tanimoto", hue="PoseData_POSIT_Method", kind="ecdf")

In [None]:
mvar = "PoseData_POSIT_Method"
df["complex_id"] = df["Query_Ligand"] + "_" + df["Reference_Structure"] 
dfs = {}
for method in df[mvar].unique():
    # filter data by all the results for POSIT that 
    filtered_posit_data = df[(df[mvar] == method)&(df["Method"] == "POSIT")]
    filtered_data = df[df["complex_id"].isin(filtered_posit_data["complex_id"].unique())]
    dfs[method] = filtered_data.copy()

In [None]:
dfs.keys()

In [None]:
def plot_ecdf(method):
    sns.displot(dfs[method], x="TanimotoComboData_Tanimoto", hue="Method", kind="ecdf")
    plt.xlim(0,1)
    plt.ylim(0,1)

In [None]:
plot_ecdf("FRED")

In [None]:
plot_ecdf("HYBRID")

## make evaluator with minimal bootstraps

In [None]:
fred_df = dfs["FRED"]
fred_fred_df = fred_df[fred_df["Method"] == "FRED"]
posit_fred_df = fred_df[fred_df["Method"] == "POSIT"]

In [None]:
len(fred_fred_df)

In [None]:
len(posit_fred_df)

In [None]:
# update default settings
default = cd.EvaluatorFactory(name="default")
default.n_bootstraps = 10

default.success_rate_evaluator_settings.use = True
default.success_rate_evaluator_settings.success_rate_column = "PoseData_RMSD"

default.scorer_settings.rmsd_scorer_settings.use = True
default.scorer_settings.rmsd_scorer_settings.rmsd_column_name = "PoseData_RMSD"

default.scorer_settings.posit_scorer_settings.use = True
default.scorer_settings.posit_scorer_settings.posit_score_column_name = (
    "PoseData_docking-confidence-POSIT"
)
# similarity split
sim_split = default.__deepcopy__()
sim_split.name = "increasing_similarity_tanimoto_combo_aligned"
sim_split.pairwise_split_settings.use = True
sim_split.pairwise_split_settings.similarity_split_settings.use = True
sim_split.pairwise_split_settings.similarity_split_settings.similarity_column_name = (
    "TanimotoComboData_Tanimoto"
)
sim_split.pairwise_split_settings.similarity_split_settings.include_similar = False
sim_split.pairwise_split_settings.similarity_split_settings.similarity_groupby_dict = {
    "TanimotoComboData_Type": "TanimotoCombo",
    "TanimotoComboData_Aligned": True,
}

In [None]:
evs = sim_split.create_evaluators()

In [None]:
len(evs)

In [None]:
ev_df = pd.DataFrame.from_records([ev.get_records() for ev in evs])

In [None]:
ffd = cd.DockingDataModel(dataframe=fred_fred_df, **fpd.model_dump())

In [None]:
pfd = cd.DockingDataModel(dataframe=posit_fred_df, **pose_data.model_dump())

In [None]:
fred_fred_results = cd.Results.calculate_results(ffd, evs)

In [None]:
posit_fred_results = cd.Results.calculate_results(pfd, evs)

In [None]:
fred_fred_results_df = cd.Results.df_from_results(fred_fred_results)
fred_fred_results_df["Method"] = "FRED"
posit_fred_results_df = cd.Results.df_from_results(posit_fred_results)
posit_fred_results_df["Method"] = "POSIT"

In [None]:
fred_results = pd.concat([fred_fred_results_df, posit_fred_results_df])

## now plot

# Plotting Params

In [None]:
# Global configuration
fig_path = Path("./20250722_why_is_posit_not_better")
fig_path.mkdir(parents=True, exist_ok=True)

FIGNUM_GLOBAL = 0

# def save_fig(fig, filename, dpi=200, suffix=".pdf"):
#     """Save the figure with a global figure number."""
#     global FIGNUM_GLOBAL
#     FIGNUM_GLOBAL += 1
#     figpath = Path(fig_path / f"{filename}_{FIGNUM_GLOBAL:02d}")
#     fig.savefig(figpath.with_suffix(suffix), 
#                 bbox_inches="tight", 
#                 dpi=dpi)

def save_fig(fig, filename, dpi=200, suffix=".pdf"):
    figpath = Path(fig_path / f"{filename}")
    fig.savefig(figpath.with_suffix(suffix), 
                bbox_inches="tight", 
                dpi=dpi)
    

sns.set_style("white")
label_map = {
    "Reference_Split": "Dataset Split Type",
    "Score": "Scoring Method",
    "RandomSplit": "Randomly Ordered",
    "DateSplit": "Ordered by Date",
    "RMSD": "RMSD (Positive Control)",
    "POSIT_Probability": "POSIT Probability",
    # "PairwiseSplit": "Similarity Metric",
    "Similarity_Threshold": "Similarity Threshold",
    "ECFP4_2048": "ECFP4 2048",
    "MCS": "MCS",
    "TanimotoCombo_True": "Tanimoto Combo (Aligned)",
    # "N_Reference_Structures": "Number of Randomly Chosen Reference Structures",
    "N_Reference_Structures": "Number of Reference Structures Available to Use \n(Log Scale)",
    "Fraction": "Fraction of Ligands Posed \n<2Å from Reference",
    "CI_Lower": "Confidence Interval Lower Bound",
    "CI_Upper": "Confidence Interval Upper Bound",
    
}
for column in pdf.columns:
    if not column in label_map:
        label_map[column] = column
        
X_VAR = label_map["N_Reference_Structures"]
Y_VAR = label_map["Fraction"]
X_LABEL = label_map["N_Reference_Structures"]
Y_LABEL = label_map["Fraction"]
# QUERY_SCAFFOLD_ID = label_map["Query_Scaffold_ID_Subset_1"]
# REF_SCAFFOLD_ID = label_map["Reference_Scaffold_ID_Subset_1"]
COLOR_VAR = label_map["Reference_Split"]
STYLE_VAR = label_map["Score"]
CI_LOWER = label_map["CI_Lower"]
CI_UPPER = label_map["CI_Upper"]
LARGE_FIG_SIZE = (12, 8)
SMALL_FIG_SIZE = (8, 6)
FONT_SIZES = {
    "xlabel": 24,
    "ylabel": 24,
    "ticks": 18,
    "legend_title": 24,
    "legend_text": 18,
}
ALPHA = 0.1

## functions

In [None]:
def plot_filled_in_error_bars(
    raw_df,
    x_var=X_VAR,
    y_var=Y_VAR,
    color_var=COLOR_VAR,
    style_var=STYLE_VAR,
    ci_lower=CI_LOWER,
    ci_upper=CI_UPPER,
    x_label=X_LABEL,
    y_label=Y_LABEL,
    reverse_hue_order=False,
    reverse_style_order=False,
):
    # Sort the dataframe
    raw_df = raw_df.sort_values(by=[x_var, style_var, color_var])
    plt.figure(figsize=(LARGE_FIG_SIZE[0], LARGE_FIG_SIZE[1]))
    
    # Define hue and style orders
    hue_order = list(reversed(sorted(raw_df[color_var].unique()))) if reverse_hue_order else list(sorted(raw_df[color_var].unique()))
    style_order = list(reversed(sorted(raw_df[style_var].unique()))) if reverse_style_order else list(sorted(raw_df[style_var].unique()))
    
    # Create color mapping
    unique_colors = sns.color_palette(n_colors=len(raw_df[color_var].unique()))
    color_map = dict(zip(hue_order, unique_colors))
    
    
    
    # Create the line plot
    fig = sns.lineplot(
        data=raw_df,
        x=x_var,
        y=y_var,
        hue=color_var,
        style=style_var,  # Keep style_var for line styles
        palette=color_map,
        hue_order=hue_order,
        style_order= style_order,
    )


    # Create fill between for each group using matched colors
    for name, group in raw_df.groupby([color_var, style_var]):
        color_name = name[0]  # First element is Score
        fig.fill_between(
            group[x_var],
            group[ci_lower],
            group[ci_upper],
            color=color_map[color_name],
            alpha=ALPHA,
        )
    
    # Customize each subplot
    # fig.set_xscale("log")
    # fig.xaxis.set_major_formatter(ScalarFormatter())

    # custom_ticks = [1, 5, 10, 20, 50, 100, 200, raw_df[x_var].max()]
    # fig.set_xticks(custom_ticks)
    # fig.set_xticklabels(custom_ticks, fontsize=FONT_SIZES["ticks"])
    fig.tick_params(axis='y', labelsize=FONT_SIZES["ticks"])

    fig.set_xlabel(x_label, fontsize=FONT_SIZES["xlabel"], fontweight="bold")
    fig.set_ylabel(y_label, fontsize=FONT_SIZES["ylabel"], fontweight="bold")

    # Customize legend
    legend = fig.legend()
    plt.setp(legend.get_title(), fontsize=FONT_SIZES["legend_title"], fontweight="bold")
    plt.setp(legend.get_texts(), fontsize=FONT_SIZES["legend_text"])
    return plt

In [None]:
df = fred_results.copy()

In [None]:
df = df.rename(columns=label_map)

In [None]:
for column in df.columns:
    df[column] = df[column].apply(lambda x: label_map.get(x,x))

In [None]:
plot_filled_in_error_bars(df, x_var=label_map["Similarity_Threshold"], color_var="Method", style_var=label_map["Score"], x_label="Query to Reference Ligand Similarity")

In [None]:
def calculate_results(method, dfs):
    filtered_df = dfs[method]
    
    # separate dfs
    fdf = filtered_df[filtered_df["Method"] == "FRED"]
    pdf = filtered_df[filtered_df["Method"] == "POSIT"]
    
    assert len(fdf) == len(pdf)
    
    evs = sim_split.create_evaluators()
    
    assert len(evs) == 42
    
    ev_df = pd.DataFrame.from_records([ev.get_records() for ev in evs])
    
    fred_data_model = cd.DockingDataModel(dataframe=fdf, **fpd.model_dump())
    posit_data_model = cd.DockingDataModel(dataframe=pdf, **pose_data.model_dump())
    
    fred_results = cd.Results.calculate_results(fred_data_model, evs)
    posit_results = cd.Results.calculate_results(posit_data_model, evs)
    
    fred_results_df = cd.Results.df_from_results(fred_results)
    fred_results_df["Method"] = "FRED"
    posit_results_df = cd.Results.df_from_results(posit_results)
    posit_results_df["Method"] = "POSIT"
    
    return_results = pd.concat([fred_results_df, posit_results_df])
    return return_results

In [None]:
hybrid_results = calculate_results("HYBRID", dfs)

In [None]:
fred_results = calculate_results("FRED", dfs)

In [None]:
shapefit_results = calculate_results("SHAPEFIT", dfs)

In [None]:
df = fred_results.copy()
df = df.rename(columns=label_map)
for column in df.columns:
    df[column] = df[column].apply(lambda x: label_map.get(x,x))
fig = plot_filled_in_error_bars(df, x_var=label_map["Similarity_Threshold"], color_var="Method", style_var=label_map["Score"], x_label="Query to Reference Ligand Similarity")
save_fig(fig, "fred_fred_vs_posit_all_available_structures")

In [None]:
df = shapefit_results.copy()
df = df.rename(columns=label_map)
for column in df.columns:
    df[column] = df[column].apply(lambda x: label_map.get(x,x))
fig = plot_filled_in_error_bars(df, x_var=label_map["Similarity_Threshold"], color_var="Method", style_var=label_map["Score"], x_label="Query to Reference Ligand Similarity")
save_fig(fig, "shapefit_fred_vs_posit")

In [None]:
df = hybrid_results.copy()
df = df.rename(columns=label_map)
for column in df.columns:
    df[column] = df[column].apply(lambda x: label_map.get(x,x))
fig = plot_filled_in_error_bars(df, x_var=label_map["Similarity_Threshold"], color_var="Method", style_var=label_map["Score"], x_label="Query to Reference Ligand Similarity")
save_fig(fig, "hybrid_fred_vs_posit")