In [None]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from pathlib import Path

# Load Data

In [None]:
results_csv = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20240424_multi_pose_docking_cross_docking/results_csvs/20240503_combined_results_with_data.csv")

In [None]:
data_path = results_csv.parent.parent / "analyzed_data"
figure_path = Path("figures")

In [None]:
df_paths = data_path.glob("*/*.csv")

In [None]:
dfs = [pd.read_csv(path) for path in df_paths]

In [None]:
ogdf = pd.concat(dfs)
ogdf.N_Per_Split = ogdf.N_Per_Split.astype(int)
ogdf.sort_values(["Split", "Score", "PoseSelection", "StructureChoice", "StructureChoice_Choose_N", "N_Per_Split"], inplace=True)

In [None]:
ogdf["Error_Lower"] = ogdf["Fraction"] - ogdf["CI_Lower"]
ogdf["Error_Upper"] = ogdf["CI_Upper"] - ogdf["Fraction"]

In [None]:
df = ogdf[ogdf.PoseSelection == "Default"]
multipose = ogdf[ogdf.PoseSelection != "Default"]

In [None]:
ogdf.StructureChoice.unique()

In [None]:
ogdf.Split.unique()

# Plotting Variables

In [None]:
large_font = 24
small_font = 18 
labels = {"Fraction": "<b> Fraction of Poses Docked < 2Å from Reference </b>",
               "N_Per_Split": "<b> Total Number of Reference Structures Available to Use </b>",
          }
update_layout_dict = dict(xaxis=dict(title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1),  
                      title_font=dict(size=large_font), 
                             color='black', 
                             ))

In [None]:
def update_traces(fig):
    for trace in fig.data:
        if trace.name is None:
            continue
        trace.name = trace.name.replace("_", " ")
        trace.name = trace.name.replace("Split", "")
        trace.name = trace.name.replace(", ", " | ")
        trace.name = trace.name.replace("RMSD", "RMSD (Positive Control)")
    return fig

In [None]:
def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip("#")
    if len(hex_color) == 3:
        hex_color = hex_color * 2
    return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)

In [None]:
def rgb_to_rgba(rgb_str, alpha):
    # Split the RGB string into its components
    rgb_values = rgb_str.strip('rgb()').split(',')
    
    # Extract individual RGB values and convert them to integers
    r, g, b = map(int, rgb_values)
    
    # Construct the RGBA string
    rgba_str = f"rgba({r}, {g}, {b}, {alpha})"
    
    return rgba_str

In [None]:
def plot_scatter_with_confidence_bands(df, x, y, split_by, error_y_plus, error_y_minus, template="plotly_white", height=600, width=800, colors=px.colors.qualitative.Plotly):
    traces = []
    
    # Covert Colors
    if colors[0][0] == "#":
        colors = [f"rgb{hex_to_rgb(color)}" for color in colors]
        
    
    # order by mean
    
    ordered_splits = df.groupby(split_by)[y].mean().sort_values().index.tolist()
    print(ordered_splits)
    for i, split in enumerate(ordered_splits):
        # subset the dataframe by the split, which can be a tuple
        if not isinstance(split, tuple):
            subdf = df[df[split_by] == split]
        else:
            subdf = df[(df[split_by[0]] == split[0])&(df[split_by[1]] == split[1])]
        # subdf = df[df[split_by] == split]
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y],
                                 mode='lines',
                                 showlegend=True,
                                 line_color=rgb_to_rgba(colors[i], 1),
                                 ))
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y] + subdf[error_y_plus],
                                 mode='lines',
                                 fillcolor=rgb_to_rgba(colors[i], 0.15),
                                 line_width=0,
                                 showlegend=False,
                                 ))
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y] - subdf[error_y_minus],
                                 fill='tonexty',
                                 mode='lines',
                                fillcolor=rgb_to_rgba(colors[i], 0.15),
                                 line_width=0,
                                 showlegend=False, 
                                 ))
    fig = go.Figure(traces)
    fig.update_layout(template=template, height=height, width=width)
    return fig

# I'd like to remake the dataset split figure with shaded error bars

In [None]:
dataset_split_df = df[(df.StructureChoice == "Dock_to_All")&(df.PoseSelection_Choose_N == 1)]

found here - https://stackoverflow.com/questions/69587547/continuous-error-band-with-plotly-express-in-python

In [None]:
def line(error_y_mode=None, **kwargs):
    """Extension of `plotly.express.line` to use error bands."""
    ERROR_MODES = {'bar','band','bars','bands',None}
    if error_y_mode not in ERROR_MODES:
        raise ValueError(f"'error_y_mode' must be one of {ERROR_MODES}, received {repr(error_y_mode)}.")
    if error_y_mode in {'bar','bars',None}:
        fig = px.line(**kwargs)
    elif error_y_mode in {'band','bands'}:
        if 'error_y' not in kwargs:
            raise ValueError(f"If you provide argument 'error_y_mode' you must also provide 'error_y'.")
        figure_with_error_bars = px.line(**kwargs)
        fig = px.line(**{arg: val for arg,val in kwargs.items() if arg != 'error_y'})
        for data in figure_with_error_bars.data:
            x = list(data['x'])
            y_upper = list(data['y'] + data['error_y']['array'])
            y_lower = list(data['y'] - data['error_y']['array'] if data['error_y']['arrayminus'] is None else data['y'] - data['error_y']['arrayminus'])
            color = f"rgba({tuple(int(data['line']['color'].lstrip('#')[i:i+2], 16) for i in (0, 2, 4))},.3)".replace('((','(').replace('),',',').replace(' ','')
            fig.add_trace(
                go.Scatter(
                    x = x+x[::-1],
                    y = y_upper+y_lower[::-1],
                    fill = 'toself',
                    fillcolor = color,
                    line = dict(
                        color = 'rgba(255,255,255,0)'
                    ),
                    hoverinfo = "skip",
                    showlegend = False,
                    legendgroup = data['legendgroup'],
                    xaxis = data['xaxis'],
                    yaxis = data['yaxis'],
                )
            )
        # Reorder data as said here: https://stackoverflow.com/a/66854398/8849755
        reordered_data = []
        for i in range(int(len(fig.data)/2)):
            reordered_data.append(fig.data[i+int(len(fig.data)/2)])
            reordered_data.append(fig.data[i])
        fig.data = tuple(reordered_data)
    return fig

In [None]:
fig = line(data_frame=dataset_split_df, 
              x="N_Per_Split", 
              y="Fraction", 
              color="Score",
              line_dash="Split",
              error_y="Error_Upper", 
              error_y_minus="Error_Lower", 
           error_y_mode="band",
              template="simple_white", 
              # symbol="Score", 
              height=600, 
              width=800,
              log_x=True,
              # color_discrete_sequence=px.colors.qualitative.Dark2,
              labels = labels
              )
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title="<b> Score Function | Dataset Split </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
    **update_layout_dict)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig = update_traces(fig)
fig.show()
fig.write_image(figure_path / "20241024_dataset_split_comparison_v3.svg")

# Make figure of number of structures vs date

In [None]:
results_csv = Path(
    "/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20240424_multi_pose_docking_cross_docking/results_csvs/20240503_combined_results_with_data.csv")

In [None]:
all_results_df = pd.read_csv(results_csv, index_col=0)

In [None]:
unique_dates = all_results_df.Reference_Structure_Date.unique()

In [None]:
import json
from datetime import datetime 

In [None]:
def date_processor(date_string):
    if type(date_string) == str and not date_string == 'None':
        try:
            return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
        except ValueError:
            return datetime.strptime(date_string, "%d/%m/%Y %H:%M")
    else:
        return None

In [None]:
with open("20240503_inputs_analysis/date_dict.json", "r") as f:
    date_dict = [{"Name": name, "Date": date_processor(date)} for  name, date in json.load(f).items()]

In [None]:
date_df = pd.DataFrame.from_records(date_dict)

In [None]:
date_df

In [None]:
reference_ligand_df = all_results_df.groupby(['Reference_Ligand']).head(1)

In [None]:
from rdkit.Chem import MolFromSmiles

In [None]:
reference_ligand_df["rdkit_mol"] = [MolFromSmiles(str(smiles)) if not str(smiles) == 'nan' else None for smiles in reference_ligand_df.Reference_Ligand_SMILES]

In [None]:
reference_ligand_df = reference_ligand_df[~reference_ligand_df.rdkit_mol.isna()]

In [None]:
reference_ligand_df["n_atoms"] = reference_ligand_df.rdkit_mol.apply(lambda x: x.GetNumAtoms())

In [None]:
reference_ligand_df.sort_values("Reference_Structure_Date", inplace=True)

In [None]:
reference_ligand_df["cum_max"] = reference_ligand_df['n_atoms'].cummax()

In [None]:
reference_ligand_df = all_results_df.groupby(['Reference_Ligand']).head(1)

In [None]:
bemis_murcko_cluster_df = pd.read_csv("/Users/alexpayne/Scientific_Projects/harbor/examples/cluster_labels.csv")

In [None]:
bemis_murcko_cluster_df.columns

In [None]:
bemis_murcko_cluster_df.columns = ["Reference_Ligand", "Cluster", "Scaffold_Smiles"]

In [None]:
plot_df = pd.merge(reference_ligand_df, bemis_murcko_cluster_df, on="Reference_Ligand")

In [None]:
plot_df

In [None]:
plot_df.groupby("Cluster").count()

In [None]:
plot_df['simple_clusters'] = plot_df.Cluster.apply(lambda x: str(x) if x <= 5 else "6-90") 

In [None]:
plot_df.sort_values('simple_clusters', inplace=True)

In [None]:
fig = px.ecdf(plot_df, 
              x='Reference_Structure_Date', 
              color='simple_clusters', 
              ecdfnorm=None, 
              template='simple_white', 
              log_y=True,height=600,width=800,)
# update legend title
fig.update_layout(legend_title_text="<b> Bemis-Murcko Cluster </b>")
fig.update_xaxes(title_text="<b> Date of Crystal Structure Collection </b>")
fig.update_yaxes(title_text="<b> Cumulative Number of Structures </b>")

update_layout_dict = dict(xaxis=dict(title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(
                      # range=(0,1),  
                      title_font=dict(size=large_font), 
                             color='black', 
                             ))

# move legend to inside the plot
fig.update_layout(legend=dict(
    yanchor="bottom",
    y=0.25,
    xanchor="right",
    x=1.1
), **update_layout_dict)
fig.show()
fig.write_image(figure_path / "20241024_cumulative_cluster_by_date.svg")

# Make new multipose datesplit plot without POSIT sorting

In [None]:
# Multipose Plots
datesplit = multipose[(multipose.Split == "DateSplit")&(multipose.Score == "RMSD")]
fig = line(data_frame=datesplit,
              x="N_Per_Split",
              y="Fraction",
              color="PoseSelection_Choose_N",
              line_dash="Score",
              error_y="Error_Upper",
           error_y_mode='bands',
              error_y_minus="Error_Lower", 
              template="simple_white",
              # symbol="", 
              height=600,
              width=800,
              log_x=True,
              color_discrete_sequence=px.colors.sequential.Viridis,
              labels=labels
              )
fig.update_layout(
    font=dict(size=small_font,
              family='Arial'
              ),
    legend=dict(title="<b> Number of Poses Returned by POSIT </b>",
                x=0.3, y=0.1,
                # traceorder='reversed', 
                title_font_size=large_font,
                font_color='black'),
    **update_layout_dict)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1))
for trace in fig.data:
        if trace.name is None:
            continue
        trace.name = trace.name.replace("_", " ")
        trace.name = trace.name.replace("Split", "")
        trace.name = trace.name.replace(", ", "")
        trace.name = trace.name.replace("RMSD", "")
fig.show()
fig.write_image(figure_path / "20241024_multipose_datesplit.svg")

In [None]:
all_results_df

In [None]:
tc_cuttofs = np.linspace(0,2, 20)

In [None]:
tc_cuttofs[0]

In [None]:
for tc_cuttof in tc_cuttofs[:1]:
    subset = all_results_df[all_results_df.Tanimoto <= tc_cuttof].sort_values("docking-confidence-POSIT")
    subset.groupby(["Query_Ligand"])