# Imports

In [None]:
import pandas as pd
import json
from datetime import datetime
import plotly.express as px
from pathlib import Path
import sys
from matplotlib.pyplot import ScalarFormatter
from asapdiscovery.data.readers.molfile import MolFileFactory
from harbor.analysis.cross_docking import DockingDataModel
import harbor.analysis.cross_docking as cd
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
data = cd.DockingDataModel.deserialize("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_combined_results/ALL_1_poses.parquet")

In [None]:
s1_df = data.dataframe[data.dataframe['QueryData_Scaffold_ID'] == 0]

In [None]:
s1 = cd.DockingDataModel(dataframe=s1_df, **data.model_dump())

In [None]:
# update default settings
default = cd.EvaluatorFactory(name="default")

default.success_rate_evaluator_settings.use = True
default.success_rate_evaluator_settings.success_rate_column = "PoseData_RMSD"

default.scorer_settings.rmsd_scorer_settings.use = True
default.scorer_settings.rmsd_scorer_settings.rmsd_column_name = "PoseData_RMSD"

default.scorer_settings.posit_scorer_settings.use = True
default.scorer_settings.posit_scorer_settings.posit_score_column_name = (
    "PoseData_docking-confidence-POSIT"
)

# basic date split cross docking
evf = default.__deepcopy__()
evf.name = "reference_split_comparison"
evf.reference_split_settings.use = True
evf.reference_split_settings.date_split_settings.use = True
evf.reference_split_settings.date_split_settings.reference_structure_date_column = (
    "RefData_Date"
)
evf.reference_split_settings.random_split_settings.use = True
evf.reference_split_settings.update_reference_settings.use = True
evf.reference_split_settings.update_reference_settings.use_logarithmic_scaling = (
    True
)

In [None]:
evs = evf.create_evaluators(s1)

In [None]:
len(evs)

In [None]:
# evaluate

In [None]:
results = cd.Results.calculate_results(s1, evs)

# scaffold 2

In [None]:
s2_df = data.dataframe[data.dataframe['QueryData_Scaffold_ID'] == 1]
s2 = cd.DockingDataModel(dataframe=s2_df, **data.model_dump())

In [None]:
evs2 = evf.create_evaluators(s2)
results2 = cd.Results.calculate_results(s2, evs2)

# Scaffold 3

In [None]:
s3_df = data.dataframe[data.dataframe['QueryData_Scaffold_ID'] == 2]
s3 = cd.DockingDataModel(dataframe=s3_df, **data.model_dump())

In [None]:
evs3 = evf.create_evaluators(s3)
results3 = cd.Results.calculate_results(s3, evs3)

# Scaffold 4

In [None]:
s4_df = data.dataframe[data.dataframe['QueryData_Scaffold_ID'] == 3]
s4 = cd.DockingDataModel(dataframe=s4_df, **data.model_dump())

In [None]:
evs4 = evf.create_evaluators(s4)
results4 = cd.Results.calculate_results(s4, evs4)

# Save Results

In [None]:
r1 = cd.Results.df_from_results(results)
r2 = cd.Results.df_from_results(results2)
r3 = cd.Results.df_from_results(results3)
r4 = cd.Results.df_from_results(results4)

## combine results

In [None]:
r1["Query_Scaffold"] = "Scaffold 1"
r2["Query_Scaffold"] = "Scaffold 2"
r3["Query_Scaffold"] = "Scaffold 3"
r4["Query_Scaffold"] = "Scaffold 4"

In [None]:
df = pd.concat([r1, r2, r3, r4], ignore_index=True)

In [None]:
df.to_csv("single_scaffold_split_results.csv")

# Plot

In [None]:
posit_results = Path("single_scaffold_split_results.csv")
pdf = pd.read_csv(posit_results)
pdf["Error_Lower"] = pdf["Fraction"] - pdf["CI_Lower"]
pdf["Error_Lower"] = pdf["Error_Lower"].apply(lambda x: 0 if x < 0 else x)
pdf["Error_Upper"] = pdf["CI_Upper"] - pdf["Fraction"]
pdf["Error_Upper"] = pdf["Error_Upper"].apply(lambda x: 0 if x < 0 else x)

# Plotting

## Plot function

In [None]:
# Global configuration
fig_path = Path("./20250804_single_scaffold_split")
fig_path.mkdir(parents=True, exist_ok=True)

FIGNUM_GLOBAL = 0

# def save_fig(fig, filename, dpi=200, suffix=".pdf"):
#     """Save the figure with a global figure number."""
#     global FIGNUM_GLOBAL
#     FIGNUM_GLOBAL += 1
#     figpath = Path(fig_path / f"{filename}_{FIGNUM_GLOBAL:02d}")
#     fig.savefig(figpath.with_suffix(suffix), 
#                 bbox_inches="tight", 
#                 dpi=dpi)

def save_fig(fig, filename, dpi=200, suffix=".pdf"):
    figpath = Path(fig_path / f"{filename}")
    fig.savefig(figpath.with_suffix(suffix), 
                bbox_inches="tight", 
                dpi=dpi)
    

sns.set_style("white")
label_map = {
    "Reference_Split": "Dataset Split Type",
    "Score": "Scoring Method",
    "RandomSplit": "Randomly Ordered",
    "DateSplit": "Ordered by Date",
    "RMSD": "RMSD (Positive Control)",
    "POSIT_Probability": "POSIT Probability",
    "ScaffoldDateSplit": "One Structure Per Scaffold",
    # "N_Reference_Structures": "Number of Randomly Chosen Reference Structures",
    "N_Reference_Structures": "Number of Reference Structures Available to Use \n(Log Scale)",
    "Fraction": "Fraction of Ligands Posed \n<2Å from Reference",
    "CI_Lower": "Confidence Interval Lower Bound",
    "CI_Upper": "Confidence Interval Upper Bound",
    "N_Refs_Per_Scaffold": "# Reference Structures Per Scaffold",
    "Query_Scaffold": "Query Scaffold ID"
}
for column in pdf.columns:
    if not column in label_map:
        label_map[column] = column
        
X_VAR = label_map["N_Reference_Structures"]
Y_VAR = label_map["Fraction"]
X_LABEL = label_map["N_Reference_Structures"]
Y_LABEL = label_map["Fraction"]
# QUERY_SCAFFOLD_ID = label_map["Query_Scaffold_ID_Subset_1"]
# REF_SCAFFOLD_ID = label_map["Reference_Scaffold_ID_Subset_1"]
COLOR_VAR = label_map["Reference_Split"]
STYLE_VAR = label_map["Score"]
CI_LOWER = label_map["CI_Lower"]
CI_UPPER = label_map["CI_Upper"]
LARGE_FIG_SIZE = (12, 8)
SMALL_FIG_SIZE = (8, 6)
FONT_SIZES = {
    "xlabel": 24,
    "ylabel": 24,
    "ticks": 18,
    "legend_title": 24,
    "legend_text": 18,
}
ALPHA = 0.1
n_refs = sorted(pdf["N_Reference_Structures"].unique())

In [1]:
def plot_filled_in_error_bars(
    raw_df,
    x_var=X_VAR,
    y_var=Y_VAR,
    color_var=COLOR_VAR,
    style_var=STYLE_VAR,
    ci_lower=CI_LOWER,
    ci_upper=CI_UPPER,
    reverse_hue_order=False,
    reverse_style_order=False,
    n_refs= n_refs,
    fill_between=True,
        log_scale=True,
):
    # Sort the dataframe
    raw_df = raw_df.sort_values(by=[x_var, style_var, color_var])
    plt.figure(figsize=(LARGE_FIG_SIZE[0], LARGE_FIG_SIZE[1]))
    
    # Define hue and style orders
    hue_order = list(reversed(sorted(raw_df[color_var].unique()))) if reverse_hue_order else list(sorted(raw_df[color_var].unique()))
    style_order = list(reversed(sorted(raw_df[style_var].unique()))) if reverse_style_order else list(sorted(raw_df[style_var].unique()))
    
    # Create color mapping
    unique_colors = sns.color_palette(n_colors=len(raw_df[color_var].unique()))
    color_map = dict(zip(hue_order, unique_colors))
    
    
    
    # Create the line plot
    fig = sns.lineplot(
        data=raw_df,
        x=x_var,
        y=y_var,
        hue=color_var,
        style=style_var,  # Keep style_var for line styles
        palette=color_map,
        hue_order=hue_order,
        style_order= style_order,
    )

    if fill_between:
        # Create fill between for each group using matched colors
        for name, group in raw_df.groupby([color_var, style_var]):
            color_name = name[0]  # First element is Score
            fig.fill_between(
                group[x_var],
                group[ci_lower],
                group[ci_upper],
                color=color_map[color_name],
                alpha=ALPHA,
            )
    
    # Customize each subplot
    if log_scale:
        fig.set_xscale("log")
        fig.xaxis.set_major_formatter(ScalarFormatter())

    custom_ticks = n_refs
    fig.set_xticks(custom_ticks)
    fig.set_xticklabels(custom_ticks, fontsize=FONT_SIZES["ticks"])
    fig.tick_params(axis='y', labelsize=FONT_SIZES["ticks"])
    
    fig.set_yticks(y_ticks)
    fig.set_yticklabels(y_ticks, fontsize=FONT_SIZES["ticks"])

    fig.set_xlabel(X_LABEL, fontsize=FONT_SIZES["xlabel"], fontweight="bold")
    fig.set_ylabel(Y_LABEL, fontsize=FONT_SIZES["ylabel"], fontweight="bold")

    # Customize legend
    legend = fig.legend()
    plt.setp(legend.get_title(), fontsize=FONT_SIZES["legend_title"], fontweight="bold")
    plt.setp(legend.get_texts(), fontsize=FONT_SIZES["legend_text"])
    return plt

NameError: name 'X_VAR' is not defined

## update labels

In [None]:
df = pdf.copy()
df = df.rename(columns=label_map)
for column in df.columns:
    df[column] = df[column].apply(lambda x: label_map.get(x, x))
df.columns

In [None]:
plot_df = df[df[label_map["Reference_Split"]] == label_map["DateSplit"]]

In [None]:
fig = plot_filled_in_error_bars(plot_df, x_var=label_map["N_Reference_Structures"], color_var=label_map["Query_Scaffold"],fill_between=False)
save_fig(fig, "scaffolds_by_color")

In [None]:
for s in [1,2,3,4]:
    plot_df = df[df[label_map["Query_Scaffold"]] == f"Scaffold {s}"]
    plot_df.sort_values([label_map["Reference_Split"], label_map["Score"], label_map["N_Reference_Structures"]])
    fig = plot_filled_in_error_bars(plot_df, x_var=label_map["N_Reference_Structures"], color_var=label_map["Reference_Split"], reverse_hue_order=True)
    save_fig(fig, f"date_split_by_scaffold_{s}")    