# Imports

In [None]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from pathlib import Path

# Load Data

In [None]:
results_csv = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20240424_multi_pose_docking_cross_docking/results_csvs/20240503_combined_results_with_data.csv")

In [None]:
data_path = results_csv.parent.parent / "analyzed_data"
figure_path = Path("figures")

In [None]:
df_paths = data_path.glob("*/*.csv")

In [None]:
dfs = [pd.read_csv(path) for path in df_paths]

In [None]:
ogdf = pd.concat(dfs)

In [None]:
ogdf["Error_Lower"] = ogdf["Fraction"] - ogdf["CI_Lower"]
ogdf["Error_Upper"] = ogdf["CI_Upper"] - ogdf["Fraction"]

In [None]:
df = ogdf[ogdf.PoseSelection == "Default"]

# Plotting Variables

In [None]:
large_font = 18
small_font = 12

# Plotting Functions

In [None]:
def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip("#")
    if len(hex_color) == 3:
        hex_color = hex_color * 2
    return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)

In [None]:
def rgb_to_rgba(rgb_str, alpha):
    # Split the RGB string into its components
    rgb_values = rgb_str.strip('rgb()').split(',')
    
    # Extract individual RGB values and convert them to integers
    r, g, b = map(int, rgb_values)
    
    # Construct the RGBA string
    rgba_str = f"rgba({r}, {g}, {b}, {alpha})"
    
    return rgba_str

In [None]:
def plot_scatter_with_confidence_bands(df, x, y, split_by, error_y_plus, error_y_minus, template="plotly_white", height=600, width=800, colors=px.colors.qualitative.Plotly):
    traces = []
    
    # Covert Colors
    if colors[0][0] == "#":
        colors = [f"rgb{hex_to_rgb(color)}" for color in colors]
        
    
    # order by mean
    
    ordered_splits = df.groupby(split_by)[y].mean().sort_values().index.tolist()
    for i, split in enumerate(ordered_splits):
        subdf = df[df[split_by] == split]
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y],
                                 mode='lines',
                                 showlegend=True,
                                 line_color=rgb_to_rgba(colors[i], 1),
                                 ))
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y] + subdf[error_y_plus],
                                 mode='lines',
                                 fillcolor=rgb_to_rgba(colors[i], 0.15),
                                 line_width=0,
                                 showlegend=False,
                                 ))
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y] - subdf[error_y_minus],
                                 fill='tonexty',
                                 mode='lines',
                                fillcolor=rgb_to_rgba(colors[i], 0.15),
                                 line_width=0,
                                 showlegend=False, 
                                 ))
    fig = go.Figure(traces)
    fig.update_layout(template=template, height=height, width=width)
    return fig

# Dataset Split Comparison

In [None]:
dataset_split_df = df[(df.StructureChoice == "Dock_to_All")&(df.PoseSelection_Choose_N == 1)]

In [None]:
dataset_split_df.nunique()

In [None]:
# there are duplicates but they are identical
dataset_split_df_simple = dataset_split_df.groupby(["Score", "Split", "N_Per_Split"]).max().reset_index()

In [None]:
dataset_split_df_simple.nunique()

In [None]:
fig = plot_scatter_with_confidence_bands(df = dataset_split_df_simple, 
                                         x = "N_Per_Split", 
                                         y = "Fraction", 
                                         split_by = "Split", 
                                         error_y_plus="Error_Upper", 
                                         error_y_minus="Error_Lower", 
                                         template="simple_white", 
                                         height=600, 
                                         width=800, 
                                         colors=px.colors.qualitative.Plotly)

fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title="<b> Dataset Split </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.show()
fig.write_image(figure_path / "20240528_dataset_split_comparison.png")
fig.write_html(figure_path / "20240528_dataset_split_comparison.html")

 # Use default plotly express stuff

In [None]:
fig = px.line(dataset_split_df_simple, x="N_Per_Split", y="Fraction", color="Split", error_y="Error_Upper", error_y_minus="Error_Lower", template="simple_white", symbol="Score", height=600, width=800, color_discrete_sequence=px.colors.qualitative.Safe)
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title="<b> Dataset Split, Score Function </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.show()
fig.write_image(figure_path / "20240528_dataset_split_comparison_v2.png")

# Plot Everything Separately

In [None]:
for split in df.Split.unique().tolist():
    structure_choice_df = df[df.Split == split]
    for choice in structure_choice_df.StructureChoice.unique().tolist():
        subset_df = structure_choice_df[structure_choice_df.StructureChoice == choice]
        if len(subset_df) == 0:
            continue
        fig = plot_scatter_with_confidence_bands(subset_df, 
                                                 "N_Per_Split", 
                                                 "Fraction", 
                                                 "StructureChoice_Choose_N", 
                                                 "Error_Upper", 
                                                 "Error_Lower", 
                                                 template="simple_white", 
                                                 height=600, 
                                                 width=800, 
                                                 colors=px.colors.qualitative.Plotly)
        fig.update_layout(
            font=dict(size=small_font, 
                      family='Arial'
                      ),
            legend=dict(title=f"<b> DatasetSplit:</b> {split} <br>"
                              f"<b> StructureChoice:</b> {choice}", 
                                      x=0.4, y=0.1, 
                                      traceorder='reversed', 
                                      title_font_size=large_font, 
                                      font_color='black'),
                         xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                                    title_font=dict(size=large_font), 
                                    color='black', 
                                    ),
                          yaxis=dict(range=(0,1), 
                                     title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                              title_font=dict(size=large_font), 
                                     color='black', 
                                     ),)
        fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
        fig.write_image(figure_path / f"20240503_{split}_{choice}.png")
        fig.write_html(figure_path / f"20240503_{split}_{choice}.html")

# DateSplit - Structure Choice Comparison

In [None]:
choose_n = [1,10, "All"]
structure_choice_comparison_df = pd.concat([
    df[(df.Split == "DateSplit")&(df.StructureChoice_Choose_N.isin(choose_n))],
], )
structure_choice_comparison_df["Structure Choice"] = [f"{i}_{j}" for i, j in zip(structure_choice_comparison_df["StructureChoice"].tolist(), structure_choice_comparison_df["StructureChoice_Choose_N"].tolist())]

In [None]:
fig = plot_scatter_with_confidence_bands(df=structure_choice_comparison_df, 
                                            x="N_Per_Split", 
                                             y="Fraction", 
                                             split_by="Structure Choice", 
                                             error_y_plus="Error_Upper", 
                                             error_y_minus="Error_Lower", 
                                             template="simple_white", 
                                             height=600, 
                                             width=800, 
                                             colors=px.colors.qualitative.Safe)
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title=f"<b> Structure Choice</b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.write_image(figure_path / f"20240503_date_split_structure_choice_comparison_{choose_n}.png")
fig.write_html(figure_path / f"20240503_date_split_structure_choice_comparison_{choose_n}.html")

# Calculate how many poses

In [None]:
rdf = pd.read_csv(results_csv, index_col=0)

In [None]:
rdf = rdf.sort_values(["Query_Ligand", "Reference_Ligand", "RMSD"])

In [None]:
rdf.RMSD.isna().sum()

## drop the previously added missing values

In [None]:
rdf = rdf.dropna(subset=["RMSD"])

In [None]:
n_poses = rdf.groupby(["Query_Ligand", "Reference_Ligand"]).count().reset_index()

In [None]:
combined_data = rdf.sort_values("RMSD").groupby(["Query_Ligand", "Reference_Ligand"]).head(1).reset_index()

In [None]:
n_poses = n_poses.rename(columns={"Pose_ID": "N_Poses"})

In [None]:
n_poses = n_poses[["Query_Ligand", "Reference_Ligand", "N_Poses"]]

In [None]:
merged = n_poses.merge(combined_data, on=["Query_Ligand", "Reference_Ligand"])

In [None]:
lig_orders = merged.sort_values("Reference_Structure_Date").Reference_Ligand.unique().tolist()

In [None]:
fig = px.density_heatmap(merged,x="Query_Ligand", y="Reference_Ligand", z="N_Poses", category_orders={"Query_Ligand": lig_orders, "Reference_Ligand": lig_orders,}, color_continuous_scale="Viridis", height=800, width=1000)
fig.write_image(figure_path / "20240521_n_poses_density_heatmap.png")
fig.write_html(figure_path / "20240521_n_poses_density_heatmap.html")

# is the number of poses related to how similar reference / query are?

In [None]:
fig = px.density_heatmap(merged, y="N_Poses", x="Tanimoto", marginal_x="histogram", marginal_y="histogram", color_continuous_scale=px.colors.sequential.Viridis, height=800, width=1000, template="simple_white", title="Number of Poses vs Chemical Similarity (ECFP4) for All Complex Pairs")
fig.write_image(figure_path / "20240521_n_poses_vs_tanimoto_density_heatmap.png")
fig.write_html(figure_path / "20240521_n_poses_vs_tanimoto_density_heatmap.html")

## no

## is is more obvious without the zero pose ones?

## no difference because there are no longer any pairs with only 1 pose (or 0 poses)

In [None]:
merged_no_zero_poses = merged[merged.N_Poses > 0]

In [None]:
sum(merged_no_zero_poses.N_Poses == 1)

In [None]:
fig = px.density_heatmap(merged_no_zero_poses, y="N_Poses", x="Tanimoto", marginal_x="histogram", marginal_y="histogram", color_continuous_scale=px.colors.sequential.Viridis, height=800, width=1000, template="simple_white")
fig.write_image(figure_path / "20240521_n_poses_vs_tanimoto_density_heatmap_no_zero_poses.png")
fig.write_html(figure_path / "20240521_n_poses_vs_tanimoto_density_heatmap_no_zero_poses.html")

## is there a difference if you just look at good rmsd poses?

In [None]:
best_rmsd_per_complex = rdf.sort_values(["RMSD"]).groupby(["Query_Ligand", "Reference_Ligand"]).head(1)

In [None]:
best_rmsd_per_complex["Good_RMSD"] = best_rmsd_per_complex.apply(lambda x: x["RMSD"] < 2, axis=1)
best_rmsd = best_rmsd_per_complex[best_rmsd_per_complex.Good_RMSD]

In [None]:
best_rmsd_with_n_poses = best_rmsd.merge(n_poses, on=["Query_Ligand", "Reference_Ligand"])

In [None]:
fig = px.density_heatmap(best_rmsd_with_n_poses, y="N_Poses", x="Tanimoto", marginal_x="histogram", marginal_y="histogram", color_continuous_scale=px.colors.sequential.Viridis, height=800, width=1000, template="simple_white", title="Number of Poses vs Chemical Similarity (ECFP4) for Best RMSD Poses (<2Å)")
fig.write_image(figure_path / "20240521_best_rmsd_vs_tanimoto_density_heatmap.png")
fig.write_html(figure_path / "20240521_best_rmsd_vs_tanimoto_density_heatmap.html")

## scatter plot with trendline

In [None]:
fig = px.scatter(merged, x="Tanimoto", y="N_Poses", trendline="ols", height=800, width=1000, template="simple_white", title="Number of Poses vs Chemical Similarity (ECFP4) for Best RMSD Poses (<2Å)")

## actually it's kind of annoying to add a trendline to this plot

In [None]:
from software.plotting import plot_scatter_with_regression_line_plotly

In [None]:
fig = plot_scatter_with_regression_line_plotly(merged.Tanimoto, merged.N_Poses)
fig.update_layout(template="simple_white", height=600, width=800, title="Number of Poses vs Chemical Similarity (ECFP4) for All Complex Pairs")
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(
                              x=0.4, y=0.1,),
                 xaxis=dict(title="<b> Tanimoto (ECFP4) </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(title="<b> Number of Poses </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.write_image(figure_path / "20240521_n_poses_vs_tanimoto_regression.png")

## only good rmsds

In [None]:
fig = plot_scatter_with_regression_line_plotly(best_rmsd_with_n_poses.Tanimoto, best_rmsd_with_n_poses.N_Poses)
fig.update_layout(template="simple_white", height=600, width=800, title="Number of Poses vs Chemical Similarity (ECFP4) for Best RMSD Poses (<2Å)")
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(
                              x=0.4, y=0.1,),
                 xaxis=dict(title="<b> Tanimoto (ECFP4) </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(title="<b> Number of Poses </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.write_image(figure_path / "20240521_n_poses_vs_tanimoto_regression_best_rmsd.png")

# what's the correlation between RMSD and number of poses?

In [None]:
fig = plot_scatter_with_regression_line_plotly(merged.N_Poses, merged.RMSD)
fig.update_layout(template="simple_white", height=600, width=800, title="Number of Poses vs Best RMSD for All Complex Pairs")
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(
                              x=0.4, y=0.1,),
                 xaxis=dict(title="<b> Number of Poses </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(title="<b> RMSD (A) </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.write_image(figure_path / "20240521_n_poses_vs_rmsd.png")

# Do we do better if we have more poses?

In [None]:
multipose = ogdf[ogdf.PoseSelection != "Default"]

In [None]:
ogdf.PoseSelection_Choose_N.unique()

In [None]:
random = multipose[multipose.Split == "RandomSplit"]

In [None]:
random.Score.unique()

In [None]:
random

In [None]:
fig = plot_scatter_with_confidence_bands(df=random, 
                                            x="N_Per_Split", 
                                             y="Fraction", 
                                             split_by="PoseSelection_Choose_N", 
                                             error_y_plus="Error_Upper", 
                                             error_y_minus="Error_Lower", 
                                             template="simple_white", 
                                             height=600, 
                                             width=800, 
                                             colors=px.colors.qualitative.Safe)
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title=f"<b> Number of Poses Included </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.show()
fig.write_image(figure_path / f"20240527_n_poses_random_rmsd.png")
fig.write_html(figure_path / f"20240527_n_poses_random_rmsd.html")

In [None]:
datesplit = multipose[(multipose.Split == "DateSplit")&(multipose.Score == "RMSD")]

In [None]:
fig = plot_scatter_with_confidence_bands(df=datesplit, 
                                            x="N_Per_Split", 
                                             y="Fraction", 
                                             split_by="PoseSelection_Choose_N", 
                                             error_y_plus="Error_Upper", 
                                             error_y_minus="Error_Lower", 
                                             template="simple_white", 
                                             height=600, 
                                             width=800, 
                                             colors=px.colors.qualitative.Safe)
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title=f"<b> Number of Poses Included </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.show()
fig.write_image(figure_path / f"20240527_n_poses_datesplit_rmsd.png")
fig.write_html(figure_path / f"20240527_n_poses_datesplit_rmsd.html")

In [None]:
posit_datesplit = multipose[(multipose.Split == "DateSplit")&(multipose.Score == "POSIT_Probability")]

In [None]:
posit_datesplit

In [None]:
fig = plot_scatter_with_confidence_bands(df=posit_datesplit, 
                                            x="N_Per_Split", 
                                             y="Fraction", 
                                             split_by="PoseSelection_Choose_N", 
                                             error_y_plus="Error_Upper", 
                                             error_y_minus="Error_Lower", 
                                             template="simple_white", 
                                             height=600, 
                                             width=800, 
                                             colors=px.colors.qualitative.Safe)
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title=f"<b> Number of Poses Included </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.show()
fig.write_image(figure_path / f"20240527_n_poses_datesplit_posit.png")
fig.write_html(figure_path / f"20240527_n_poses_datesplit_posit.html")