# Import 

In [None]:
import pandas as pd
from pathlib import Path
import numpy as np
import plotly.express as px
from plotly.graph_objs import Figure
from importlib import reload
import software.analysis as a
reload(a)

# Load CSVs

In [None]:
active_site_torsions = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/torsion_actst.csv")

In [None]:
active_site_sasa = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/summation_matrix(1).csv")

In [None]:
active_site_sasa.iloc[4,0]

In [None]:
active_site_sasa.iloc[0, :]

In [None]:
full_sasa = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/sasa_data.csv")

In [None]:
full_sasa.sum(numeric_only=True, axis=1)

In [None]:
sasa_v2 = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/summation_matrix(2).csv")
sasa_v2["Structure_Source"] = sasa_v2.Complex_File.apply(lambda x: Path(x).name.split("-prepped")[0].split("_0")[0])
sasa_v2_edited = sasa_v2.drop(columns=["Complex_File"])
sasa_v2_edited.columns = ["SASA (nm^2)", "Structure_Name"]
sasa_v2_edited

## load docking results

In [None]:
df = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20230611-combined.csv", index_col=0)

In [None]:
df.Structure_Name

In [None]:
df_sasa = pd.merge(df, sasa_v2_edited, on="Structure_Name", how="outer")

In [None]:
df=df_sasa

## Add reverse TC score so that it can be ranked ascending

In [None]:
df["TanimotoCombo_R"] = 2-df.TanimotoCombo

# Overall Analysis

In [None]:
df.groupby("Version").nunique()[["Complex_ID", "Compound_ID", "Structure_Source"]]

# Drop self docking results and results with diff compounds

# filter by hybrid-only ones

In [None]:
cmpds = df[df["Version"] == "Hybrid-Only"].Compound_ID.unique()

In [None]:
clean = df[df.Compound_ID.isin(cmpds)]

In [None]:
structures = clean[clean["Version"] == "Hybrid-Only"].Structure_Source.unique()

In [None]:
clean = clean[clean.Structure_Source.isin(structures)]

## remove self-docked

In [None]:
clean = clean[clean.Compound_ID != clean.Reference_Ligand]

In [None]:
clean.groupby("Version").nunique()

## remove failed hybrid-only ones

In [None]:
ref_ligs = clean[clean["Version"] == "Hybrid-Only"].Reference_Ligand.unique()

In [None]:
len(np.intersect1d(ref_ligs, cmpds))

In [None]:
n_successfully_docked =clean.groupby("Compound_ID").nunique().Docked_File

In [None]:
px.histogram(n_successfully_docked / 2)

In [None]:
n_successfully_docked[n_successfully_docked < 308]

In [None]:
clean[clean.Compound_ID == "ALP-POS-ecbed2ba-12"].groupby("Version").nunique()

In [None]:
minimal_success = clean[(clean.Compound_ID == "ALP-POS-ecbed2ba-12") & (clean.Version == "Hybrid-Only")].Reference_Ligand.unique()

In [None]:
superclean = clean[clean.Compound_ID.isin(minimal_success)]
superclean = superclean[superclean.Reference_Ligand.isin(minimal_success)]

In [None]:
n_successfully_docked_superclean = superclean.groupby("Compound_ID").nunique().Docked_File

In [None]:
px.histogram(n_successfully_docked_superclean / 2)

In [None]:
minimal_success = superclean[(superclean.Compound_ID == "MAT-POS-5cd9ea36-16") & (superclean.Version == "Hybrid-Only")].Reference_Ligand.unique()

In [None]:
superclean = superclean[superclean.Compound_ID.isin(minimal_success)]
superclean = superclean[superclean.Reference_Ligand.isin(minimal_success)]

In [None]:
n_successfully_docked_superclean = superclean.groupby("Compound_ID").nunique().Docked_File

In [None]:
px.histogram(n_successfully_docked_superclean / 2)

In [None]:
superclean.groupby("Compound_ID").nunique().sort_values("Docked_File")

In [None]:
minimal_success = superclean[(superclean.Compound_ID == "MAT-POS-5cd9ea36-13") & (superclean.Version == "Hybrid-Only")].Reference_Ligand.unique()

In [None]:
superclean = superclean[superclean.Compound_ID.isin(minimal_success)]
superclean = superclean[superclean.Reference_Ligand.isin(minimal_success)]

In [None]:
n_successfully_docked_superclean = superclean.groupby("Compound_ID").nunique().Docked_File

In [None]:
px.histogram(n_successfully_docked_superclean / 2)

In [None]:
#clean = clean[clean.Complex_ID.isin(complexes)]

In [None]:
#clean.groupby("Version").nunique()[["Complex_ID", "Compound_ID", "Structure_Source"]]

In [None]:
clean_all = clean[clean.Version == "All"]

In [None]:
clean_all.groupby("Version").nunique()[["Complex_ID", "Compound_ID", "Structure_Source"]]

## Variables

In [None]:
tc = "TanimotoCombo"
tc_title = "TanimotoCombo Cutoff for Inclusion of Reference Structures"
tcr = "TanimotoCombo_R"
y="Fraction"
posit_r = "POSIT_R"
posit_method="POSIT_Method"
color="Version"
id_col="Compound_ID"
rmsd="RMSD"
method_split=[posit_method]
n=1
good=2
tc_cutoffs = np.linspace(0,2,50)
tcr_cutoffs = np.linspace(2,0,50)
sort_cols = [rmsd, posit_r, "Chemgauss4", tcr]
dates = df.Structure_Date.unique()
date_col = "Structure_Date"
date_title = "Date for Inclusion of Reference Structures"
reference_col = "Structure_Source"
split_cols=["Version"]
sort_col_name="Sorted_By"
full_split_cols=["Version", posit_method]
split_column_sets={"General":split_cols, "Detailed":full_split_cols}
general_split_cols = {"General":split_cols}
detailed_split_cols = {"Detailed":full_split_cols}
frac_title=f"Fraction of Poses < {good}Å from Reference"

## Calculation Functions

## Plotting Functions

### plot kwargs

In [None]:
df.Version.unique()

In [None]:
df.POSIT_Method.unique()

In [None]:
# this doesn't actually work
full_versions = [("All", method) for method in df.POSIT_Method.unique()] + [("Hybrid-Only", "HYBRID")]
full_version_labels = [f"{version}: {method}" for version, method in full_versions]
full_version_label_dict = {og: label for og, label in zip(full_versions, full_version_labels)}

In [None]:
basic_plot_kwargs = dict(color=color, 
                         )

In [None]:
big_plot_kwargs = dict(facet_col=sort_col_name,
                         facet_row="Split", 
                         height=600, 
                         width=1200, )

In [None]:
single_plot_kwargs = dict(height=400, width=600)

In [None]:
tc_plot_kwargs = dict(x=tc,  
                      labels={tc: tc_title},
                     range_x=[-0.1,2.1],)

In [None]:
date_plot_kwargs = dict(x=date_col,
                        labels={date_col:date_title},
                       )

In [None]:
fraction_plot_kwargs = dict(range_y=[-0.1,1.1])

In [None]:
stats_kwargs = dict(y="Value", error_y="STD")

In [None]:
def combine_labels_kwargs(list_of_kwargs):
    new_dict = {}
    for kwargs in list_of_kwargs:
        for k,v in kwargs.items():
            if k in new_dict.keys():
                if isinstance(v, dict):
                    for ik, iv in v.items():
                        new_dict[k][ik] = iv
                else:
                    raise NotImplementedError(f"combining these kwargs will not work due to repeated use of {k}")
            else:
                new_dict[k] = v
    return new_dict
    

In [None]:
combine_labels_kwargs([big_plot_kwargs, tc_plot_kwargs])

In [None]:
general_posit_kwargs = {sort_col_name: posit_r, "Split":"General"}

### cleanup functions

In [None]:
def replace_xaxis_labels(fig: Figure, axis_title):
    fig.for_each_xaxis(lambda x: x.update(title = ''))
    fig.add_annotation(x=0.5,y=-0.15,
                   text=axis_title, textangle=0,
                       font=dict(size=16),
                    xref="paper", yref="paper",
            showarrow=False,)
    return fig

In [None]:
def replace_yaxis_labels(fig: Figure, axis_title):
    fig.for_each_yaxis(lambda y: y.update(title = ''))
    fig.add_annotation(x=-0.05,y=0.5,
                   text=axis_title, textangle=-90,
                       font=dict(size=16),
                    xref="paper", yref="paper",
            showarrow=False,)
    return fig

In [None]:
def clean_labels(fig):
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    return fig

### scatterplot wrapper

In [None]:
def scatter_wrapper(df, kwarg_dict, 
                    x_axis_title=None, 
                    y_axis_title=None, 
                    replace_xaxis=False,
                    replace_y_axis=False,
                    clean=True,
                    x_axis_reversed=False
                   ):
    fig:Figure = px.scatter(df, **kwarg_dict, hover_data=df.columns)
    if x_axis_title:
        if replace_xaxis:
            fig = replace_xaxis_labels(fig, x_axis_title)
        else:
            fig.update_xaxes(title=x_axis_title)
    
    if y_axis_title:
        if replace_y_axis:
            fig = replace_yaxis_labels(fig, y_axis_title)
        else:
            fig.update_yaxes(title=y_axis_title)
    
    if clean:
        fig = clean_labels(fig)
    if x_axis_reversed:
        fig.update_xaxes(autorange="reversed")
    return fig

### splify df

In [None]:
def simplify_df(df, condition_dict):
    new_df = df.copy()
    for column, value in condition_dict.items():
        new_df = new_df[new_df[column] == value]
    return new_df

# Add new analysis

In [None]:
reload(a)
random_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="random", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)
date_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="Structure_Date", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
tc_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="TanimotoCombo_R", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
tcr_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="TanimotoCombo", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
sasa_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="SASA (nm^2)", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
clean["SASA_R"] = 0 - clean["SASA (nm^2)"]

In [None]:
sasa_r_stats = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="SASA_R", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=100)

In [None]:
all_stats = pd.concat([random_stats, date_stats, tc_stats, tcr_stats, sasa_stats, sasa_r_stats])

In [None]:
tc_stats_1 = a.calculate_rmsd_stats(clean, query_mol_id="Compound_ID", reference_selection="TanimotoCombo_R", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=1)

In [None]:
clean.sort_values("TanimotoCombo_R").groupby(["Compound_ID", "Version"]).head(11).max()

In [None]:
aggregated = all_stats.groupby(["Version", "Number of References", "Structure_Split", "Split_Value_min", "Split_Value_max"]).mean().reset_index()
aggregated["Max"] = all_stats.groupby(["Version", "Number of References", "Structure_Split"]).quantile(0.975).reset_index()["Fraction"] - aggregated["Fraction"]
aggregated["Min"] = aggregated["Fraction"] - all_stats.groupby(["Version", "Number of References", "Structure_Split"]).quantile(0.025).reset_index()["Fraction"]

In [None]:
aggregated

In [None]:
fig = scatter_wrapper(aggregated, 
                      dict(
                          y="Fraction", color="Structure_Split", facet_col="Version",
                          color_discrete_sequence=px.colors.qualitative.Dark24,
                          error_y="Max", 
                          error_y_minus="Min",
                          template="seaborn",
                           **fraction_plot_kwargs,
                           x="Number of References", 
                          height=600,
                          width=800
                           # **single_plot_kwargs
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References"
                     )
fig.for_each_yaxis(lambda y: y.update(title = ''))
fig.update_layout(yaxis1=dict(title=frac_title), height=400, width=800)
fig.show()
fig.write_image("20231101_sasa_comparison.png")

In [None]:
aggregated