# Imports

In [None]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from pathlib import Path

# Load Data

In [None]:
results_csv = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20240424_multi_pose_docking_cross_docking/results_csvs/20240503_combined_results_with_data.csv")

In [None]:
data_path = Path("analyzed_data")
figure_path = Path("figures")

In [None]:
df_paths = data_path.glob("*.csv")

In [None]:
dfs = [pd.read_csv(path) for path in df_paths]

In [None]:
df = pd.concat(dfs)

In [None]:
df["Error_Lower"] = df["Fraction"] - df["CI_Lower"]
df["Error_Upper"] = df["CI_Upper"] - df["Fraction"]

# Plotting Variables

In [None]:
large_font = 18
small_font = 12

# Plotting Functions

In [None]:
def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip("#")
    if len(hex_color) == 3:
        hex_color = hex_color * 2
    return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)

In [None]:
def rgb_to_rgba(rgb_str, alpha):
    # Split the RGB string into its components
    rgb_values = rgb_str.strip('rgb()').split(',')
    
    # Extract individual RGB values and convert them to integers
    r, g, b = map(int, rgb_values)
    
    # Construct the RGBA string
    rgba_str = f"rgba({r}, {g}, {b}, {alpha})"
    
    return rgba_str

In [None]:
def group_by_two_columns(df, col1, col2):
    """
    Groups the DataFrame by two specified columns and returns a list of DataFrames,
    each corresponding to a unique pair in those two columns.

    Parameters:
    df (pd.DataFrame): The input DataFrame.
    col1 (str): The name of the first column to group by.
    col2 (str): The name of the second column to group by.

    Returns:
    list: A list of DataFrames, each corresponding to a unique pair in col1 and col2.
    """
    grouped = df.groupby([col1, col2])
    dfs = [group for _, group in grouped]
    return dfs

# Example usage:
# df = pd.DataFrame({
#     'A': ['foo', 'foo', 'bar', 'bar', 'foo', 'bar', 'foo'],
#     'B': ['one', 'one', 'one', 'two', 'two', 'two', 'one'],
#     'C': [1, 2, 3, 4, 5, 6, 7]
# })

# result = group_by_two_columns(df, 'A', 'B')
# for df in result:
#     print(df)

In [None]:
def plot_scatter_with_confidence_bands(df, x, y, split_by, error_y_plus, error_y_minus, extra_split=None, template="plotly_white", height=600, width=800, colors=px.colors.qualitative.Plotly):
    traces = []
    
    # Covert Colors
    if colors[0][0] == "#":
        colors = [f"rgb{hex_to_rgb(color)}" for color in colors]
        
    
    # order by mean
    grouped = df.groupby([split_by, extra_split])
    dfs = [group for _, group in grouped]
    ordered_idx = np.argsort([x[y].mean() for x in dfs if not len(x) == 0])
    for i, idx in enumerate(ordered_idx):
        subdf = dfs[idx]
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y],
                                 mode='lines',
                                 showlegend=True,
                                 line_color=rgb_to_rgba(colors[i], 1),
                                 ))
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y] + subdf[error_y_plus],
                                 mode='lines',
                                 fillcolor=rgb_to_rgba(colors[i], 0.15),
                                 line_width=0,
                                 showlegend=False,
                                 ))
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y] - subdf[error_y_minus],
                                 fill='tonexty',
                                 mode='lines',
                                fillcolor=rgb_to_rgba(colors[i], 0.15),
                                 line_width=0,
                                 showlegend=False, 
                                 ))
    fig = go.Figure(traces)
    fig.update_layout(template=template, height=height, width=width)
    return fig

# Dataset Split Comparison

In [None]:
dataset_split_df = df[df.StructureChoice == "Dock_to_All"]

In [None]:
fig = plot_scatter_with_confidence_bands(df = dataset_split_df, 
                                         x = "N_Per_Split", 
                                         y = "Fraction", 
                                         split_by = "Split", 
                                         error_y_plus="Error_Upper", 
                                         error_y_minus="Error_Lower", 
                                         template="simple_white", 
                                         height=600, 
                                         width=800, 
                                         colors=px.colors.qualitative.Plotly)

fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title="<b> Dataset Split </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.write_image(figure_path / "20240510_dataset_split_comparison.png")
fig.write_html(figure_path / "20240510_dataset_split_comparison.html")

# Plot Everything Separately

In [None]:
for split in df.Split.unique().tolist():
    structure_choice_df = df[df.Split == split]
    for choice in structure_choice_df.StructureChoice.unique().tolist():
        subset_df = structure_choice_df[structure_choice_df.StructureChoice == choice]
        if len(subset_df) == 0:
            continue
        fig = plot_scatter_with_confidence_bands(subset_df, 
                                                 "N_Per_Split", 
                                                 "Fraction", 
                                                 "StructureChoice_Choose_N", 
                                                 "Error_Upper", 
                                                 "Error_Lower", 
                                                 template="simple_white", 
                                                 height=600, 
                                                 width=800, 
                                                 colors=px.colors.qualitative.Plotly)
        fig.update_layout(
            font=dict(size=small_font, 
                      family='Arial'
                      ),
            legend=dict(title=f"<b> DatasetSplit:</b> {split} <br>"
                              f"<b> StructureChoice:</b> {choice}", 
                                      x=0.4, y=0.1, 
                                      traceorder='reversed', 
                                      title_font_size=large_font, 
                                      font_color='black'),
                         xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                                    title_font=dict(size=large_font), 
                                    color='black', 
                                    ),
                          yaxis=dict(range=(0,1), 
                                     title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                              title_font=dict(size=large_font), 
                                     color='black', 
                                     ),)
        fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
        fig.write_image(figure_path / f"20240503_{split}_{choice}.png")
        fig.write_html(figure_path / f"20240503_{split}_{choice}.html")

# DateSplit - Structure Choice Comparison

In [None]:
choose_n = [1,10]
structure_choice_comparison_df = df[(df.Split == "DateSplit")&(df.StructureChoice_Choose_N.isin(choose_n))]

In [None]:
fig = plot_scatter_with_confidence_bands(df=structure_choice_comparison_df, 
                                            x="N_Per_Split", 
                                             y="Fraction", 
                                             split_by="StructureChoice", 
                                             error_y_plus="Error_Upper", 
                                             error_y_minus="Error_Lower", 
                                             template="simple_white", 
                                         extra_split="StructureChoice_Choose_N",
                                             height=600, 
                                             width=800, 
                                             colors=px.colors.qualitative.Safe)
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title=f"<b> StructureChoice</b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.write_image(figure_path / f"20240503_date_split_structure_choice_comparison_{choose_n}.png")
fig.write_html(figure_path / f"20240503_date_split_structure_choice_comparison_{choose_n}.html")

# Calculate how many poses

In [None]:
rdf = pd.read_csv(results_csv, index_col=0)

In [None]:
n_poses = rdf.groupby(["Query_Ligand", "Reference_Ligand"]).count().reset_index()

In [None]:
combined_data = rdf.groupby(["Query_Ligand", "Reference_Ligand"]).head(1).reset_index()

In [None]:
combined_data

In [None]:
n_poses = n_poses.rename(columns={"Pose_ID": "N_Poses"})

In [None]:
n_poses = n_poses[["Query_Ligand", "Reference_Ligand", "N_Poses"]]

In [None]:
merged = n_poses.merge(combined_data, on=["Query_Ligand", "Reference_Ligand"])

In [None]:
lig_orders = merged.sort_values("Reference_Structure_Date").Reference_Ligand.unique().tolist()

In [None]:
fig = px.density_heatmap(merged,x="Query_Ligand", y="Reference_Ligand", z="N_Poses", category_orders={"Query_Ligand": lig_orders, "Reference_Ligand": lig_orders,}, color_continuous_scale="Viridis", height=800, width=1000)
fig.write_image(figure_path / "20240521_n_poses_density_heatmap.png")
fig.write_html(figure_path / "20240521_n_poses_density_heatmap.html")

# is the number of poses related to how similar reference / query are?

In [None]:
fig = px.density_heatmap(merged, y="N_Poses", x="Tanimoto", marginal_x="histogram", marginal_y="histogram", color_continuous_scale=px.colors.sequential.Viridis, height=800, width=1000, )
fig.write_image(figure_path / "20240521_n_poses_vs_tanimoto_density_heatmap.png")
fig.write_html(figure_path / "20240521_n_poses_vs_tanimoto_density_heatmap.html")

## no

In [None]:
best_rmsd_per_complex = rdf.sort_values(["RMSD"]).groupby(["Query_Ligand", "Reference_Ligand"]).head(1)

In [None]:
best_rmsd_per_complex["Good_RMSD"] = best_rmsd_per_complex.apply(lambda x: x["RMSD"] < 2, axis=1)
best_rmsd = best_rmsd_per_complex[best_rmsd_per_complex.Good_RMSD]

In [None]:
best_rmsd_with_n_poses = best_rmsd.merge(n_poses, on=["Query_Ligand", "Reference_Ligand"])

In [None]:
fig = px.density_heatmap(best_rmsd_with_n_poses, y="N_Poses", x="Tanimoto", marginal_x="histogram", marginal_y="histogram", color_continuous_scale=px.colors.sequential.Viridis, height=800, width=1000, )
fig.write_image(figure_path / "20240521_best_rmsd_vs_tanimoto_density_heatmap.png")
fig.write_html(figure_path / "20240521_best_rmsd_vs_tanimoto_density_heatmap.html")