# Imports

In [None]:
import pandas as pd, numpy as np
import plotly.express as px
from plotly.graph_objs import Figure
from pathlib import Path
from importlib import reload
import software.analysis as a
reload(a)

# Load Paths

In [None]:
import sys
sys.path.append(str(Path("../../../").resolve()))
from software.paths import paths

In [None]:
local_analysis = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/")

In [None]:
df = pd.read_csv(local_analysis / "20230611-combined.csv", index_col=0)

## Add reverse TC score so that it can be ranked ascending

In [None]:
df["TanimotoCombo_R"] = 2-df.TanimotoCombo

# Functions and Variables

## Variables

In [None]:
tc = "TanimotoCombo"
tc_title = "TanimotoCombo Cutoff for Inclusion of Reference Structures"
tcr = "TanimotoCombo_R"
y="Fraction"
posit_r = "POSIT_R"
posit_method="POSIT_Method"
color="Version"
id_col="Compound_ID"
rmsd="RMSD"
method_split=[posit_method]
n=1
good=2
tc_cutoffs = np.linspace(0,2,50)
tcr_cutoffs = np.linspace(2,0,50)
sort_cols = [rmsd, posit_r, "Chemgauss4", tcr]
dates = df.Structure_Date.unique()
date_col = "Structure_Date"
date_title = "Date for Inclusion of Reference Structures"
reference_col = "Structure_Source"
split_cols=["Version"]
sort_col_name="Sorted_By"
full_split_cols=["Version", posit_method]
split_column_sets={"General":split_cols, "Detailed":full_split_cols}
general_split_cols = {"General":split_cols}
detailed_split_cols = {"Detailed":full_split_cols}
frac_title=f"Fraction of Poses < {good}Å from Reference"

## Calculation Functions

## Plotting Functions

### plot kwargs

In [None]:
df.Version.unique()

In [None]:
df.POSIT_Method.unique()

In [None]:
# this doesn't actually work
full_versions = [("All", method) for method in df.POSIT_Method.unique()] + [("Hybrid-Only", "HYBRID")]
full_version_labels = [f"{version}: {method}" for version, method in full_versions]
full_version_label_dict = {og: label for og, label in zip(full_versions, full_version_labels)}

In [None]:
basic_plot_kwargs = dict(color=color, 
                         )

In [None]:
big_plot_kwargs = dict(facet_col=sort_col_name,
                         facet_row="Split", 
                         height=600, 
                         width=1200, )

In [None]:
single_plot_kwargs = dict(height=400, width=600)

In [None]:
tc_plot_kwargs = dict(x=tc,  
                      labels={tc: tc_title},
                     range_x=[-0.1,2.1],)

In [None]:
date_plot_kwargs = dict(x=date_col,
                        labels={date_col:date_title},
                       )

In [None]:
fraction_plot_kwargs = dict(range_y=[-0.1,1.1])

In [None]:
stats_kwargs = dict(y="Value", error_y="STD")

In [None]:
def combine_labels_kwargs(list_of_kwargs):
    new_dict = {}
    for kwargs in list_of_kwargs:
        for k,v in kwargs.items():
            if k in new_dict.keys():
                if isinstance(v, dict):
                    for ik, iv in v.items():
                        new_dict[k][ik] = iv
                else:
                    raise NotImplementedError(f"combining these kwargs will not work due to repeated use of {k}")
            else:
                new_dict[k] = v
    return new_dict
    

In [None]:
combine_labels_kwargs([big_plot_kwargs, tc_plot_kwargs])

In [None]:
general_posit_kwargs = {sort_col_name: posit_r, "Split":"General"}

### cleanup functions

In [None]:
def replace_xaxis_labels(fig: Figure, axis_title):
    fig.for_each_xaxis(lambda x: x.update(title = ''))
    fig.add_annotation(x=0.5,y=-0.15,
                   text=axis_title, textangle=0,
                       font=dict(size=16),
                    xref="paper", yref="paper",
            showarrow=False,)
    return fig

In [None]:
def replace_yaxis_labels(fig: Figure, axis_title):
    fig.for_each_yaxis(lambda y: y.update(title = ''))
    fig.add_annotation(x=-0.05,y=0.5,
                   text=axis_title, textangle=-90,
                       font=dict(size=16),
                    xref="paper", yref="paper",
            showarrow=False,)
    return fig

In [None]:
def clean_labels(fig):
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    return fig

### scatterplot wrapper

In [None]:
def scatter_wrapper(df, kwarg_dict, 
                    x_axis_title=None, 
                    y_axis_title=None, 
                    replace_xaxis=False,
                    replace_y_axis=False,
                    clean=True,
                    x_axis_reversed=False
                   ):
    fig:Figure = px.scatter(df, **kwarg_dict, hover_data=df.columns)
    if x_axis_title:
        if replace_xaxis:
            fig = replace_xaxis_labels(fig, x_axis_title)
        else:
            fig.update_xaxes(title=x_axis_title)
    
    if y_axis_title:
        if replace_y_axis:
            fig = replace_yaxis_labels(fig, y_axis_title)
        else:
            fig.update_yaxes(title=y_axis_title)
    
    if clean:
        fig = clean_labels(fig)
    if x_axis_reversed:
        fig.update_xaxes(autorange="reversed")
    return fig

### splify df

In [None]:
def simplify_df(df, condition_dict):
    new_df = df.copy()
    for column, value in condition_dict.items():
        new_df = new_df[new_df[column] == value]
    return new_df

# Overall Analysis

In [None]:
df.groupby("Version").nunique()[["Complex_ID", "Compound_ID", "Structure_Source"]]

# Drop self docking results and results with diff compounds

# filter by hybrid-only ones

In [None]:
cmpds = df[df["Version"] == "Hybrid-Only"].Compound_ID.unique()

In [None]:
clean = df[df.Compound_ID.isin(cmpds)]

In [None]:
structures = clean[clean["Version"] == "Hybrid-Only"].Structure_Source.unique()

In [None]:
clean = clean[clean.Structure_Source.isin(structures)]

## remove self-docked

In [None]:
clean = clean[clean.Compound_ID != clean.Reference_Ligand]

In [None]:
clean.groupby("Version").nunique()

## remove failed hybrid-only ones

In [None]:
ref_ligs = clean[clean["Version"] == "Hybrid-Only"].Reference_Ligand.unique()

In [None]:
len(np.intersect1d(ref_ligs, cmpds))

In [None]:
n_successfully_docked =clean.groupby("Compound_ID").nunique().Docked_File

In [None]:
px.histogram(n_successfully_docked / 2)

In [None]:
n_successfully_docked[n_successfully_docked < 308]

In [None]:
clean[clean.Compound_ID == "ALP-POS-ecbed2ba-12"].groupby("Version").nunique()

In [None]:
minimal_success = clean[(clean.Compound_ID == "ALP-POS-ecbed2ba-12") & (clean.Version == "Hybrid-Only")].Reference_Ligand.unique()

In [None]:
superclean = clean[clean.Compound_ID.isin(minimal_success)]
superclean = superclean[superclean.Reference_Ligand.isin(minimal_success)]

In [None]:
n_successfully_docked_superclean = superclean.groupby("Compound_ID").nunique().Docked_File

In [None]:
px.histogram(n_successfully_docked_superclean / 2)

In [None]:
minimal_success = superclean[(superclean.Compound_ID == "MAT-POS-5cd9ea36-16") & (superclean.Version == "Hybrid-Only")].Reference_Ligand.unique()

In [None]:
superclean = superclean[superclean.Compound_ID.isin(minimal_success)]
superclean = superclean[superclean.Reference_Ligand.isin(minimal_success)]

In [None]:
n_successfully_docked_superclean = superclean.groupby("Compound_ID").nunique().Docked_File

In [None]:
px.histogram(n_successfully_docked_superclean / 2)

In [None]:
superclean.groupby("Compound_ID").nunique().sort_values("Docked_File")

In [None]:
minimal_success = superclean[(superclean.Compound_ID == "MAT-POS-5cd9ea36-13") & (superclean.Version == "Hybrid-Only")].Reference_Ligand.unique()

In [None]:
superclean = superclean[superclean.Compound_ID.isin(minimal_success)]
superclean = superclean[superclean.Reference_Ligand.isin(minimal_success)]

In [None]:
n_successfully_docked_superclean = superclean.groupby("Compound_ID").nunique().Docked_File

In [None]:
px.histogram(n_successfully_docked_superclean / 2)

In [None]:
#clean = clean[clean.Complex_ID.isin(complexes)]

In [None]:
#clean.groupby("Version").nunique()[["Complex_ID", "Compound_ID", "Structure_Source"]]

In [None]:
clean_all = clean[clean.Version == "All"]

In [None]:
clean_all.groupby("Version").nunique()[["Complex_ID", "Compound_ID", "Structure_Source"]]

# By Random Sort

In [None]:
from tqdm import tqdm

In [None]:
structures = clean.Structure_Source.unique()
random_dfs = []
for i in tqdm(range(0,100)):
    shuffled_structures = np.random.choice(structures, len(structures))
    # print(structures)
    # print(shuffled_structures)
    split_cols_list = []
    score_list = []
    cutoff_list = []
    perc_mols_list = []
    n_references = []
    for n_refs in range(1,len(structures), 5):
        # get a random set of structures
        sampled_structures = shuffled_structures[0:n_refs]
        df_subset = clean[clean.Structure_Source.isin(sampled_structures)]

        # now sort by posit score
        best = df_subset.sort_values("POSIT", ascending=False).groupby(["Version", "Compound_ID"]).head(1)

        # print(best.groupby("Version").nunique())
        perc_good = a.calc_perc_good(best, 
                         score_column=rmsd, 
                         good_score=good,
                         total_mol=len(df_subset.Compound_ID.unique()))
        for split_col in perc_good.index:
            split_cols_list.append(split_col)
            score_list.append(perc_good[split_col])
            n_references.append(n_refs)

    return_df = pd.DataFrame(
        {
            "Fraction": score_list,
            "Version": split_cols_list,
            "Number of References": n_references,
        }
    )
    random_dfs.append(return_df)

In [None]:
combined = pd.concat(random_dfs)

In [None]:
stats = combined.groupby(["Version", "Number of References"]).describe().reset_index()

In [None]:
stats.columns = ["Version", "Number of References", "count", "mean", "std", "min", "25%", "50%", "75%", "max"]

In [None]:
stats

In [None]:
fig = scatter_wrapper(stats, kwarg_dict=dict(y="mean", x="Number of References", error_y="std", **basic_plot_kwargs, **single_plot_kwargs, **fraction_plot_kwargs))
fig.update_yaxes(title=frac_title)
fig.update_xaxes(title="Number of References")
fig.show()
fig.write_image("20231002_random_posit_full_all.png")

## use new calculation function

In [None]:
reload(a)
random_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="random", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)
date_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="Structure_Date", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
tc_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="TanimotoCombo_R", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
tcr_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="TanimotoCombo", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
all_stats = pd.concat([random_stats, date_stats, tc_stats, tcr_stats])

In [None]:
all_stats

In [None]:
simple = all_stats.drop(columns=["Mean Number of References", "Max Number of References", "Min Number of References", "Split_Value_min", "Split_Value_max"])
aggregated = simple.groupby(["Version", "Number of References", "Structure_Split"]).mean().reset_index()
aggregated["Max"] = simple.groupby(["Version", "Number of References", "Structure_Split"]).quantile(0.975).reset_index()["Fraction"] - aggregated["Fraction"]
aggregated["Min"] = aggregated["Fraction"] - simple.groupby(["Version", "Number of References", "Structure_Split"]).quantile(0.025).reset_index()["Fraction"]

In [None]:
aggregated.to_csv("20240120_aggregated_all_stats_bootstraps100_stride10")

In [None]:
fig = scatter_wrapper(aggregated[aggregated["Version"] == "Hybrid-Only"], 
                      dict(
                          y="Fraction", color="Structure_Split", facet_col="Version",
                          color_discrete_sequence=px.colors.qualitative.Dark24,
                          error_y="Max", 
                          error_y_minus="Min",
                          template="seaborn",
                           **fraction_plot_kwargs,
                           x="Number of References", 
                          height=400,
                          width=600,
                           # **single_plot_kwargs
                          labels={"Structure_Date": "Date of Structure Deposition", 
                                  "TanimotoCombo": "Decreasing Chemical Similarity (TC)", 
                                  "TanimotoCombo_R": "Increasing Chemical Similarity (TC)",
                                  "random":"Random"}
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References",
                     )
fig.show()
fig.write_image("20240116_structure_splits_hybrid.png")

In [None]:
fig = scatter_wrapper(aggregated, 
                      dict(
                          y="Fraction", color="Structure_Split", facet_col="Version",
                          color_discrete_sequence=px.colors.qualitative.Dark24,
                          error_y="Max", 
                          error_y_minus="Min",
                          template="seaborn",
                           **fraction_plot_kwargs,
                           x="Number of References", 
                          height=600,
                          width=800
                           # **single_plot_kwargs
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.for_each_yaxis(lambda y: y.update(title = ''))
# fig.add_annotation(x=-0.03, y=0.5,text=frac_title, textangle=-90,
#                     xref="paper", yref="paper")
# fig.update_yaxes(title=frac_title)
fig.update_layout(yaxis1=dict(title=frac_title), height=400, width=800)
fig.show()
fig.write_image("20240116_structure_splits_big.png")

# Test different scoring functions for Chemical Similarity

In [None]:
tcr_posit = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo", 
                       ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
tcr_rmsd = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo", 
                       ref_structure_stride=10, score_column="RMSD", group_by=["Version"], n_bootstraps=100)

In [None]:
tcr_chemgauss = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo", 
                       ref_structure_stride=10, score_column="Chemgauss4", group_by=["Version"], n_bootstraps=100)

In [None]:
tcr_tcr = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo", 
                       ref_structure_stride=10, score_column="TanimotoCombo_R", group_by=["Version"], n_bootstraps=100)

In [None]:
clean.columns

In [None]:
random = clean.sample(frac=1)
subset_df = random.sort_values("SCHNET_score").groupby(["Compound_ID", "Version"]).head(1)

In [None]:
rmsd_col = subset_df.groupby("Version", group_keys=True)["RMSD"].apply(lambda x: x <= 2).groupby("Version").sum() / len(random["Compound_ID"].unique())

In [None]:
rmsd_col

In [None]:
tcr_random = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo_R", 
                       ref_structure_stride=10, score_column="SCHNET_score", group_by=["Version"], n_bootstraps=100)

In [None]:
tcr_random

In [None]:
tcr_posit["Score"] = "POSIT_R"
tcr_rmsd["Score"] = "RMSD"
tcr_chemgauss["Score"] = "Chemgauss4"
tcr_tcr["Score"] = "TanimotoCombo"
tcr_random["Score"] = "Random"

In [None]:
tcr_combined = pd.concat([tcr_rmsd, tcr_posit, tcr_chemgauss, tcr_tcr, tcr_random])

In [None]:
tcr_combined

In [None]:
tcr_simple = tcr_combined.drop(columns=["Mean Number of References", "Max Number of References", "Min Number of References", "Split_Value_min", "Split_Value_max", "Structure_Split"])

In [None]:
tcr_simple

In [None]:
tcr_agg = tcr_simple.groupby(["Version", "Number of References", "Score"]).mean().reset_index()
# tcr_agg = tcr_agg["Score"] != "Random"
tcr_agg["Max"] = tcr_simple.groupby(["Version", "Number of References", "Score"]).quantile(0.975).reset_index()["Fraction"] - tcr_agg["Fraction"]
tcr_agg["Min"] = tcr_agg["Fraction"] - tcr_simple.groupby(["Version", "Number of References", "Score"]).quantile(0.025).reset_index()["Fraction"]

In [None]:
tcr_simple.groupby(["Version", "Number of References", "Score"]).quantile(0.025)

In [None]:
fig = scatter_wrapper(tcr_agg[tcr_agg.Version == "Hybrid-Only"], 
                      dict(
                          y="Fraction", color="Score",
                          color_discrete_sequence=px.colors.qualitative.Pastel,
                          error_y="Max", 
                          error_y_minus="Min",
                          template="simple_white",
                           **fraction_plot_kwargs,
                           x="Number of References", 
                          height=600,
                          width=800,
                           # **single_plot_kwargs
                          category_orders=dict(Score=["RMSD", "POSIT_R", "Chemgauss4"])
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.for_each_yaxis(lambda y: y.update(title = ''))
# fig.add_annotation(x=-0.03, y=0.5,text=frac_title, textangle=-90,
#                     xref="paper", yref="paper")
fig.update_yaxes(title=frac_title)
fig.update_layout(title="Posed by: Hybrid-Only POSIT", yaxis1=dict(title=frac_title), height=400, width=600)
fig.show()
fig.write_image("20231017_hybrid_score_comparison_increasing similarity.png")

In [None]:
fig = scatter_wrapper(tcr_agg, 
                      dict(
                          y="Fraction", color="Version", facet_col="Score",
                          color_discrete_sequence=px.colors.qualitative.Pastel,
                          # error_y="Max", 
                          # error_y_minus="Min",
                          template="simple_white",
                           **fraction_plot_kwargs,
                           x="Number of References", 
                          height=600,
                          width=800,
                           # **single_plot_kwargs
                          category_orders=dict(Score=["RMSD", "POSIT_R", "Chemgauss4"])
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.for_each_yaxis(lambda y: y.update(title = ''))
# fig.add_annotation(x=-0.03, y=0.5,text=frac_title, textangle=-90,
#                     xref="paper", yref="paper")
# fig.update_yaxes(title=frac_title)
fig.update_layout(title="Variance of Docking Performance with Increasing Chemical Similarity", yaxis1=dict(title=frac_title), height=400, width=1200)
fig.show()
fig.write_image("20231017_all_score_comparison_increasing_similarity.png")

## group by POSIT_Method as well

In [None]:
tcr_posit = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo", 
                       ref_structure_stride=10, score_column="POSIT_R", group_by=["Version", "POSIT_Method"], n_bootstraps=1)

In [None]:
tcr_rmsd = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo", 
                       ref_structure_stride=10, score_column="RMSD", group_by=["Version", "POSIT_Method"], n_bootstraps=1)

In [None]:
tcr_chemgauss = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo", 
                       ref_structure_stride=10, score_column="Chemgauss4", group_by=["Version", "POSIT_Method"], n_bootstraps=1)

In [None]:
tcr_tcr = a.calculate_rmsd_stats(clean, 
                       query_mol_id="Compound_ID", 
                       reference_selection="TanimotoCombo", 
                       ref_structure_stride=10, score_column="TanimotoCombo_R", group_by=["Version", "POSIT_Method"], n_bootstraps=1)

In [None]:
tcr_posit["Score"] = "POSIT_R"
tcr_rmsd["Score"] = "RMSD"
tcr_chemgauss["Score"] = "Chemgauss4"
tcr_tcr["Score"] = "TanimotoCombo"

In [None]:
tcr_combined = pd.concat([tcr_rmsd, tcr_posit, tcr_chemgauss, tcr_tcr])

In [None]:
tcr_agg = tcr_combined.groupby(["Version", "Number of References", "Structure_Split", "Split_Value_min", "Split_Value_max", "Score"]).mean().reset_index()
tcr_agg["Max"] = tcr_agg.groupby(["Version", "Number of References", "Structure_Split"]).quantile(0.975).reset_index()["Fraction"] - tcr_agg["Fraction"]
tcr_agg["Min"] = tcr_agg["Fraction"] - tcr_combined.groupby(["Version", "Number of References", "Structure_Split"]).quantile(0.025).reset_index()["Fraction"]

In [None]:
fig = scatter_wrapper(tcr_agg, 
                      dict(
                          y="Fraction", color="Version", facet_col="Score",
                          color_discrete_sequence=px.colors.qualitative.Pastel,
                          # error_y="Max", 
                          # error_y_minus="Min",
                          template="simple_white",
                           **fraction_plot_kwargs,
                           x="Number of References", 
                          height=600,
                          width=800,
                           # **single_plot_kwargs
                          category_orders=dict(Score=["RMSD", "POSIT_R", "Chemgauss4"])
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.for_each_yaxis(lambda y: y.update(title = ''))
# fig.add_annotation(x=-0.03, y=0.5,text=frac_title, textangle=-90,
#                     xref="paper", yref="paper")
# fig.update_yaxes(title=frac_title)
fig.update_layout(yaxis1=dict(title=frac_title), height=400, width=1200)
fig.show()
fig.write_image("20231019_scoring_comparison_by_poser.png")

In [None]:
tcr_subset = tcr_agg[tcr_agg.Score == "TanimotoCombo"]
tcr_subset

In [None]:
fig = scatter_wrapper(tcr_subset, 
                      dict(
                          y="Fraction", color="Version", 
                          # facet_col="Score",
                          color_discrete_sequence=px.colors.qualitative.Pastel,
                          # error_y="Max", 
                          # error_y_minus="Min",
                          template="simple_white",
                           **fraction_plot_kwargs,
                           x="Number of References", 
                          height=600,
                          width=800,
                           # **single_plot_kwargs
                          # category_orders=dict(Score=["RMSD", "POSIT_R", "Chemgauss4"])
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.for_each_yaxis(lambda y: y.update(title = ''))
# fig.add_annotation(x=-0.03, y=0.5,text=frac_title, textangle=-90,
#                     xref="paper", yref="paper")
# fig.update_yaxes(title=frac_title)
fig.update_layout(yaxis1=dict(title=frac_title), height=400, width=600)
fig.show()
fig.write_image("20231019_tcr_posing_function_comparison.png")

# Test getting traunches of structures

In [None]:
reload(a)
traunch_random_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="random", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=3, cumulative=False, count_nrefs=True)
traunch_date_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="Structure_Date", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=3, cumulative=False, count_nrefs=True)

In [None]:
traunch_tc_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="TanimotoCombo_R", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=3, cumulative=False, count_nrefs=True)

In [None]:
traunch_tcr_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="TanimotoCombo", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version", "POSIT_Method"], n_bootstraps=3, cumulative=False, count_nrefs=True)

In [None]:
traunch_all_stats = pd.concat([
    # traunch_random_stats, 
    # traunch_date_stats, 
    # traunch_tc_stats, 
    traunch_tcr_stats
])

In [None]:
traunch_all_stats

In [None]:
traunch_aggregated = traunch_all_stats.groupby(["Version", "Number of References", "Structure_Split", "Split_Value_min", "Split_Value_max", "Mean Number of References", "Max Number of References", "Min Number of References"]).mean().reset_index()
traunch_aggregated["Max"] = traunch_all_stats.groupby(["Version", "Number of References", "Structure_Split"]).quantile(0.975).reset_index()["Fraction"] - traunch_aggregated["Fraction"]
traunch_aggregated["Min"] = traunch_aggregated["Fraction"] - traunch_all_stats.groupby(["Version", "Number of References", "Structure_Split"]).quantile(0.025).reset_index()["Fraction"]

In [None]:
traunch_aggregated

In [None]:
n_docked = clean[clean["Version"] == "Hybrid-Only"].groupby("Compound_ID").count()["Docked_File"]

In [None]:
n_docked.sum()

In [None]:
fig = px.histogram(n_docked, height=600, width=600)
fig.update_layout(title="Distribution of Successful Docking Runs")
fig.update_xaxes(title="Number of Successful Docked Runs")
fig.show()
fig.write_image("20231010_distribution_successful_docking_runs.png")

In [None]:
traunch_aggregated["Range"] = [f"{i}-->{j}" for i,j in zip(traunch_aggregated["Number of References"], traunch_aggregated["Number of References"] + 10)]

In [None]:
traunch_aggregated

In [None]:
fig = scatter_wrapper(traunch_aggregated, 
                      dict(
                          y="Mean Number of References", color="Version",
                          color_discrete_sequence=px.colors.qualitative.Dark24,
                          template="seaborn",
                           # **fraction_plot_kwargs,
                           x="Range", 
                          height=400,
                          width=600
                           # **single_plot_kwargs
                      ),
                     # y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.show()


In [None]:
fig = scatter_wrapper(traunch_aggregated[traunch_aggregated["Version"] == "Hybrid-Only"], 
                      dict(
                          y="Min Number of References", color="Structure_Split", facet_col="Version",
                          color_discrete_sequence=px.colors.qualitative.Dark24,
                          template="seaborn",
                           # **fraction_plot_kwargs,
                           x="Range", 
                          height=400,
                          width=600
                           # **single_plot_kwargs
                      ),
                     # y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.show()

In [None]:
fig = scatter_wrapper(traunch_aggregated, 
                      dict(
                          y="Mean Number of References", color="Structure_Split", facet_col="Version",
                          color_discrete_sequence=px.colors.qualitative.Dark24,
                          template="seaborn",
                           # **fraction_plot_kwargs,
                           x="Range", 
                          height=400,
                          width=600
                           # **single_plot_kwargs
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.show()

In [None]:
fig = scatter_wrapper(traunch_aggregated, 
                      dict(
                          y="Fraction", color="Structure_Split", facet_col="Version",
                          color_discrete_sequence=px.colors.qualitative.Dark24,
                          error_y="Max", 
                          error_y_minus="Min",
                          template="seaborn",
                           **fraction_plot_kwargs,
                           x="Range", 
                          height=600,
                          width=800
                           # **single_plot_kwargs
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.for_each_yaxis(lambda y: y.update(title = ''))
# fig.add_annotation(x=-0.03, y=0.5,text=frac_title, textangle=-90,
#                     xref="paper", yref="paper")
# fig.update_yaxes(title=frac_title)
fig.update_layout(yaxis1=dict(title=frac_title), height=400, width=800)
fig.show()


## Using the Kneed Package to 

In [None]:
from kneed import KneeLocator

In [None]:
x1 = stats[stats["Version"] == "All"]["Number of References"].unique()

In [None]:
x2 = stats[stats["Version"] == "Hybrid-Only"]["Number of References"].unique()

In [None]:
x1 == x2

### convert hybrid_only points to x,y

In [None]:
x = stats["Number of References"].unique()
for version in stats.Version.unique():
    y = stats[stats["Version"] == version]["mean"].to_list()
    kneedle = KneeLocator(x, y, S=1.0, curve="concave", direction="increasing")
    kneedle.plot_knee_normalized(title=f"{version} Difference Plot")
    print(version, kneedle.knee)