# Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter
import seaborn as sns
import pandas as pd
from pathlib import Path

# Load Data

In [None]:
results_path = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_analyzed_results/")
posit_results = [results_path / csv for csv in ["ALL_1_poses_ecfp4_combined_results.csv", "ALL_1_poses_mcs_combined_results.csv","ALL_1_poses_tc_combined_results.csv"]]
raw_pdf = pd.concat([pd.read_csv(csv_path) for csv_path in posit_results], ignore_index=True)
raw_pdf["Method"] = "POSIT"
fred_results = [results_path / csv for csv in ["FRED_1_poses_ecfp4_combined_results.csv", "FRED_1_poses_mcs_combined_results.csv","FRED_1_poses_tc_combined_results.csv"]]
fdf = pd.concat([pd.read_csv(csv_path) for csv_path in fred_results], ignore_index=True)
fdf["Method"] = "FRED"
pdf = pd.concat([raw_pdf, fdf], ignore_index=True)
pdf["Error_Lower"] = pdf["Fraction"] - pdf["CI_Lower"]
pdf["Error_Lower"] = pdf["Error_Lower"].apply(lambda x: 0 if x < 0 else x)
pdf["Error_Upper"] = pdf["CI_Upper"] - pdf["Fraction"]
pdf["Error_Upper"] = pdf["Error_Upper"].apply(lambda x: 0 if x < 0 else x)
pdf = pdf[pdf["Total"] == 403] 

In [None]:
pdf.ECFPData_fingerprint.unique()

In [None]:
for row in pdf.itertuples():
    if isinstance(row.ECFPData_fingerprint, str):
        pdf.loc[row.Index, "Similarity Metric"] = f"{row.ECFPData_fingerprint}"
    elif isinstance(row.MCSData_Type, str):
        pdf.loc[row.Index, "Similarity Metric"] = f"{row.MCSData_Type}"
    elif isinstance(row.TanimotoComboData_Type, str):
        pdf.loc[row.Index, "Similarity Metric"] = f"{row.TanimotoComboData_Type}_{row.TanimotoComboData_Aligned}"
    else:
        raise ValueError(f"Row {row.Index} has no similarity metric defined.")

In [None]:
pdf["Similarity Metric"].unique()

In [None]:
pdf["Similarity Metric"]

# Plotting Params

In [None]:
# Global configuration
fig_path = Path("./20250702_similarity_split")
fig_path.mkdir(parents=True, exist_ok=True)

FIGNUM_GLOBAL = 0

# def save_fig(fig, filename, dpi=200, suffix=".pdf"):
#     """Save the figure with a global figure number."""
#     global FIGNUM_GLOBAL
#     FIGNUM_GLOBAL += 1
#     figpath = Path(fig_path / f"{filename}_{FIGNUM_GLOBAL:02d}")
#     fig.savefig(figpath.with_suffix(suffix), 
#                 bbox_inches="tight", 
#                 dpi=dpi)

def save_fig(fig, filename, dpi=200, suffix=".pdf"):
    figpath = Path(fig_path / f"{filename}")
    fig.savefig(figpath.with_suffix(suffix), 
                bbox_inches="tight", 
                dpi=dpi)
    

sns.set_style("white")
label_map = {
    "Reference_Split": "Dataset Split Type",
    "Score": "Scoring Method",
    "RandomSplit": "Randomly Ordered",
    "DateSplit": "Ordered by Date",
    "RMSD": "RMSD (Positive Control)",
    "POSIT_Probability": "POSIT Probability",
    # "PairwiseSplit": "Similarity Metric",
    "Similarity_Threshold": "Similarity Threshold",
    "ECFP4_2048": "ECFP4 2048",
    "MCS": "MCS",
    "TanimotoCombo_True": "Tanimoto Combo (Aligned)",
    # "N_Reference_Structures": "Number of Randomly Chosen Reference Structures",
    "N_Reference_Structures": "Number of Reference Structures Available to Use \n(Log Scale)",
    "Fraction": "Fraction of Ligands Posed \n<2Å RMSD from Crystal Pose",
    "CI_Lower": "Confidence Interval Lower Bound",
    "CI_Upper": "Confidence Interval Upper Bound",
    
}
for column in pdf.columns:
    if not column in label_map:
        label_map[column] = column
        
X_VAR = label_map["N_Reference_Structures"]
Y_VAR = label_map["Fraction"]
X_LABEL = label_map["Similarity_Threshold"]
Y_LABEL = label_map["Fraction"]
# QUERY_SCAFFOLD_ID = label_map["Query_Scaffold_ID_Subset_1"]
# REF_SCAFFOLD_ID = label_map["Reference_Scaffold_ID_Subset_1"]
COLOR_VAR = label_map["Reference_Split"]
STYLE_VAR = label_map["Score"]
CI_LOWER = label_map["CI_Lower"]
CI_UPPER = label_map["CI_Upper"]
LARGE_FIG_SIZE = (12, 8)
SMALL_FIG_SIZE = (8, 6)
FONT_SIZES = {
    "xlabel": 24,
    "ylabel": 24,
    "ticks": 18,
    "legend_title": 24,
    "legend_text": 18,
}
ALPHA = 0.1

## functions

In [None]:
def plot_filled_in_error_bars(
    raw_df,
    x_var=X_VAR,
    y_var=Y_VAR,
    color_var=COLOR_VAR,
    style_var=STYLE_VAR,
    ci_lower=CI_LOWER,
    ci_upper=CI_UPPER,
    x_label=X_LABEL,
    y_label=Y_LABEL,
    reverse_hue_order=False,
    reverse_style_order=False,
        x_ticks=None,
        y_ticks=None,
):
    # Sort the dataframe
    raw_df = raw_df.sort_values(by=[x_var, style_var, color_var])
    plt.figure(figsize=(LARGE_FIG_SIZE[0], LARGE_FIG_SIZE[1]))
    
    # Define hue and style orders
    hue_order = list(reversed(sorted(raw_df[color_var].unique()))) if reverse_hue_order else list(sorted(raw_df[color_var].unique()))
    style_order = list(reversed(sorted(raw_df[style_var].unique()))) if reverse_style_order else list(sorted(raw_df[style_var].unique()))
    
    # Create color mapping
    unique_colors = sns.color_palette(n_colors=len(raw_df[color_var].unique()))
    color_map = dict(zip(hue_order, unique_colors))
    
    
    
    # Create the line plot
    fig = sns.lineplot(
        data=raw_df,
        x=x_var,
        y=y_var,
        hue=color_var,
        style=style_var,  # Keep style_var for line styles
        palette=color_map,
        hue_order=hue_order,
        style_order= style_order,
    )


    # Create fill between for each group using matched colors
    for name, group in raw_df.groupby([color_var, style_var]):
        color_name = name[0]  # First element is Score
        fig.fill_between(
            group[x_var],
            group[ci_lower],
            group[ci_upper],
            color=color_map[color_name],
            alpha=ALPHA,
        )
    
    # Customize each subplot
    # fig.set_xscale("log")
    # fig.xaxis.set_major_formatter(ScalarFormatter())
    if x_ticks is not None:
        custom_ticks = x_ticks
    else:
        custom_ticks = np.round(np.linspace(0,1,11), 1)
    fig.set_xticks(custom_ticks)
    fig.set_xticklabels(custom_ticks, fontsize=FONT_SIZES["ticks"])
    
    if y_ticks is not None:
        custom_ticks = y_ticks
    else:
        custom_ticks = np.round(np.linspace(0,1,11), 1)
        fig.set_yticks(custom_ticks)
        fig.set_yticklabels(custom_ticks, fontsize=FONT_SIZES["ticks"])

    fig.set_xlabel(x_label, fontsize=FONT_SIZES["xlabel"], fontweight="bold")
    fig.set_ylabel(y_label, fontsize=FONT_SIZES["ylabel"], fontweight="bold")

    # Customize legend
    legend = fig.legend()
    plt.setp(legend.get_title(), fontsize=FONT_SIZES["legend_title"], fontweight="bold")
    plt.setp(legend.get_texts(), fontsize=FONT_SIZES["legend_text"])
    return plt

## update labels

In [None]:
df = pdf.copy()

In [None]:
df = df.rename(columns=label_map)

In [None]:
for column in df.columns:
    df[column] = df[column].apply(lambda x: label_map.get(x,x))

# Plotting

In [None]:
sns.lineplot(data=df,x=X_VAR,y=Y_VAR,hue=label_map["PairwiseSplit"],style="Method",)

In [None]:
sns.lineplot(data=df[df[label_map["N_Reference_Structures"]] == 403],x=label_map["Similarity_Threshold"], y=Y_VAR, hue=label_map["Similarity Metric"], style="Method",)

In [None]:
test_totals = df[["Method", "Total"]]

In [None]:
sns.displot(data=test_totals, x="Total", hue="Method", kind="ecdf", hue_order=["FRED", "POSIT"],)

# make single similarity metric plots

## tc aligned

In [None]:
tdf = df[df["Similarity Metric"] == label_map["TanimotoCombo_True"]]
group_vars = [label_map["Similarity_Threshold"], "Method", STYLE_VAR]
tdf = tdf.sort_values(label_map["N_Reference_Structures"], ascending=False).groupby(group_vars).head(1)
tdf.sort_values(by=group_vars, inplace=True)

In [None]:
tdf[label_map["N_Reference_Structures"]].unique()

In [None]:
tdf[tdf[label_map["N_Reference_Structures"]] == 250]

In [None]:
tdf[(tdf["Similarity Threshold"] > 0.5)&(tdf["Similarity Threshold"] < 0.9)&(tdf["Scoring Method"] == "RMSD (Positive Control)")&(tdf["Method"] == "FRED")]

In [None]:
plot_filled_in_error_bars(raw_df=tdf[tdf[label_map["N_Reference_Structures"]] == 403],
                          x_var=label_map['Similarity_Threshold'],
                          y_var=Y_VAR,
                          color_var="Method",
                          style_var=STYLE_VAR,
                          ci_lower=CI_LOWER,
                          ci_upper=CI_UPPER,
                          reverse_hue_order=True,
                          reverse_style_order=False)

In [None]:
sns.scatterplot(tdf, x="Similarity Threshold", y="Total", hue="Method", style="Scoring Method")

In [None]:
tdf = df[df["Similarity Metric"] == label_map["TanimotoCombo_True"]]
group_vars = [label_map["Similarity_Threshold"], "Method", STYLE_VAR]
tdf = tdf.sort_values(label_map["N_Reference_Structures"], ascending=True).groupby(group_vars).head(1)
tdf.sort_values(by=group_vars, inplace=True)

In [None]:
tdf[label_map["N_Reference_Structures"]].unique()

In [None]:
plot_filled_in_error_bars(raw_df=tdf,
                          x_var=label_map['Similarity_Threshold'],
                          y_var=Y_VAR,
                          color_var="Method",
                          style_var=STYLE_VAR,
                          ci_lower=CI_LOWER,
                          ci_upper=CI_UPPER,
                          reverse_hue_order=True,
                          reverse_style_order=False)

In [None]:
tdf = df[df["Similarity Metric"] == label_map["TanimotoCombo_True"]]
tdf[label_map["Similarity_Threshold"]] = tdf[label_map["Similarity_Threshold"]].astype(float) * 2
for i in sorted(tdf[label_map["N_Reference_Structures"]].unique()):
    fig = plot_filled_in_error_bars(raw_df=tdf[tdf[label_map["N_Reference_Structures"]] == i],
                            x_var=label_map['Similarity_Threshold'],
                            y_var=Y_VAR,
                            color_var="Method",
                            style_var=STYLE_VAR,
                            ci_lower=CI_LOWER,
                            ci_upper=CI_UPPER,
                            reverse_hue_order=True,
                            reverse_style_order=False,
                            x_label=f"TanimotoCombo (Aligned) - {i} Reference{'s' if i > 1 else ''}",
                            x_ticks=np.round(np.linspace(0,2,11), 1),
                                    )
    save_fig(fig, f"similarity_split_tc_aligned_{i:03d}", suffix=".pdf")

## MCS

In [None]:
mdf = df[df["Similarity Metric"] == label_map["MCS"]]
for i in sorted(mdf[label_map["N_Reference_Structures"]].unique()):
    fig = plot_filled_in_error_bars(raw_df=mdf[mdf[label_map["N_Reference_Structures"]] == i],
                            x_var=label_map['Similarity_Threshold'],
                            y_var=Y_VAR,
                            color_var="Method",
                            style_var=STYLE_VAR,
                            ci_lower=CI_LOWER,
                            ci_upper=CI_UPPER,
                            reverse_hue_order=True,
                            reverse_style_order=False,
                                    x_label=f"MCS - {i} Reference{'s' if i > 1 else ''}",
                                    )
    save_fig(fig, f"similarity_split_mcs_{i:03d}", suffix=".pdf")

# ECFP

In [None]:
mdf = df[df["Similarity Metric"] == label_map["ECFP4_2048"]]
for i in sorted(mdf[label_map["N_Reference_Structures"]].unique()):
    fig = plot_filled_in_error_bars(raw_df=mdf[mdf[label_map["N_Reference_Structures"]] == i],
                            x_var=label_map['Similarity_Threshold'],
                            y_var=Y_VAR,
                            color_var="Method",
                            style_var=STYLE_VAR,
                            ci_lower=CI_LOWER,
                            ci_upper=CI_UPPER,
                            reverse_hue_order=True,
                            reverse_style_order=False,
                                    x_label=f"ECFP4 Tanimoto Similarity - {i} Reference{'s' if i > 1 else ''}",)
    save_fig(fig, f"similarity_split_ecfp4_{i:03d}", suffix=".pdf")

# plot TanimotoCombo Aligned with the hue being the Number of Reference Structures

In [None]:
tdf = df[df["Similarity Metric"] == label_map["TanimotoCombo_True"]]
# only use one score method
tdf = tdf[tdf[label_map["Score"]] == label_map["POSIT_Probability"]]

## make a 4 facet figure

In [None]:
from matplotlib.colors import LogNorm

In [None]:
tdf = df[df["Similarity Metric"] == label_map["TanimotoCombo_True"]]
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True, sharey=True)
for i, ((method, score), group) in enumerate(tdf.groupby(["Method", label_map["Score"]])):
    ax = axes[i // 2, i % 2]
    # Create the line plot only including legend entries for the first subplot
    sns.lineplot(
        data=group,
        x=label_map['Similarity_Threshold'],
        y=Y_VAR,
        hue=label_map["N_Reference_Structures"],
        ax=ax,
        hue_norm=LogNorm()
    )

In [None]:
# Plot the lines on two facets
fig = sns.relplot(
    data=tdf,
    x=label_map['Similarity_Threshold'], 
    y=Y_VAR,
    hue=label_map["N_Reference_Structures"], 
    # size="choice", 
    col=label_map["Method"],
    row=label_map["Score"],
    kind="line", 
    # size_order=["T1", "T2"], 
    palette="viridis",
    hue_norm=LogNorm(),
    height=3, 
    aspect=1.25, 
    legend="full",
    # facet_kws=dict(sharex=False),
)
fig.set_titles("{col_name} Method - {row_name}")
fig.legend.set_title("N Refs")
# move the legend to top left plot
fig._legend.set_bbox_to_anchor((0.775, 0.7))
# plt.tight_layout()
save_fig(fig, "similarity_split_tc_aligned_hue_n_ref_facet", suffix=".pdf")

## four facet ecfp plot

In [None]:
tdf = df[df["Similarity Metric"] == label_map["ECFP4_2048"]]
# Plot the lines on two facets
fig = sns.relplot(
    data=tdf,
    x=label_map['Similarity_Threshold'], 
    y=Y_VAR,
    hue=label_map["N_Reference_Structures"], 
    # size="choice", 
    col=label_map["Method"],
    row=label_map["Score"],
    kind="line", 
    # size_order=["T1", "T2"], 
    palette="viridis",
    hue_norm=LogNorm(),
    height=3, 
    aspect=1.25, 
    legend="full",
    # facet_kws=dict(sharex=False),
)
fig.set_titles("{col_name} Method - {row_name}")
fig.legend.set_title("N Refs")
# move the legend to top left plot
fig._legend.set_bbox_to_anchor((0.775, 0.7))
# plt.tight_layout()
save_fig(fig, "similarity_split_ecfp4_hue_n_ref_facet", suffix=".pdf")