# Imports

In [None]:
import pandas as pd, numpy as np
import plotly.express as px
from pathlib import Path
from asapdiscovery.docking.analysis import get_df_subset, calc_perc_good, calculate_perc_good
from importlib import reload
import asapdiscovery.docking.analysis as a

# Load Paths

In [None]:
import sys
sys.path.append(str(Path("../../../").resolve()))
from software.paths import paths

In [None]:
local_analysis = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/")

In [None]:
df = pd.read_csv(local_analysis / "20230611-combined.csv", index_col=0)

## a bit of fixing

In [None]:
df["TanimotoCombo_R"] = 2-df.TanimotoCombo

# Plotting

## standard variables

In [None]:
tc = "TanimotoCombo"
tcr = "TanimotoCombo_R"
y="Percentage"
color="Version"
idcol="Compound_ID"
rmsd="RMSD"
split_cols=["Version"]
full_split_cols=["Version", "POSIT_Method"]
method_split=["POSIT_Method"]
n=1
good=2
tc_cutoffs = np.linspace(0,2,50)
tcr_cutoffs = np.linspace(2,0,50)
sort_cols = ["RMSD", "POSIT_R", "Chemgauss4", "TanimotoCombo_R"]
sort_col_name = "Sorted_By"
dates = df.Structure_Date.unique()
date_col = "Structure_Date"

## TanimotoCombo

### use_per_split_mol=False

In [None]:
dfs = []
for splits in [split_cols, full_split_cols]:
    for sort_column in sort_cols:
        new_df = a.calculate_perc_good(df,
                                        id_column=idcol,
                        filter_column=tc,
                        filter_cutoffs=tc_cutoffs,
                        sort_column=sort_column,
                        n=n,
                        score_column=rmsd,
                        good_score=good,
                        split_cols=splits,
                                      reference_col="Structure_Source")
        new_df[sort_col_name] = sort_column
        new_df["Percentage References Used"] = new_df["Number of References"] / 219 # this estimate of percentage is wrong?
        dfs.append(new_df)

In [None]:
split_combined = [pd.concat(dfs[0:4]), pd.concat(dfs[4:])]

In [None]:
split_combined[0]["Vtype"] = "General"
split_combined[1]["Vtype"] = "Detailed"

In [None]:
combined = pd.concat(split_combined)

In [None]:
fig = px.scatter(combined,
                 x=tc,
                 color=color,
                 facet_col="Vtype",
                 y="Percentage Docked",)

In [None]:
fig.show()

In [None]:
fig = px.scatter(combined,
                 x=tc,
                 color=color,
                 facet_col=sort_col_name,
                 facet_row="Vtype",
                 y="Percentage References Used",
                hover_data=combined.columns,
                height=600, 
                 width=1200,)

In [None]:
fig.show()

In [None]:
fig = px.scatter(combined, x=tc, 
                 y=y, 
                 color=color,
                 facet_col=sort_col_name,
                 facet_row="Vtype",
                 hover_data=combined.columns,
                 height=600, 
                 width=1200,)
fig.show()

In [None]:
fig = px.scatter(combined[(combined.Sorted_By=="POSIT_R") & (combined.Vtype == "General")], x=tc, 
                 y=y, 
                 color=color,
                 # facet_row="Vtype",
                 hover_data=combined.columns,
                 height=400, 
                 width=600,range_y=[0,1.1])
fig.update_xaxes(title="TanimotoCombo Cutoff for Inclusion of Reference Structures")
fig.show()

### use_per_split_mol=True

In [None]:
dfs = []
for splits in [split_cols, full_split_cols]:
    for sort_column in sort_cols:
        new_df = a.calculate_perc_good(df,
                                        id_column=idcol,
                        filter_column=tc,
                        filter_cutoffs=tc_cutoffs,
                        sort_column=sort_column,
                        n=n,
                        score_column=rmsd,
                        good_score=good,
                        split_cols=splits,
                                    use_per_split_mol=True,
                                      reference_col="Structure_Source")
        new_df[sort_col_name] = sort_column
        dfs.append(new_df)

In [None]:
combined = pd.concat(dfs)

In [None]:
fig = px.scatter(combined, x=tc, 
                 y=y, 
                 color=color,
                 facet_col=sort_col_name,
                 hover_data=combined.columns,
                 height=600, 
                 width=1200,)
fig.show()

### reverse TC

In [None]:
reload(a)
dfs = []
for splits in [split_cols, full_split_cols]:
    for sort_column in sort_cols:
        new_df = a.calculate_perc_good(df,
                                        id_column=idcol,
                        filter_column=tcr,
                        filter_cutoffs=tc_cutoffs,
                        sort_column=sort_column,
                        n=n_,
                        score_column=rmsd,
                        good_score=good,
                        split_cols=splits,
                                     use_per_split_mol=True,
                                      reference_col="Reference_Ligand")
        new_df[sort_col_name] = sort_column
        new_df["TanimotoCombo"] = 2-new_df["TanimotoCombo_R"]
        dfs.append(new_df)

In [None]:
split_combined = [pd.concat(dfs[0:4]), pd.concat(dfs[4:])]

In [None]:
split_combined[0]["Vtype"] = "General"
split_combined[1]["Vtype"] = "Detailed"

In [None]:
combined = pd.concat(split_combined)

In [None]:
fig = px.scatter(combined, 
                 x=tc, 
                 y=y, 
                 color=color,
                 facet_col=sort_col_name,
                 facet_row = "Vtype",
                 hover_data=combined.columns,
                 height=600, 
                 width=1200,
                range_y=[0,1.1])
fig.update_xaxes(autorange="reversed")
fig.show()

In [None]:
fig = px.scatter(combined[(combined.Sorted_By=="POSIT_R") & (combined.Vtype == "General")], x=tc, 
                 y=y, 
                 color=color,
                 # facet_row="Vtype",
                 hover_data=combined.columns,
                 height=600, 
                 width=600,range_y=[0,1.1])
fig.update_xaxes(title="TanimotoCombo Cutoff for Inclusion of Reference Structures")
fig.update_xaxes(autorange="reversed")
fig.show()

In [None]:
means = []
cutoffs = []
versions = []
sortby = []
sds = []
for sort_col in sort_cols:
    for cutoff in tc_cutoffs:
        rmsds = a.get_df_subset(df, tcr, cutoff, sort_col).groupby("Version")["RMSD"]
        mean_list = rmsds.mean()
        sd = rmsds.std()
        for version in mean_list.index:
            means.append(mean_list[version])
            cutoffs.append(2-cutoff)
            versions.append(version)
            sortby.append(sort_col)
            sds.append(sd[version])

In [None]:
mean_df = pd.DataFrame({"Mean RMSD (Å)":means, "TanimotoCombo":cutoffs, "Version":versions, "Sorted_By":sortby, "RMSD Std":sds})

In [None]:
fig = px.scatter(mean_df,y="Mean RMSD (Å)", facet_col="Sorted_By", error_y="RMSD Std", x=tc, color=color)
fig.update_xaxes(autorange="reversed")

## how to calculate nrefs?

In [None]:
means = []
cutoffs = []
versions = []
sds = []
for cutoff in tc_cutoffs:
    ref_list = df[df[tc] <= cutoff].groupby(["Version", "Compound_ID"])["Structure_Source"].count()
    mean_list = ref_list.groupby("Version").mean()
    sd_list = ref_list.groupby("Version").std()
    for version in mean_list.index:
        means.append(mean_list[version])
        cutoffs.append(cutoff)
        versions.append(version)
        sds.append(sd_list[version])

In [None]:
mean_df = pd.DataFrame({"Mean Number of References":means, "TanimotoCombo":cutoffs, "Version":versions, "STD":sds})

In [None]:
fig = px.scatter(mean_df,y="Mean Number of References", x=tc, color=color, height=400, width=600)

In [None]:
fig.show()

## more fun with RMSD using TC bounds

In [None]:
bins = np.linspace(0,2,9)
bounds = [(i,j) for i,j in zip(bins[:-1], bins[1:])]

In [None]:
dfs = []
for name, metric in {"Min":np.min, "Max":np.max, "Mean": np.mean}.items():
    means = []
    versions = []
    avg_tc = []
    sds = []
    for i,j in bounds:
        values = df[(df[tc] > i) & (df[tc]<=j)].groupby(["Version", "Compound_ID"])["RMSD"].apply(metric)
        mean = values.groupby("Version").mean()
        sd = values.groupby("Version").std()
        for version in mean.index:
            means.append(mean[version])
            versions.append(version)
            avg_tc.append(str(np.mean([i,j])))
            sds.append(sd[version])
    dfs.append(pd.DataFrame({"Value": means, "Metric": name, "STD": sds, "Version":versions, "TanimotoCombo":avg_tc}))

In [None]:
combined = pd.concat(dfs)

In [None]:
fig = px.scatter(combined, 
                 x="TanimotoCombo", 
                 y="Value",
                 facet_col="Version",
                 error_y="STD",
             color="Metric")
fig.show()

## Structure Dates

### use_per_split_mol=False

In [None]:
dfs = []
for splits in [split_cols, full_split_cols]:
    for sort_column in sort_cols:
        new_df = calculate_perc_good(df,
                                        id_column=idcol,
                        filter_column=date_col,
                        filter_cutoffs=dates,
                        sort_column=sort_column,
                        n=n,
                        score_column=rmsd,
                        good_score=good,
                        split_cols=splits,
                                    reference_col="Structure_Source")
        new_df[sort_col_name] = sort_column
        dfs.append(new_df)

In [None]:
split_combined = [pd.concat(dfs[0:4]), pd.concat(dfs[4:])]

In [None]:
split_combined[0]["Vtype"] = "General"
split_combined[1]["Vtype"] = "Detailed"

In [None]:
structure_df = pd.concat(split_combined)

In [None]:
fig = px.scatter(structure_df[(structure_df.Sorted_By=="POSIT_R") & (structure_df.Vtype == "General")],
                 x=date_col,
                 color=color,
                 # facet_col=sort_col_name,
                 # facet_row="Vtype",
                 y=y,height=400, width=600)
fig.show()

In [None]:
fig = px.scatter(structure_df, 
                 x=date_col, 
                 y=y, 
                 color=color,
                 facet_col=sort_col_name,
                 hover_data=structure_df.columns,
                 height=600, 
                 width=1200,)
fig.show()

## How do the different methods compare with the POSIT vs RMSD score?

In [None]:
df_ = df[df.Version == "All"]
for method in df.POSIT_Method.unique():
    fig = px.density_heatmap(df_[df_.POSIT_Method == method],
                     x="RMSD",
                     y="POSIT",
                             marginal_x="histogram", marginal_y="histogram",
                             height=800,
                             width=800,
                             title=method,
                             range_x=[0,11],
                             range_y=[0,1.1]
                            )
    fig.show()
    

## how about TC vs RMSD?

In [None]:
df_ = df[df.Version == "All"]
for method in df.POSIT_Method.unique():
    fig = px.density_heatmap(df_[df_.POSIT_Method == method],
                     x="RMSD",
                     y=tc,
                             marginal_x="histogram", marginal_y="histogram",
                             height=800,
                             width=800,
                             title=method,
                             range_x=[0,11],
                             range_y=[0,2.1]
                            )
    fig.show()
    

In [None]:
for version in df.Version.unique():
    fig = px.density_heatmap(df[df.Version == version],
                     x="RMSD",
                     y=tc,
                             marginal_x="histogram", marginal_y="histogram",
                             height=800,
                             width=800,
                             title=version,
                             range_x=[0,11],
                             range_y=[0,2.1]
                            )
    fig.show()
    

In [None]:
for version in df.Version.unique():
    fig = px.density_contour(df[df.Version == version],
                     x="RMSD",
                     y=tc,
                             marginal_x="histogram", marginal_y="histogram",
                             height=800,
                             width=800,
                             title=version,
                             range_x=[0,11],
                             range_y=[0,2.1]
                            )
    fig.show()
    

# How do RMSDs for self-docking look?

In [None]:
self_docked = df[(df.Reference_Ligand == df.Compound_ID)]

In [None]:
self_docked.TanimotoCombo.unique()

In [None]:
self_docked.groupby("Version").nunique()

In [None]:
from functools import reduce

In [None]:
intersection = reduce(lambda x,y: x.intersection(y), self_docked.groupby("Version")["Compound_ID"].apply(lambda x: set(x.unique())))

In [None]:
len(intersection)

In [None]:
self_docked_filtered = self_docked[self_docked.Compound_ID.isin(intersection)]

In [None]:
len(self_docked_filtered)

In [None]:
len(self_docked)

In [None]:
import plotly.figure_factory as ff

In [None]:
def plot_kde(df, value_column, group_column, groups):
    arrays = [df[df[group_column] == group][value_column] for group in groups]
    fig = ff.create_distplot(arrays, group_labels=groups, bin_size=0.25, show_rug=False)
    fig.update_layout(width=800, height=800)
    fig.update_xaxes(title="RMSD (Å)", range=[0,8])
    fig.update_yaxes(title="Frequency", range=[0,1])
    return fig

In [None]:
plot_kde(self_docked, "RMSD", "Version", ["All", "Hybrid-Only"])

In [None]:
fig2 = plot_kde(self_docked_filtered, "RMSD", "Version", ["All", "Hybrid-Only"])

In [None]:
fig2.write_image("20230613_self_docking_RMSD_kde.png")