# Imports

In [None]:
import pandas as pd, numpy as np
import plotly.express as px
from plotly.graph_objs import Figure
from pathlib import Path
from asapdiscovery.docking.analysis import get_df_subset, calc_perc_good, calculate_perc_good
from importlib import reload
import asapdiscovery.docking.analysis as a

# Load Paths

In [None]:
import sys
sys.path.append(str(Path("../../../").resolve()))
from software.paths import paths

In [None]:
local_analysis = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/")

In [None]:
df = pd.read_csv(local_analysis / "20230611-combined.csv", index_col=0)

## a bit of fixing

In [None]:
df["TanimotoCombo_R"] = 2-df.TanimotoCombo

# Functions and Variables

# Variables

In [None]:
tc = "TanimotoCombo"
tc_title = "TanimotoCombo Cutoff for Inclusion of Reference Structures"
tcr = "TanimotoCombo_R"
y="Fraction"
posit_r = "POSIT_R"
posit_method="POSIT_Method"
color="Version"
id_col="Compound_ID"
rmsd="RMSD"
split_cols=["Version"]
full_split_cols=["Version", posit_method]
method_split=[posit_method]
n=1
good=2
tc_cutoffs = np.linspace(0,2,50)
tcr_cutoffs = np.linspace(2,0,50)
sort_cols = [rmsd, posit_r, "Chemgauss4", tcr]
sort_col_name = "Sorted_By"
dates = df.Structure_Date.unique()
date_col = "Structure_Date"
date_title = "Date for Inclusion of Reference Structures"
reference_col = "Structure_Source"
split_column_sets={"General":split_cols, "Detailed":full_split_cols}
general_split_cols = {"General":split_cols}
detailed_split_cols = {"Detailed":full_split_cols}
frac_title=f"Fraction of Poses < {good}Å from Reference"

## Calculation Functions

In [None]:
def calculate_perc_good_wrapper(split_column_sets: dict, sort_columns, n_refs, **kwargs):
    split_dfs = []
    for name, split in split_column_sets.items():
        sort_dfs = []
        
        for sort_column in sort_columns:
            new_df = a.calculate_perc_good(sort_column=sort_column,
                                           split_cols=split,
                                           **kwargs)
            new_df[sort_col_name] = sort_column
            new_df["Fraction of References Used"] = new_df["Number of References"] / n_refs
            sort_dfs.append(new_df)
        split_combined = pd.concat(sort_dfs)
        split_combined["Split"] = name
        split_dfs.append(split_combined)
    combined = pd.concat(split_dfs)
    return combined

In [None]:
def calculate_stats(df,
                    metric_dict,
                     summary_col,
                     filter_column, 
                     filter_cutoffs,
                    value_column,
                    extra_groupby_cols=None):
    
    groupby_cols = [summary_col] + extra_groupby_cols
    dfs = []
    for name, metric in metric_dict.items():
        means = []
        cutoffs = []
        summary_types = []
        sds = []
        for cutoff in filter_cutoffs:
            values = df[df[filter_column] <= cutoff].groupby(groupby_cols, group_keys=True)[value_column].apply(metric)
            mean_list = values.groupby([summary_col]).mean()
            sd_list = values.groupby([summary_col]).std()
            for summary_type in mean_list.index:
                means.append(mean_list[summary_type])
                cutoffs.append(cutoff)
                summary_types.append(summary_type)
                sds.append(sd_list[summary_type])
        mean_df = pd.DataFrame({f"Value":means, "Metric": name, filter_column:cutoffs, summary_col:summary_types, "STD":sds})
        dfs.append(mean_df)

    return pd.concat(dfs)

In [None]:
def calculate_rmsd_stats(df, bins):
    bounds = [(i,j) for i,j in zip(bins[:-1], bins[1:])]
    dfs = []
    for name, metric in {"Min":np.min, "Max":np.max, "Mean": np.mean}.items():
        means = []
        versions = []
        avg_tc = []
        sds = []
        for i,j in bounds:
            values = df[(df[tc] > i) & (df[tc]<=j)].groupby(["Version", "Compound_ID"])["RMSD"].apply(metric)
            mean = values.groupby(["Version"]).mean()
            sd = values.groupby(["Version"]).std()
            for version in mean.index:
                means.append(mean[version])
                versions.append(version)
                avg_tc.append(str(np.mean([i,j])))
                sds.append(sd[version])
        dfs.append(pd.DataFrame({"Value": means, "Metric": name, "STD": sds, "Version":versions, "TanimotoCombo":avg_tc}))
    

## Plotting Functions

### plot kwargs

In [None]:
df.Version.unique()

In [None]:
df.POSIT_Method.unique()

In [None]:
# this doesn't actually work
full_versions = [("All", method) for method in df.POSIT_Method.unique()] + [("Hybrid-Only", "HYBRID")]
full_version_labels = [f"{version}: {method}" for version, method in full_versions]
full_version_label_dict = {og: label for og, label in zip(full_versions, full_version_labels)}

In [None]:
basic_plot_kwargs = dict(color=color, 
                         )

In [None]:
big_plot_kwargs = dict(facet_col=sort_col_name,
                         facet_row="Split", 
                         height=600, 
                         width=1200, )

In [None]:
single_plot_kwargs = dict(height=400, width=600)

In [None]:
tc_plot_kwargs = dict(x=tc,  
                      labels={tc: tc_title},
                     range_x=[-0.1,2.1],)

In [None]:
date_plot_kwargs = dict(x=date_col,
                        labels={date_col:date_title},
                       )

In [None]:
fraction_plot_kwargs = dict(range_y=[-0.1,1.1])

In [None]:
stats_kwargs = dict(y="Value", error_y="STD")

In [None]:
def combine_labels_kwargs(list_of_kwargs):
    new_dict = {}
    for kwargs in list_of_kwargs:
        for k,v in kwargs.items():
            if k in new_dict.keys():
                if isinstance(v, dict):
                    for ik, iv in v.items():
                        new_dict[k][ik] = iv
                else:
                    raise NotImplementedError(f"combining these kwargs will not work due to repeated use of {k}")
            else:
                new_dict[k] = v
    return new_dict
    

In [None]:
combine_labels_kwargs([big_plot_kwargs, tc_plot_kwargs])

In [None]:
general_posit_kwargs = {sort_col_name: posit_r, "Split":"General"}

### cleanup functions

In [None]:
def replace_xaxis_labels(fig: Figure, axis_title):
    fig.for_each_xaxis(lambda x: x.update(title = ''))
    fig.add_annotation(x=0.5,y=-0.15,
                   text=axis_title, textangle=0,
                       font=dict(size=16),
                    xref="paper", yref="paper",
            showarrow=False,)
    return fig

In [None]:
def replace_yaxis_labels(fig: Figure, axis_title):
    fig.for_each_yaxis(lambda y: y.update(title = ''))
    fig.add_annotation(x=-0.05,y=0.5,
                   text=axis_title, textangle=-90,
                       font=dict(size=16),
                    xref="paper", yref="paper",
            showarrow=False,)
    return fig

In [None]:
def clean_labels(fig):
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    return fig

### scatterplot wrapper

In [None]:
def scatter_wrapper(df, kwarg_dict, 
                    x_axis_title=None, 
                    y_axis_title=None, 
                    replace_xaxis=False,
                    replace_y_axis=False,
                    clean=True,
                    x_axis_reversed=False
                   ):
    fig:Figure = px.scatter(df, **kwarg_dict, hover_data=df.columns)
    if x_axis_title:
        if replace_xaxis:
            fig = replace_xaxis_labels(fig, x_axis_title)
        else:
            fig.update_xaxes(title=x_axis_title)
    
    if y_axis_title:
        if replace_y_axis:
            fig = replace_yaxis_labels(fig, y_axis_title)
        else:
            fig.update_yaxes(title=y_axis_title)
    
    if clean:
        fig = clean_labels(fig)
    if x_axis_reversed:
        fig.update_xaxes(autorange="reversed")
    return fig

### splify df

In [None]:
def simplify_df(df, condition_dict):
    new_df = df.copy()
    for column, value in condition_dict.items():
        new_df = new_df[new_df[column] == value]
    return new_df

# Overall Analysis

In [None]:
df.groupby("Version").nunique()[["Complex_ID", "Compound_ID", "Structure_Source"]]

# By TanimotoCombo

## calculate cumulative from 0 to 2 TC

In [None]:
tc_kwargs = dict(df=df,
              id_column=id_col,
              filter_column=tc,
              filter_cutoffs=tc_cutoffs,
                 n=n,
                 good_score=good,
                 score_column=rmsd,
                 reference_col= reference_col)
              

In [None]:
df1 = calculate_perc_good_wrapper(split_column_sets=split_column_sets,
                                 sort_columns=sort_cols,
                                  n_refs=219,
                                  **tc_kwargs
                                 )

### Full Plot

In [None]:
fig = scatter_wrapper(df1, 
                      dict(y=y,**fraction_plot_kwargs,**tc_plot_kwargs, **basic_plot_kwargs, **big_plot_kwargs),
                     y_axis_title=frac_title,
                      replace_xaxis=True,
                      replace_y_axis=True,
                     x_axis_title=tc_title)
fig.show()
fig.write_image("20230615_tc_full_frac.png")

### Simple Plot

In [None]:
fig = scatter_wrapper(simplify_df(df1, general_posit_kwargs), 
                      dict(y=y,
                           **fraction_plot_kwargs,
                           **tc_plot_kwargs, 
                           **basic_plot_kwargs, 
                           **single_plot_kwargs),
                     y_axis_title=frac_title,
                     x_axis_title=tc_title)
fig.update_layout(title="Sorted by POSIT Score")
fig.show()
fig.write_image("20230615_tc_posit_frac.png")

### Calculate N Refs 

In [None]:
nrefdf1 = calculate_stats(df, 
                metric_dict={"Number of References": np.count_nonzero},
                summary_col="Version", 
                 filter_column=tc, 
                 filter_cutoffs=tc_cutoffs,
                          value_column=reference_col,
                          extra_groupby_cols=[id_col])

In [None]:
nref_column="Mean Number of References"
fig = scatter_wrapper(nrefdf1, 
                      dict(**stats_kwargs,
                          **tc_plot_kwargs, 
                           **basic_plot_kwargs, 
                           **single_plot_kwargs),
                     y_axis_title=nref_column,
                     x_axis_title=tc_title)
fig.show()
fig.write_image("20230615_tc_nref.png")

## calculate cumulative from 2 to 0 TC

In [None]:
tcr_kwargs = dict(df=df,
              id_column=id_col,
              filter_column=tcr,
              filter_cutoffs=tc_cutoffs,
                 n=n,
                 good_score=good,
                 score_column=rmsd,
                 reference_col= reference_col)
              

In [None]:
df2 = calculate_perc_good_wrapper(split_column_sets=split_column_sets,
                                 sort_columns=sort_cols,
                                  n_refs=219,
                                  **tcr_kwargs
                                 )

In [None]:
df2["TanimotoCombo"] = 2-df2["TanimotoCombo_R"]

### Full Plot

In [None]:
fig = scatter_wrapper(df2, 
                      combine_labels_kwargs([dict(y=y),fraction_plot_kwargs,tc_plot_kwargs, basic_plot_kwargs, big_plot_kwargs]),
                     y_axis_title=frac_title,
                      replace_xaxis=True,
                      replace_y_axis=True,
                     x_axis_title=tc_title,
                     x_axis_reversed=True)
fig.show()
fig.write_image("20230615_tcr_full_frac.png")

### Simple Plot

In [None]:
fig = scatter_wrapper(simplify_df(df2, general_posit_kwargs), 
                      dict(y=y,
                           **fraction_plot_kwargs,
                           **tc_plot_kwargs, 
                           **basic_plot_kwargs, 
                           **single_plot_kwargs),
                     y_axis_title=frac_title,
                     x_axis_title=tc_title,
                     x_axis_reversed=True)
fig.show()
fig.write_image("20230615_tcr_posit_frac.png")

### Calculate N Refs 

In [None]:
nrefdf2 = calculate_stats(df, 
                metric_dict={"Number of References": np.count_nonzero},
                summary_col="Version", 
                 filter_column=tcr, 
                 filter_cutoffs=tc_cutoffs,
                          value_column=reference_col,
                          extra_groupby_cols=[id_col])
nrefdf2["TanimotoCombo"] = 2-nrefdf2["TanimotoCombo_R"]

In [None]:
nref_column="Mean Number of References"
fig = scatter_wrapper(nrefdf2, 
                      dict(**stats_kwargs,
                          **tc_plot_kwargs, 
                           **basic_plot_kwargs, 
                           **single_plot_kwargs),
                     y_axis_title=nref_column,
                     x_axis_title=tc_title,
                     x_axis_reversed=True)
fig.show()
fig.write_image("20230615_tcr_nref.png")

# By Structure Date

In [None]:
date_kwargs = dict(df=df,
              id_column=id_col,
              filter_column=date_col,
              filter_cutoffs=dates,
                 n=n,
                 good_score=good,
                 score_column=rmsd,
                 reference_col= reference_col)
              

In [None]:
sdf1 = calculate_perc_good_wrapper(split_column_sets=split_column_sets,
                                 sort_columns=sort_cols,
                                  n_refs=219,
                                  **date_kwargs
                                 )

### Full Plot

In [None]:
fig = scatter_wrapper(sdf1, 
                      dict(y=y,**fraction_plot_kwargs,**date_plot_kwargs, **basic_plot_kwargs, **big_plot_kwargs),
                     y_axis_title=frac_title,
                      replace_xaxis=True,
                      replace_y_axis=True,
                     x_axis_title=date_title)
fig.show()
fig.write_image("20230615_dates_full_frac.png")

### Simple Plot

In [None]:
fig = scatter_wrapper(simplify_df(sdf1, general_posit_kwargs), 
                      dict(y=y,
                           **fraction_plot_kwargs,
                           **date_plot_kwargs, 
                           **basic_plot_kwargs, 
                           **single_plot_kwargs),
                     y_axis_title=frac_title,
                     x_axis_title=date_title)
fig.show()
fig.write_image("20230615_dates_posit_frac.png")

### Calculate N Refs 

In [None]:
nrefsdf = calculate_stats(df, 
                metric_dict={"Number of References": np.count_nonzero},
                summary_col="Version", 
                 filter_column=date_col, 
                 filter_cutoffs=dates,
                          value_column=reference_col,
                          extra_groupby_cols=[id_col])

In [None]:
fig = scatter_wrapper(nrefsdf, 
                      dict(y="Value",
                          **date_plot_kwargs, 
                           **basic_plot_kwargs, 
                           **single_plot_kwargs),
                     y_axis_title="Mean Number of References",
                     x_axis_title=date_title)
fig.show()
fig.write_image("20230615_dates_nref.png")

# How is RMSD changing

In [None]:
y_axis_title = "RMSD (Å)"
rmsd_kwargs = dict(y="Value",
                           facet_col="Version",
                           color="Metric",
                          error_y="STD",
                          width=1000, height=400,)

## by tc

In [None]:
rmsddf = calculate_stats(df=df,
                         metric_dict={"Mininum RMSD":np.min, "Maximum RMSD":np.max, "Mean RMSD": np.mean},
                         summary_col="Version", 
                 filter_column=tc, 
                 filter_cutoffs=tc_cutoffs,
                          value_column=rmsd,
                          extra_groupby_cols=[id_col])
                         

In [None]:
fig = scatter_wrapper(rmsddf, 
                      dict(x=tc, **rmsd_kwargs),
                      replace_y_axis=True,
                      replace_xaxis=True,
                     y_axis_title=y_axis_title,
                     x_axis_title=tc_title)
fig.show()
fig.write_image("20230615_rmsd_tc.png")

## by tcr

In [None]:
rmsddf = calculate_stats(df=df,
                         metric_dict={"Mininum RMSD":np.min, "Maximum RMSD":np.max, "Mean RMSD": np.mean},
                         summary_col="Version", 
                 filter_column=tcr, 
                 filter_cutoffs=tc_cutoffs,
                          value_column=rmsd,
                          extra_groupby_cols=[id_col])
rmsddf["TanimotoCombo"] = 2-rmsddf["TanimotoCombo_R"]

In [None]:
fig = scatter_wrapper(rmsddf, 
                      dict(x=tc, **rmsd_kwargs),
                      replace_y_axis=True,
                      replace_xaxis=True,
                     y_axis_title=y_axis_title,
                     x_axis_title=tc_title,
                     x_axis_reversed=True)
fig.show()
fig.write_image("20230615_rmsd_tcr.png")

## by date

In [None]:
rmsddf = calculate_stats(df=df,
                         metric_dict={"Mininum RMSD":np.min, "Maximum RMSD":np.max, "Mean RMSD": np.mean},
                         summary_col="Version", 
                 filter_column=date_col, 
                 filter_cutoffs=dates,
                          value_column=rmsd,
                          extra_groupby_cols=[id_col])

In [None]:
nref_column="RMSD"
fig = scatter_wrapper(rmsddf, 
                      dict(x=date_col, **rmsd_kwargs),
                      replace_xaxis=True,
                      replace_y_axis=True,
                     y_axis_title=y_axis_title,
                     x_axis_title=date_title,)
fig.show()
fig.write_image("20230615_rmsd_dates.png")

## How do the different methods compare with the POSIT vs RMSD score?

In [None]:
df_ = df[df.Version == "All"]
for method in df.POSIT_Method.unique():
    fig = px.density_heatmap(df_[df_.POSIT_Method == method],
                     x="RMSD",
                     y="POSIT",
                             marginal_x="histogram", marginal_y="histogram",
                             height=800,
                             width=800,
                             title=method,
                             range_x=[0,11],
                             range_y=[0,1.1]
                            )
    fig.show()
    

## how about TC vs RMSD?

In [None]:
df_ = df[df.Version == "All"]
for method in df.POSIT_Method.unique():
    fig = px.density_heatmap(df_[df_.POSIT_Method == method],
                     x="RMSD",
                     y=tc,
                             marginal_x="histogram", marginal_y="histogram",
                             height=800,
                             width=800,
                             title=method,
                             range_x=[0,11],
                             range_y=[0,2.1]
                            )
    fig.show()
    

In [None]:
for version in df.Version.unique():
    fig = px.density_heatmap(df[df.Version == version],
                     x="RMSD",
                     y=tc,
                             marginal_x="histogram", marginal_y="histogram",
                             height=800,
                             width=800,
                             title=version,
                             range_x=[0,11],
                             range_y=[0,2.1]
                            )
    fig.show()
    

In [None]:
for version in df.Version.unique():
    fig = px.density_contour(df[df.Version == version],
                     x="RMSD",
                     y=tc,
                             marginal_x="histogram", marginal_y="histogram",
                             height=800,
                             width=800,
                             title=version,
                             range_x=[0,11],
                             range_y=[0,2.1]
                            )
    fig.show()
    

# How do RMSDs for self-docking look?

In [None]:
self_docked = df[(df.Reference_Ligand == df.Compound_ID)]

In [None]:
self_docked.TanimotoCombo.unique()

In [None]:
self_docked.groupby("Version").nunique()[["Compound_ID"]]

In [None]:
from functools import reduce

In [None]:
intersection = reduce(lambda x,y: x.intersection(y), self_docked.groupby("Version")["Compound_ID"].apply(lambda x: set(x.unique())))

In [None]:
len(intersection)

In [None]:
self_docked_filtered = self_docked[self_docked.Compound_ID.isin(intersection)]

In [None]:
len(self_docked_filtered)

In [None]:
len(self_docked)

In [None]:
import plotly.figure_factory as ff

In [None]:
self_docked_filtered.groupby(["Version", "POSIT_Method"]).nunique()

In [None]:
def plot_kde(df, value_column, group_column, groups):
    arrays = [df[df[group_column] == group][value_column] for group in groups]
    fig = ff.create_distplot(arrays, group_labels=groups, bin_size=0.25, show_rug=False)
    fig.update_layout(width=600, height=400)
    fig.update_xaxes(title="RMSD (Å)", range=[0,8])
    fig.update_yaxes(title="Frequency", range=[0,1])
    return fig

In [None]:
fig = plot_kde(self_docked, "RMSD", "Version", ["All", "Hybrid-Only"])
fig.update_layout(title="RMSD Distribution for Self-Docking Results")

In [None]:
fig2 = plot_kde(self_docked_filtered, "RMSD", "Version", ["All", "Hybrid-Only"])

In [None]:
fig2.write_image("20230613_self_docking_RMSD_kde.png")