In [None]:
import pandas as pd
import json
from datetime import datetime
import plotly.express as px
from pathlib import Path
from matplotlib.pyplot import ScalarFormatter
from asapdiscovery.data.readers.molfile import MolFileFactory
from harbor.analysis.cross_docking import DockingDataModel
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
data = DockingDataModel.deserialize("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_combined_results/ALL_1_poses.parquet")

In [None]:
ligand_data = pd.read_csv("ligand_scaffold_data.csv") # generated in 20250703_scaffolds_over_time.ipynb

In [None]:
scaff_data = ligand_data.groupby("scaffold_orig_id").head(1)[["scaffold_orig_id", "scaffold_smarts", "scaffold_count", "scaffold_first_date"]]
scaff_data["scaffold_first_date"] = pd.to_datetime(scaff_data["scaffold_first_date"])

In [None]:
scaff_data

Need to make an algorithm that
1. only docks molecules to structures collected before it was collected
2. only uses the first structure collect for each scaffold

In [None]:
scaff_ids = scaff_data["scaffold_orig_id"].tolist()

In [None]:
from itertools import product

In [None]:
pairs = [(query, ref) for query, ref in product(scaff_ids, scaff_ids)]

In [None]:
len(pairs)

In [None]:
ref_data = data.dataframe.groupby(["Reference_Ligand"]).head(1)

## get first collected structure for each scaffold

In [None]:
ref_ligs = ref_data.sort_values("RefData_Date").groupby("RefData_Scaffold_ID").head(1).reset_index(drop=True)

In [None]:
from harbor.analysis.cross_docking import get_unique_structures_randomized_by_date

In [None]:
refs_to_use = []
for n in range(len(ref_ligs)):
    refs_to_use.append(get_unique_structures_randomized_by_date(ref_ligs, "Reference_Structure", "RefData_Date", n+1, 1))  

In [None]:
import harbor.analysis.cross_docking as cd

In [None]:
from pydantic import Field
class ScaffoldDateSplit(cd.ReferenceStructureSplitBase):
    """
    Returns results stuch that query structures are only docked to the first structure for each scaffold.
    """

    name: str = "ScaffoldDateSplit"
    type_: str = "ScaffoldDateSplit"
    date_column: str = Field(
        ...,
        description="Column corresponding to date deposition",
    )
    scaffold_id_column: str = Field(..., description="Column corresponding to the scaffold ID of the ligand")
    randomize_by_n_days: int = Field(
        0,
        description="Randomize the structures by n days. If 0 no randomization is done. If 1 or greater, for each structure, it can be randomly replaced by any other structure collected on that day or n-1 days from it's collection date.",
    )

    def get_records(self) -> dict:
        records = super().get_records()
        records.update(
            {
                "Randomize_by_N_Days": self.randomize_by_n_days,
                "Date_Column": self.date_column,
                "Scaffold_ID_Column": self.scaffold_id_column,
            }
        )
        return records

    def run(self, data: DockingDataModel, bootstraps=1) -> [DockingDataModel]:
        unique_refs = data.dataframe.sort_values(self.date_column).groupby(self.scaffold_id_column).head(1)[self.reference_structure_column].unique()
        
        filtered_df = data.dataframe[data.dataframe[self.reference_structure_column].isin(unique_refs)]
        
        if self.n_reference_structures is None:
            self.n_reference_structures = len(unique_refs)
        
        ref_lists = cd.get_unique_structures_randomized_by_date(
            filtered_df,
            self.reference_structure_column,
            self.date_column,
            self.n_reference_structures,
            self.randomize_by_n_days,
            bootstraps=bootstraps,
        )
        return [
            DockingDataModel(
                dataframe=data.dataframe[
                    data.dataframe[self.reference_structure_column].isin(ref_list)
                ],
                **data.model_dump(),
            )
            for ref_list in ref_lists
        ]

class ScaffoldSizeSplit(cd.ReferenceStructureSplitBase):
    """
    Returns results stuch that query structures are only docked to the first structure for each scaffold. If multiple scaffolds were deposited on the same day, pick the larger one.
    """

    name: str = "ScaffoldDateSplit"
    type_: str = "ScaffoldDateSplit"
    smiles_column: str = Field(
        ...,
        description="Column corresponding to reference ligand scaffold smiles column.",
    )
    scaffold_id_column: str = Field(..., description="Column corresponding to the scaffold ID of the ligand")

    def get_records(self) -> dict:
        records = super().get_records()
        records.update(
            {
                "Randomize_by_N_Days": self.randomize_by_n_days,
                "SMILES_Column": self.smiles_column,
                "Scaffold_ID_Column": self.scaffold_id_column,
            }
        )
        return records

    def run(self, data: DockingDataModel, bootstraps=1) -> [DockingDataModel]:
        unique_refs = data.dataframe.sort_values(self.date_column).groupby(self.scaffold_id_column).head(1)
        
        
        filtered_df = data.dataframe[data.dataframe[self.reference_structure_column].isin(unique_refs)]
        
        if self.n_reference_structures is None:
            self.n_reference_structures = len(unique_refs)
        
        ref_lists = cd.get_unique_structures_randomized_by_date(
            filtered_df,
            self.reference_structure_column,
            self.date_column,
            self.n_reference_structures,
            self.randomize_by_n_days,
            bootstraps=bootstraps,
        )
        return [
            DockingDataModel(
                dataframe=data.dataframe[
                    data.dataframe[self.reference_structure_column].isin(ref_list)
                ],
                **data.model_dump(),
            )
            for ref_list in ref_lists
        ]

In [None]:
data.dataframe.columns

In [None]:
data.dataframe

In [None]:
from importlib import reload
reload(cd)
# n_refs = cd.generate_logarithmic_scale(len(refs_to_use))
n_refs = [1, 2, 5, 10, 20, 30, 40, 50, 75, 100, 137]
evs = []
scorers = [cd.POSITScorer(variable='PoseData_docking-confidence-POSIT'),
           cd.RMSDScorer(variable='PoseData_RMSD', cutoff=2),]
for scorer in scorers:
    for n in n_refs:
        ev = cd.Evaluator(
                      scorer=scorer,
                      evaluator=cd.BinaryEvaluation(variable="PoseData_RMSD", cutoff=2),
        n_bootstraps=100,)
        ev.dataset_split = ScaffoldDateSplit(date_column="RefData_Date",
                                                      scaffold_id_column="RefData_Scaffold_ID",
                                                      randomize_by_n_days=1,
                                                      reference_structure_column="Reference_Structure")
        ev.dataset_split.n_reference_structures = n
        evs.append(ev)    

In [None]:
print([ev.dataset_split.n_reference_structures for ev in evs])

In [None]:
results = cd.Results.calculate_results(data, evs)

In [None]:
results_df = cd.Results.df_from_results(results)

# Write results to file

In [None]:
results_df.to_csv("scaffold_date_split_results.csv", index=False)

# Load results from file

## imports

In [None]:
import pandas as pd
from pathlib import Path
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.pyplot import ScalarFormatter

In [None]:
n_refs = [1, 2, 5, 10, 20, 30, 40, 50, 75, 100, 137]

In [None]:
results_df = pd.read_csv("scaffold_date_split_results.csv")
ligand_data = pd.read_csv("ligand_scaffold_data.csv") # generated in 20250703_scaffolds_over_time.ipynb

# Plot Results

In [None]:
# make 3 decimal points
results_df["Error_Lower"] = results_df["Fraction"] - results_df["CI_Lower"]
results_df["Error_Lower"] = results_df["Error_Lower"].clip(lower=0.0)
results_df["Error_Upper"] = results_df["CI_Upper"] - results_df["Fraction"]
results_df["Error_Upper"] = results_df["Error_Upper"].clip(lower=0.0)

# Plot function

In [None]:
# Global configuration
fig_path = Path("./20250706_combined_scaffold_date_split")
fig_path.mkdir(parents=True, exist_ok=True)

FIGNUM_GLOBAL = 0

# def save_fig(fig, filename, dpi=200, suffix=".pdf"):
#     """Save the figure with a global figure number."""
#     global FIGNUM_GLOBAL
#     FIGNUM_GLOBAL += 1
#     figpath = Path(fig_path / f"{filename}_{FIGNUM_GLOBAL:02d}")
#     fig.savefig(figpath.with_suffix(suffix), 
#                 bbox_inches="tight", 
#                 dpi=dpi)

def save_fig(fig, filename, dpi=200, suffix=".pdf"):
    figpath = Path(fig_path / f"{filename}")
    fig.savefig(figpath.with_suffix(suffix), 
                bbox_inches="tight", 
                dpi=dpi)
    

sns.set_style("white")
label_map = {
    "Reference_Split": "Dataset Split Type",
    "Score": "Scoring Method",
    "RandomSplit": "Randomly Ordered",
    "DateSplit": "Ordered by Date",
    "RMSD": "RMSD (Positive Control)",
    "POSIT_Probability": "POSIT Probability",
    "ScaffoldDateSplit": "One Structure Per Scaffold",
    # "N_Reference_Structures": "Number of Randomly Chosen Reference Structures",
    "N_Reference_Structures": "Number of Reference Structures Available to Use \n(Log Scale)",
    "Fraction": "Fraction of Ligands Posed \n<2Å from Reference",
    "CI_Lower": "Confidence Interval Lower Bound",
    "CI_Upper": "Confidence Interval Upper Bound",
}
for column in results_df.columns:
    if not column in label_map:
        label_map[column] = column
        
X_VAR = label_map["N_Reference_Structures"]
Y_VAR = label_map["Fraction"]
X_LABEL = label_map["N_Reference_Structures"]
Y_LABEL = label_map["Fraction"]
# QUERY_SCAFFOLD_ID = label_map["Query_Scaffold_ID_Subset_1"]
# REF_SCAFFOLD_ID = label_map["Reference_Scaffold_ID_Subset_1"]
COLOR_VAR = label_map["Reference_Split"]
STYLE_VAR = label_map["Score"]
CI_LOWER = label_map["CI_Lower"]
CI_UPPER = label_map["CI_Upper"]
LARGE_FIG_SIZE = (12, 8)
SMALL_FIG_SIZE = (8, 6)
FONT_SIZES = {
    "xlabel": 24,
    "ylabel": 24,
    "ticks": 18,
    "legend_title": 24,
    "legend_text": 18,
}
ALPHA = 0.1

In [None]:
def plot_filled_in_error_bars(
    raw_df,
    x_var=X_VAR,
    y_var=Y_VAR,
    color_var=COLOR_VAR,
    style_var=STYLE_VAR,
    ci_lower=CI_LOWER,
    ci_upper=CI_UPPER,
    reverse_hue_order=False,
    reverse_style_order=False,
    n_refs= n_refs,
):
    # Sort the dataframe
    raw_df = raw_df.sort_values(by=[x_var, style_var, color_var])
    plt.figure(figsize=(LARGE_FIG_SIZE[0], LARGE_FIG_SIZE[1]))
    
    # Define hue and style orders
    hue_order = list(reversed(sorted(raw_df[color_var].unique()))) if reverse_hue_order else list(sorted(raw_df[color_var].unique()))
    style_order = list(reversed(sorted(raw_df[style_var].unique()))) if reverse_style_order else list(sorted(raw_df[style_var].unique()))
    
    # Create color mapping
    unique_colors = sns.color_palette(n_colors=len(raw_df[color_var].unique()))
    color_map = dict(zip(hue_order, unique_colors))
    
    
    
    # Create the line plot
    fig = sns.lineplot(
        data=raw_df,
        x=x_var,
        y=y_var,
        hue=color_var,
        style=style_var,  # Keep style_var for line styles
        palette=color_map,
        hue_order=hue_order,
        style_order= style_order,
    )


    # Create fill between for each group using matched colors
    for name, group in raw_df.groupby([color_var, style_var]):
        color_name = name[0]  # First element is Score
        fig.fill_between(
            group[x_var],
            group[ci_lower],
            group[ci_upper],
            color=color_map[color_name],
            alpha=ALPHA,
        )
    
    # Customize each subplot
    fig.set_xscale("log")
    fig.xaxis.set_major_formatter(ScalarFormatter())

    custom_ticks = n_refs
    fig.set_xticks(custom_ticks)
    fig.set_xticklabels(custom_ticks, fontsize=FONT_SIZES["ticks"])
    fig.tick_params(axis='y', labelsize=FONT_SIZES["ticks"])

    fig.set_xlabel(X_LABEL, fontsize=FONT_SIZES["xlabel"], fontweight="bold")
    fig.set_ylabel(Y_LABEL, fontsize=FONT_SIZES["ylabel"], fontweight="bold")

    # Customize legend
    legend = fig.legend()
    plt.setp(legend.get_title(), fontsize=FONT_SIZES["legend_title"], fontweight="bold")
    plt.setp(legend.get_texts(), fontsize=FONT_SIZES["legend_text"])
    return plt

## update labels

In [None]:
df = results_df.copy()
df = df.rename(columns=label_map)
for column in df.columns:
    df[column] = df[column].apply(lambda x: label_map.get(x, x))
df.columns

In [None]:
fig = plot_filled_in_error_bars(df, 
                                    x_var=label_map['N_Reference_Structures'], 
                                    y_var=label_map['Fraction'], 
                                    color_var=label_map['Reference_Split'], 
                                    style_var=label_map['Score'], 
                                    ci_lower=label_map['CI_Lower'], 
                                    ci_upper=label_map['CI_Upper'],
                                    reverse_hue_order=True,
                                    )
save_fig(fig, "scaffold_date_split_results")
save_fig(fig, "scaffold_date_split_results", suffix=".png", dpi=300)

# How different is "first structure for each scaffold" vs "first 50 structures"?

In [None]:
scaffolds_from_first_50_ligs = set(ligand_data.sort_values("compound_date").head(50)["scaffold_orig_id"].unique().tolist())

In [None]:
first_50_scaffolds = set(scaff_data.sort_values("scaffold_first_date").head(50)["scaffold_orig_id"].unique().tolist())

In [None]:
scaffolds_from_first_50_ligs - first_50_scaffolds

In [None]:
len(first_50_scaffolds - scaffolds_from_first_50_ligs)

In [None]:
len(scaffolds_from_first_50_ligs)

In [None]:
missing_scaffolds = first_50_scaffolds - scaffolds_from_first_50_ligs

In [None]:
missing_scaffolds

In [None]:
first_26_scaffolds = set(scaff_data.sort_values("scaffold_first_date").head(26)["scaffold_orig_id"].unique().tolist())

In [None]:
first_26_scaffolds - scaffolds_from_first_50_ligs

In [None]:
scaffolds_from_first_50_ligs - first_26_scaffolds

In [None]:
# first 26 minus first 20

In [None]:
first_26_scaffolds - set(scaff_data.sort_values("scaffold_first_date").head(20)["scaffold_orig_id"].unique().tolist())

the scaffolds that result in the big jump are wierd big ones 

# Remake plot with date split posit probability result

In [None]:
datesplit_results = Path(
    "/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_analyzed_results/ALL_1_poses_datesplit_combined_results.csv")
pdf = pd.read_csv(datesplit_results)
pdf["Error_Lower"] = pdf["Fraction"] - pdf["CI_Lower"]
pdf["Error_Lower"] = pdf["Error_Lower"].apply(lambda x: 0 if x < 0 else x)
pdf["Error_Upper"] = pdf["CI_Upper"] - pdf["Fraction"]
pdf["Error_Upper"] = pdf["Error_Upper"].apply(lambda x: 0 if x < 0 else x)

In [None]:
pdf

In [None]:
df2 = pdf.copy()
df2 = df2.rename(columns=label_map)
for column in df2.columns:
    df2[column] = df2[column].apply(lambda x: label_map.get(x, x))
df2.columns

In [None]:
df2_plot_df = df2[(df2[label_map['Score']] == label_map['POSIT_Probability'])&(df2[label_map['Reference_Split']] == label_map['DateSplit'])].copy()

In [None]:
plot_df = pd.concat([df, df2_plot_df])

In [None]:
df2_plot_df

In [None]:
fig = plot_filled_in_error_bars(plot_df, 
                                    x_var=label_map['N_Reference_Structures'], 
                                    y_var=label_map['Fraction'], 
                                    color_var=label_map['Reference_Split'], 
                                    style_var=label_map['Score'], 
                                    ci_lower=label_map['CI_Lower'], 
                                    ci_upper=label_map['CI_Upper'],
                                    reverse_hue_order=True,
                                # n_refs = sorted(plot_df[label_map["N_Reference_Structures"]].unique().tolist())
                                    n_refs = [1, 2, 5, 20, 50, 100, 250, 403],
                                    )
save_fig(fig, "scaffold_date_split_results_with_date_split", suffix=".pdf")
save_fig(fig, "scaffold_date_split_results_with_date_split", suffix=".png", dpi=300)