In [None]:
# bc of pydantic discrepancies, need to run in harbor environment instead of asap2025

# Imports

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
from pathlib import Path
import harbor.analysis.cross_docking as cd
import multiprocessing as mp
from pydantic import BaseModel
import plotly.graph_objects as go

In [None]:
# load data
fred_data = pd.read_csv('/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20250113_fred_docking/rmsd_csvs/20250122_combined_results_with_data.csv', index_col=0)
posit_data = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20240424_multi_pose_docking_cross_docking/results_csvs/20240503_combined_results_with_data.csv", index_col=0)
figure_path = Path("figures")
fred_data['Engine'] = 'FRED'
posit_data['Engine'] = 'POSIT'
results_df = pd.concat([fred_data, posit_data])

In [None]:
fred_data_sp = fred_data.sort_values(['Pose_ID']).groupby(['Query_Ligand', 'Reference_Structure']).head(1)
posit_data_sp = posit_data.sort_values(['Pose_ID']).groupby(['Query_Ligand', 'Reference_Structure']).head(1)

In [None]:
n_per_splits = np.arange(1, 21)
n_per_splits = np.concatenate((n_per_splits, np.arange(25, 206, 20)))

In [None]:
class Results(BaseModel):
    evaluator: cd.Evaluator
    fraction_good: cd.FractionGood

    def get_records(self) -> dict:
        mydict = self.evaluator.get_records()
        mydict.update(self.fraction_good.get_records())
        return mydict

    @classmethod
    def calculate_result(cls, evaluator: cd.Evaluator, df: pd.DataFrame) -> "Results":
        result = evaluator.run(df)
        return cls(evaluator=evaluator, fraction_good=result)

    @classmethod
    def calculate_results(
        cls, df: pd.DataFrame, evaluators: list[cd.Evaluator], cpus: int = 1
    ) -> list["Results"]:
        with mp.Pool(cpus) as p:
            return p.starmap(
                cls.calculate_result,
                [(evaluator, df) for evaluator in evaluators],
            )

    @classmethod
    def df_from_results(cls, results: list["Results"]) -> pd.DataFrame:
        return pd.DataFrame.from_records([result.get_records() for result in results])

In [None]:
evaluators = []
# Set up pose selectors
pose_selectors = [
    cd.PoseSelector(name="Default", variable="Pose_ID", number_to_return=1)
]
# Set up dataset splits
dataset_splits = []
dataset_splits.extend(
    [
        cd.RandomSplit(
            variable="Query_Ligand",
            n_splits=1,
            n_per_split=-1,
        )
    ]
)
structure_choices = []
structure_choices.extend(
        [
            cd.StructureChoice(
                name="ECFP4_Similarity",
                variable="Tanimoto",
                higher_is_better=False,
                number_to_return=n_per_split,
            )
            for n_per_split in n_per_splits
        ]
    )

# Add scorers
scorers = [
    cd.Scorer(
        name="POSIT_Probability",
        variable="docking-confidence-POSIT",
        number_to_return=1,
    ),
    cd.Scorer(
        name="RMSD", variable="RMSD", higher_is_better=False, number_to_return=1
    ),
]
rmsd_evaluator = cd.BinaryEvaluation(variable="RMSD", cutoff=2)

In [None]:
for scorer in scorers:
    for split in dataset_splits:
        for selector in pose_selectors:
            for structure_choice in structure_choices:
                evaluators.append(
                    cd.Evaluator(
                        pose_selector=selector,
                        dataset_split=split,
                        structure_choice=structure_choice,
                        scorer=scorer,
                        evaluator=rmsd_evaluator,
                        groupby=["Query_Ligand"],
                        n_bootstraps=(100),
                    )
                )

In [None]:
results = []
from tqdm import tqdm
for evaluator in tqdm(evaluators):
    results.append(Results.calculate_result(evaluator=evaluator, df=posit_data_sp))
    results.append(Results.calculate_result(evaluator=evaluator, df=fred_data_sp))

In [None]:
# pull out every other result
posit_results = results[::2]
fred_results = results[1::2]

In [None]:
pr_df = Results.df_from_results(posit_results)
fr_df = Results.df_from_results(fred_results)

In [None]:
pr_df['Engine'] = 'POSIT'
fr_df['Engine'] = 'FRED'

In [None]:
results_df = pd.concat([pr_df, fr_df])

In [None]:
results_df.N_Per_Split = results_df.N_Per_Split.astype(int)
results_df.sort_values(["Split", "Score", "PoseSelection", "StructureChoice", "StructureChoice_Choose_N", "N_Per_Split"], inplace=True)
results_df["Error_Lower"] = results_df["Fraction"] - results_df["CI_Lower"]
results_df["Error_Upper"] = results_df["CI_Upper"] - results_df["Fraction"]

# Plotting Variables

In [None]:
large_font = 24
small_font = 18 
labels = {"Fraction": "<b> Fraction of Poses Docked < 2Ã… from Reference </b>",
               "N_Per_Split": "<b> Total Number of Reference Structures Available to Use </b>",
          "Date": "Temporally Ordered",
                   "Random": "Randomly Shuffled"
          }
update_layout_dict = dict(xaxis=dict(title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1),  
                      title_font=dict(size=large_font), 
                             color='black', 
                             ))

In [None]:
def update_traces(fig):
    for trace in fig.data:
        if trace.name is None:
            continue
        trace.name = trace.name.replace("_", " ")
        trace.name = trace.name.replace("Split", "")
        trace.name = trace.name.replace(", ", " | ")
        trace.name = trace.name.replace("RMSD", "RMSD (Positive Control)")
        for k, v in labels.items():
            trace.name = trace.name.replace(k, v)
    return fig

In [None]:
def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip("#")
    if len(hex_color) == 3:
        hex_color = hex_color * 2
    return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)

In [None]:
def rgb_to_rgba(rgb_str, alpha):
    # Split the RGB string into its components
    rgb_values = rgb_str.strip('rgb()').split(',')
    
    # Extract individual RGB values and convert them to integers
    r, g, b = map(int, rgb_values)
    
    # Construct the RGBA string
    rgba_str = f"rgba({r}, {g}, {b}, {alpha})"
    
    return rgba_str

In [None]:
def line(error_y_mode=None, **kwargs):
    """Extension of `plotly.express.line` to use error bands."""
    ERROR_MODES = {'bar','band','bars','bands',None}
    if error_y_mode not in ERROR_MODES:
        raise ValueError(f"'error_y_mode' must be one of {ERROR_MODES}, received {repr(error_y_mode)}.")
    if error_y_mode in {'bar','bars',None}:
        fig = px.line(**kwargs)
    elif error_y_mode in {'band','bands'}:
        if 'error_y' not in kwargs:
            raise ValueError(f"If you provide argument 'error_y_mode' you must also provide 'error_y'.")
        figure_with_error_bars = px.line(**kwargs)
        fig = px.line(**{arg: val for arg,val in kwargs.items() if arg != 'error_y'})
        for data in figure_with_error_bars.data:
            x = list(data['x'])
            y_upper = list(data['y'] + data['error_y']['array'])
            y_lower = list(data['y'] - data['error_y']['array'] if data['error_y']['arrayminus'] is None else data['y'] - data['error_y']['arrayminus'])
            color = f"rgba({tuple(int(data['line']['color'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))},.3)".replace('((','(').replace('),',',').replace(' ','')
            fig.add_trace(
                go.Scatter(
                    x = x+x[::-1],
                    y = y_upper+y_lower[::-1],
                    fill = 'toself',
                    fillcolor = color,
                    line = dict(
                        color = 'rgba(255,255,255,0)'
                    ),
                    hoverinfo = "skip",
                    showlegend = False,
                    legendgroup = data['legendgroup'],
                    xaxis = data['xaxis'],
                    yaxis = data['yaxis'],
                )
            )
        # Reorder data as said here: https://stackoverflow.com/a/66854398/8849755
        reordered_data = []
        for i in range(int(len(fig.data)/2)):
            reordered_data.append(fig.data[i+int(len(fig.data)/2)])
            reordered_data.append(fig.data[i])
        fig.data = tuple(reordered_data)
    return fig

In [None]:
results_df

In [None]:
fig = line(data_frame=results_df, 
              x="StructureChoice_Choose_N", 
              y="Fraction", 
              color="Score",
              line_dash="Split",
              error_y="Error_Upper", 
              error_y_minus="Error_Lower", 
           error_y_mode="band",
              template="simple_white",
           facet_col="Engine",
              # symbol="Score", 
              height=600, 
              width=800,
              log_x=True,
              # color_discrete_sequence=px.colors.qualitative.Dark2,
              labels = labels
              )
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title="<b> Score Function | Dataset Split </b>", 
                              x=0.4, y=0.75, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
    **update_layout_dict)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig = update_traces(fig)
fig.show()
fig.write_image(figure_path / "20250207_chemical_similarity.svg")
fig.write_image(figure_path / "20250207_chemical_similarity.png")