# Intro

The purpose of this notebook is to generate the plots where the number of poses is on the x-axis and the fraction of protein-ligand complexes docked within 2A is on the y-axis.

# Imports

In [None]:
from pathlib import Path
import pandas as pd
import plotly.express as px
from harbor.analysis.cross_docking import FractionGood
import numpy as np

In [None]:
results_csv = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20240424_multi_pose_docking_cross_docking/results_csvs/20240503_combined_results_with_data.csv")

In [None]:
data_path = results_csv.parent.parent / "analyzed_data"
figure_path = Path("figures")

In [None]:
df = pd.read_csv(results_csv, index_col=0)

In [None]:
df["Complex_ID"] = df["Query_Ligand"] + "_" + df["Reference_Structure"]

# Remove failed complexes

In [None]:
df = df[~df["RMSD"].isna()]

In [None]:
fractions = []
total = df.Complex_ID.nunique()
n_poses_list = [1] + list(range(5,51, 5))
n_actual_poses_list = []
for n_poses in n_poses_list:
    subset_df = df.groupby(["Complex_ID"]).head(n_poses)
    n_poses_df = subset_df.groupby("Complex_ID")["Pose_ID"].count()
    n_poses_df['N_Poses'] = n_poses
    n_actual_poses_list.append(n_poses_df)
    fraction = sum(subset_df.sort_values("RMSD").groupby("Complex_ID").head(1)["RMSD"] <= 2.0) / total
    fg = FractionGood(name=f'{n_poses}_Poses', total=total, fraction=fraction, replicates=[fraction])
    fractions.append(fg)

In [None]:
n_actual_poses_list[0]["Test"] = 1

In [None]:
n_actual_poses_dfs = []
for i, n_poses_series in enumerate(n_actual_poses_list):
    n_actual_poses_dfs.append(pd.DataFrame({"Actual_Poses": n_poses_series, "N_Poses": n_poses_list[i]}))

In [None]:
n_actual_poses_df = pd.concat(n_actual_poses_dfs)

In [None]:
from plotly import figure_factory as ff
def plot_kde(df, value_column, group_column, groups=None):
    """
    Plots a KDE plot of the values in `value_column` grouped by the values in `group_column`.
    :param df:
    :param value_column:
    :param group_column:
    :param groups:
    :return:
    """
    if not groups:
        groups = df[group_column].unique()
    arrays = [df[df[group_column] == group][value_column] for group in groups]
    fig = ff.create_distplot(arrays, group_labels=[str(group) for group in groups], bin_size=0.25, show_rug=False)    
    fig.update_layout(width=600, height=400)
    return fig

In [None]:
fig = plot_kde(n_actual_poses_df[n_actual_poses_df.N_Poses > 1], "Actual_Poses", "N_Poses", groups=None)

In [None]:
n_poses_list

In [None]:
fig = px.histogram(n_actual_poses_df,
                   category_orders={"N_Poses": n_poses_list},
             x="Actual_Poses", 
             color="N_Poses", 
             template="simple_white", 
             height=1200, 
             width=1200, 
             barmode="overlay", 
             marginal="box", 
             opacity=0.6, 
             color_discrete_sequence=px.colors.qualitative.Safe)

In [None]:
fig.show()

In [None]:
fraction_df = pd.DataFrame.from_records([f.get_records() for f in fractions])

In [None]:
fraction_df["N_Poses"] = n_poses_list

In [None]:
fraction_df["Error_Lower"] = fraction_df["Fraction"] - fraction_df["CI_Lower"]
fraction_df["Error_Upper"] = fraction_df["CI_Upper"] - fraction_df["Fraction"]

In [None]:
large_font = 18
small_font = 12

In [None]:
fig = px.line(fraction_df, x="N_Poses", y="Fraction", error_y="Error_Upper", error_y_minus="Error_Lower", template="simple_white", height=600, width=800, color_discrete_sequence=px.colors.qualitative.Safe)
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title="<b> Dataset Split, Score Function </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of Kept Poses </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.show()
fig.write_image(figure_path / "20240620_png")

# I guess what we really want to know is, for each complex with a pose < 2A, what is its number?

## First, get complex_id with at least 1 pose < 2A

In [None]:
good_complexes_df = df[df["RMSD"] <= 2.0]

In [None]:
first_below_2A = good_complexes_df.sort_values("Pose_ID").groupby("Complex_ID").head(1)['Pose_ID']
best_pose = good_complexes_df.sort_values("RMSD").groupby("Complex_ID").head(1)['Pose_ID']

In [None]:
df1 = pd.DataFrame({"Pose_ID": first_below_2A, "Label": "First < 2Å"})
df2 = pd.DataFrame({"Pose_ID": best_pose, "Label": "Best < 2Å"})
grouped_df = pd.concat([df1, df2])

In [None]:
px.histogram(grouped_df, x="Pose_ID", template="simple_white", height=600, width=800, color_discrete_sequence=px.colors.qualitative.Safe, color="Label", barmode="overlay")

In [None]:
fig = px.ecdf(grouped_df, x="Pose_ID", template="simple_white", color="Label", height=600, width=800, color_discrete_sequence=px.colors.qualitative.Safe)
fig.update_layout(title="Distribution of < 2Å poses in the list of returned poses", xaxis_title="Pose Number", yaxis_title=f"Fraction of Complexes ({len(first_below_2A)} total)", legend_title="Pose Type")
fig.write_image(figure_path / "20240620_pose_distribution_of_2A_poses.png")

In [None]:
first_rmsds = good_complexes_df.sort_values("Pose_ID").groupby("Complex_ID").head(1).sort_values("Complex_ID")["RMSD"].astype(float).reset_index()
best_rmsds = good_complexes_df.sort_values("RMSD").groupby("Complex_ID").head(1).sort_values("Complex_ID")["RMSD"].astype(float).reset_index()

In [None]:
rmsd_diffs = first_rmsds - best_rmsds

In [None]:
rmsd_diffs.mean()
rmsd_diffs.describe()

In [None]:
fig = px.ecdf(rmsd_diffs, x="RMSD", template="simple_white", height=600, width=800)
fig.update_layout(title="Difference between the best pose and the first <2Å pose for all complexes with a <2Å Pose", xaxis_title="dRMSD (Å)", yaxis_title=f"Fraction of Complexes ({len(rmsd_diffs)} total)")
fig.write_image(figure_path / "20240620_dRMSD_ecdf_first_2A_pose.png")

## When the first pose isn't the best pose, how much worse is it?

In [None]:
good_complexes_df_all_poses = df[df.Complex_ID.isin(good_complexes_df.Complex_ID.unique())]

In [None]:
first_pose_rmsds = good_complexes_df_all_poses.sort_values("Pose_ID").groupby("Complex_ID").head(1).sort_values("Complex_ID")
best_rmsds = good_complexes_df.sort_values("RMSD").groupby("Complex_ID").head(1).sort_values("Complex_ID")

In [None]:
rmsd_diffs2 = first_pose_rmsds["RMSD"].reset_index().astype(float) - best_rmsds["RMSD"].reset_index().astype(float)

In [None]:
fig = px.ecdf(rmsd_diffs2, x="RMSD", template="simple_white", height=600, width=800, range_x=[0, 8])
fig.update_layout(title="Difference between the best pose and the first pose for all complexes with a <2Å Pose", xaxis_title="dRMSD (Å)", yaxis_title=f"Fraction of Complexes ({len(rmsd_diffs)} total)")
fig.write_image(figure_path / "20240620_dRMSD_ecdf.png")

In [None]:
rmsd_df1 = pd.DataFrame({"RMSD": rmsd_diffs["RMSD"], "Type": "First Pose < 2Å"})
rmsd_df2 = pd.DataFrame({"RMSD": rmsd_diffs2["RMSD"], "Type": "First Pose"})
rmsd_df = pd.concat([rmsd_df1, rmsd_df2])

In [None]:
fig = px.ecdf(rmsd_df, x="RMSD", template="simple_white", height=600, width=800, range_x=[0, 8], color="Type")
fig.update_layout(title="<b> RMSD difference from the best pose for complexes with a <2Å Pose </b>", xaxis_title="dRMSD (Å)", yaxis_title=f"<b>Fraction of Complexes ({len(rmsd_diffs)} total)</b>")
fig.update_layout(legend=dict(title="<b> Chosen Pose </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              # title_font_size=large_font, 
                              font_color='black'),)
fig.write_image(figure_path / "20240620_dRMSD_ecdf.png")

In [None]:
fig.show()