# Imports

In [None]:
import plotly.express as px
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from pathlib import Path

# Load Data

In [None]:
data_path = Path("analyzed_data")
figure_path = Path("figures")

In [None]:
df_paths = data_path.glob("*.csv")

In [None]:
dfs = [pd.read_csv(path) for path in df_paths]
df = pd.concat(dfs)

In [None]:
df

In [None]:
df["Error_Lower"] = df["Fraction"] - df["CI_Lower"]
df["Error_Upper"] = df["CI_Upper"] - df["Fraction"]

# Plotting Variables

In [None]:
large_font = 18
small_font = 12

# Plotting Functions

In [None]:
def plot_scatter_with_confidence_bands(df, x, y, split_by, error_y_plus, error_y_minus, template="plotly_white", height=600, width=800, colors=px.colors.qualitative.Plotly):
    traces = []
    
    # order by mean
    ordered_splits = df.groupby(split_by)[y].mean().sort_values().index
    for i, split in enumerate(ordered_splits):
        subdf = df[df[split_by] == split]
        traces.append(go.Scatter(name=split,
                                 x=subdf[x],
                                 y=subdf[y],
                                 mode='lines',
                                 showlegend=True,
                                 line_color=colors[i]))
        traces.append(go.Scatter(name=split,
                                 x=subdf[x],
                                 y=subdf[y] + subdf[error_y_plus],
                                 mode='lines',
                                 line_color=colors[i],
                                 line_width=0,
                                 showlegend=False))
        traces.append(go.Scatter(name=split,
                                 x=subdf[x],
                                 y=subdf[y] - subdf[error_y_minus],
                                 fill='tonexty',
                                 mode='lines',
                                    line_color=colors[i],
                                 line_width=0,
                                 showlegend=False))
    fig = go.Figure(traces)
    fig.update_layout(template=template, height=height, width=width) 
    return fig

# Actual Plots

In [None]:
fig = plot_scatter_with_confidence_bands(df, 
                                         "N_Per_Split", 
                                         "Fraction", 
                                         "StructureChoice", 
                                         "Error_Upper", 
                                         "Error_Lower", 
                                         template="simple_white", 
                                         height=600, 
                                         width=800, 
                                         colors=px.colors.qualitative.Safe)

fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title="<b> Dataset Split </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.write_image(figure_path / "20240503_random_split.png")
fig.write_html(figure_path / "20240503_random_split.html")

In [None]:
df.columns

In [None]:
for choice in df.StructureChoice.unique().tolist():
    fig = plot_scatter_with_confidence_bands(df[df.StructureChoice == choice], 
                                             "N_Per_Split", 
                                             "Fraction", 
                                             "StructureChoice_Choose_N", 
                                             "Error_Upper", 
                                             "Error_Lower", 
                                             template="simple_white", 
                                             height=600, 
                                             width=800, 
                                             colors=px.colors.qualitative.Safe)
    fig.update_layout(
        font=dict(size=small_font, 
                  family='Arial'
                  ),
        legend=dict(title=f"<b> Dataset Split: {choice} </b>", 
                                  x=0.4, y=0.1, 
                                  traceorder='reversed', 
                                  title_font_size=large_font, 
                                  font_color='black'),
                     xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                                title_font=dict(size=large_font), 
                                color='black', 
                                ),
                      yaxis=dict(range=(0,1), 
                                 title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                          title_font=dict(size=large_font), 
                                 color='black', 
                                 ),)
    fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
    fig.write_image(figure_path / f"20240503_random_split_{choice}.png")
    fig.write_html(figure_path / f"20240503_random_split_{choice}.html")