In [None]:
# bc of pydantic discrepancies, need to run in harbor environment instead of asap2025

In [None]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from pathlib import Path
import harbor.analysis.cross_docking as cd

# Load Data

In [None]:
data_path = Path("01_analysis_scripts/analyzed_data/results.csv")
figure_path = Path("figures")

In [None]:
results_df = pd.read_csv(data_path)
results_df.N_Per_Split = results_df.N_Per_Split.astype(int)
results_df.sort_values(
    ["Split", "Score", "PoseSelection", "StructureChoice", "StructureChoice_Choose_N", "N_Per_Split"], inplace=True)
results_df["Error_Lower"] = results_df["Fraction"] - results_df["CI_Lower"]
results_df["Error_Upper"] = results_df["CI_Upper"] - results_df["Fraction"]

# Plotting Variables

In [None]:
large_font = 24
small_font = 18 
labels = {"Fraction": "<b> Fraction of Poses Docked < 2Å from Reference </b>",
               "N_Per_Split": "<b> Total Number of Reference Structures Available to Use </b>",
          "Date": "Temporally Ordered",
                   "Random": "Randomly Shuffled"
          }
update_layout_dict = dict(xaxis=dict(title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1),  
                      title_font=dict(size=large_font), 
                             color='black', 
                             ))

In [None]:
def update_traces(fig):
    for trace in fig.data:
        if trace.name is None:
            continue
        trace.name = trace.name.replace("_", " ")
        trace.name = trace.name.replace("Split", "")
        trace.name = trace.name.replace(", ", " | ")
        trace.name = trace.name.replace("RMSD", "RMSD (Positive Control)")
        for k, v in labels.items():
            trace.name = trace.name.replace(k, v)
    return fig

In [None]:
def clean_labels(fig):
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    return fig

In [None]:
def hex_to_rgb(hex_color: str) -> tuple:
    hex_color = hex_color.lstrip("#")
    if len(hex_color) == 3:
        hex_color = hex_color * 2
    return int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)

In [None]:
def rgb_to_rgba(rgb_str, alpha):
    # Split the RGB string into its components
    rgb_values = rgb_str.strip('rgb()').split(',')
    
    # Extract individual RGB values and convert them to integers
    r, g, b = map(int, rgb_values)
    
    # Construct the RGBA string
    rgba_str = f"rgba({r}, {g}, {b}, {alpha})"
    
    return rgba_str

In [None]:
def plot_scatter_with_confidence_bands(df, x, y, split_by, error_y_plus, error_y_minus, template="plotly_white", height=600, width=800, colors=px.colors.qualitative.Plotly):
    traces = []
    
    # Covert Colors
    if colors[0][0] == "#":
        colors = [f"rgb{hex_to_rgb(color)}" for color in colors]
        
    
    # order by mean
    
    ordered_splits = df.groupby(split_by)[y].mean().sort_values().index.tolist()
    print(ordered_splits)
    for i, split in enumerate(ordered_splits):
        # subset the dataframe by the split, which can be a tuple
        if not isinstance(split, tuple):
            subdf = df[df[split_by] == split]
        else:
            subdf = df[(df[split_by[0]] == split[0])&(df[split_by[1]] == split[1])]
        # subdf = df[df[split_by] == split]
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y],
                                 mode='lines',
                                 showlegend=True,
                                 line_color=rgb_to_rgba(colors[i], 1),
                                 ))
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y] + subdf[error_y_plus],
                                 mode='lines',
                                 fillcolor=rgb_to_rgba(colors[i], 0.15),
                                 line_width=0,
                                 showlegend=False,
                                 ))
        traces.append(go.Scatter(name=f"{split}",
                                 x=subdf[x],
                                 y=subdf[y] - subdf[error_y_minus],
                                 fill='tonexty',
                                 mode='lines',
                                fillcolor=rgb_to_rgba(colors[i], 0.15),
                                 line_width=0,
                                 showlegend=False, 
                                 ))
    fig = go.Figure(traces)
    fig.update_layout(template=template, height=height, width=width)
    return fig

found here - https://stackoverflow.com/questions/69587547/continuous-error-band-with-plotly-express-in-python

In [None]:
def line(error_y_mode=None, override_color=None, **kwargs):
        """Extension of `plotly.express.line` to use error bands."""
        ERROR_MODES = {'bar', 'band', 'bars', 'bands', None}
        if error_y_mode not in ERROR_MODES:
            raise ValueError(f"'error_y_mode' must be one of {ERROR_MODES}, received {repr(error_y_mode)}.")
        if error_y_mode in {'bar', 'bars', None}:
            fig = px.line(**kwargs)
        elif error_y_mode in {'band', 'bands'}:
            if 'error_y' not in kwargs:
                raise ValueError(f"If you provide argument 'error_y_mode' you must also provide 'error_y'.")
            figure_with_error_bars = px.line(**kwargs)
            fig = px.line(**{arg: val for arg, val in kwargs.items() if arg != 'error_y'})
            for data in figure_with_error_bars.data:
                x = list(data['x'])
                y_upper = list(data['y'] + data['error_y']['array'])
                y_lower = list(
                    data['y'] - data['error_y']['array'] if data['error_y']['arrayminus'] is None else data['y'] -
                                                                                                       data['error_y'][
                                                                                                           'arrayminus'])
                hex_color = data['line']['color'] if override_color is None else override_color
                color = f"rgba({tuple(int(hex_color.lstrip('#')[i:i + 2], 16) for i in (0, 2, 4))},.3)".replace(
                    '((', '(').replace('),', ',').replace(' ', '')
                fig.add_trace(
                    go.Scatter(
                        x=x + x[::-1],
                        y=y_upper + y_lower[::-1],
                        fill='toself',
                        fillcolor=color,
                        # line=dict(
                        #     color='rgba(255,255,255,1)',
                        #     dash='dash'
                        # ),
                        line=dict(color=color,
                                  dash=data['line']['dash'],
                                  ),
                        hoverinfo="skip",
                        showlegend=False,
                        legendgroup=data['legendgroup'],
                        xaxis=data['xaxis'],
                        yaxis=data['yaxis'],
                    )
                )
            # Reorder data as said here: https://stackoverflow.com/a/66854398/8849755
            reordered_data = []
            for i in range(int(len(fig.data) / 2)):
                reordered_data.append(fig.data[i + int(len(fig.data) / 2)])
                reordered_data.append(fig.data[i])
            fig.data = tuple(reordered_data)
        return fig

In [None]:
fig = line(data_frame=results_df,
           x="N_Per_Split",
           y="Fraction",
           color="Split",
           line_dash="Score",
           error_y="Error_Upper",
           error_y_minus="Error_Lower",
           error_y_mode="band",
           template="simple_white",
           # symbol="Score", 
           height=600,
           width=800,
           log_x=True,
           # color_discrete_sequence=px.colors.qualitative.Dark2,
           labels=labels
           )
fig.update_layout(
    font=dict(size=small_font,
              family='Arial'
              ),
    legend=dict(title="<b> Dataset Split | Score Function </b>",
                x=0.4, y=0.1,
                traceorder='reversed',
                title_font_size=large_font,
                font_color='black'),
    **update_layout_dict)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1))
fig = update_traces(fig)
fig.show()
fig.write_image(figure_path / "20250124_dataset_split_comparison.svg")
fig.write_image(figure_path / "20250124_dataset_split_comparison.png")

In [None]:
fig.data

In [None]:
orange = 'FF7F0E'
blue = '1F77B4'

## Two Figures

In [None]:
fig = line(data_frame=results_df,
           x="N_Per_Split",
           y="Fraction",
           color="Split",
           line_dash="Score",
           error_y="Error_Upper",
           error_y_minus="Error_Lower",
           error_y_mode="band",
           template="simple_white",
           facet_col="Split",
           # symbol="Score", 
           height=600,
           width=1600,
           log_x=True,
           # color_discrete_sequence=px.colors.qualitative.Dark2,
           labels=labels
           )
fig.update_layout(
    font=dict(size=small_font,
              family='Arial'
              ),
    legend=dict(title="<b> Dataset Split | Score Function </b>",
                x=0.4, y=0.1,
                traceorder='reversed',
                title_font_size=large_font,
                font_color='black'),
    **update_layout_dict)
fig = clean_labels(fig)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1))
fig = update_traces(fig)
fig.show()
fig.write_image(figure_path / "20250124_dataset_split_comparison_split.svg")
fig.write_image(figure_path / "20250124_dataset_split_comparison_split.png")

## randomly shuffled

In [None]:
results_df.Split.unique()

In [None]:
fig = line(data_frame=results_df[results_df.Split == "RandomSplit"],
           x="N_Per_Split",
           y="Fraction",
           # color="Split",
           line_dash="Score",
           error_y="Error_Upper",
           error_y_minus="Error_Lower",
           error_y_mode="band",
           template="simple_white",
           # symbol="Score", 
           height=600,
           width=800,
           log_x=True,
           color_discrete_map = {'RandomSplit': blue, 'DateSplit': orange},
           # color_discrete_sequence=px.colors.qualitative.Dark2,
           labels=labels
           )
fig.update_layout(
    font=dict(size=small_font,
              family='Arial'
              ),
    legend=dict(title="<b> Score Function </b>",
                x=0.4, y=0.1,
                traceorder='reversed',
                title_font_size=large_font,
                font_color='black'),
    **update_layout_dict)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1))
fig = update_traces(fig)
fig.show()
fig.write_image(figure_path / "20250124_dataset_split_comparison_random.svg")
fig.write_image(figure_path / "20250124_dataset_split_comparison_random.png")

In [None]:
fig = line(data_frame=results_df[results_df.Split == "DateSplit"],
           x="N_Per_Split",
           y="Fraction",
           # color="Split",
           line_dash="Score",
           error_y="Error_Upper",
           error_y_minus="Error_Lower",
           error_y_mode="band",
           template="simple_white",
           # color='1F77B4',
           color_discrete_map = {'RandomSplit': 'blue', 'DateSplit': 'red'},
           # symbol="Score", 
           height=600,
           width=800,
           log_x=True,
           # color_discrete_sequence=px.colors.qualitative.Dark2,
           labels=labels
           )
fig.update_layout(
    font=dict(size=small_font,
              family='Arial'
              ),
    legend=dict(title="<b> Score Function </b>",
                x=0.4, y=0.1,
                traceorder='reversed',
                title_font_size=large_font,
                font_color='black'),
    **update_layout_dict)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1))
fig = update_traces(fig)
fig.show()
fig.write_image(figure_path / "20250124_dataset_split_comparison_date.svg")
fig.write_image(figure_path / "20250124_dataset_split_comparison_date.png")