# Import

In [None]:
import pandas as pd, numpy as np
import plotly.express as px
from plotly.graph_objs import Figure
from pathlib import Path
from importlib import reload
import software.analysis as a
from software.plotting import plot_kde, rename_legend_labels, replace_xaxis_labels, replace_yaxis_labels, clean_labels, scatter_wrapper
reload(a)

## Import CSV

In [None]:
df = pd.read_csv("20240120_aggregated_all_stats_bootstraps100_stride10.csv")

# Some of the Core Functions and Variables

In [None]:
color="Version"
tc = "TanimotoCombo"
sort_col_name="Sorted_By"
date_title = "Date for Inclusion of Reference Structures"
tc_title = "TanimotoCombo Cutoff for Inclusion of Reference Structures"
good = 2
frac_title=f"Fraction of Poses < {good}Å from Reference"

In [None]:
basic_plot_kwargs = dict(color=color, 
                         )

In [None]:
big_plot_kwargs = dict(facet_col=sort_col_name,
                         facet_row="Split", 
                         height=600, 
                         width=1200, )

In [None]:
single_plot_kwargs = dict(height=400, width=600)

In [None]:
tc_plot_kwargs = dict(x=tc,  
                      labels={tc: tc_title},
                     range_x=[-0.1,2.1],)

In [None]:
fraction_plot_kwargs = dict(range_y=[-0.1,1.1])

In [None]:
stats_kwargs = dict(y="Value", error_y="STD")

## Hybrid Only Plot

In [None]:
fig = scatter_wrapper(df[df["Version"] == "Hybrid-Only"], 
                      dict(
                          y="Fraction", color="Structure_Split", #facet_col="Version",
                          category_orders={"Structure_Split":["Random", "Structure_Date", "TanimotoCombo", "TanimotoCombo_R"]},
                          color_discrete_sequence=["#5ba300","#89ce00", "#0073e6","#e6308a","#b51963",],
                          error_y="Max", 
                          error_y_minus="Min",
                          template="plotly_white",
                           # **fraction_plot_kwargs,
                           x="Number of References", 
                          height=600,
                          width=800,
                           # **single_plot_kwargs
                          
                      ),
                     y_axis_title=frac_title,
                     x_axis_title="Number of References",
                     )
newlabels = {"random":"Random",
                                  "TanimotoCombo": "Increasing Chemical Similarity", 
                                  "TanimotoCombo_R": "Decreasing Chemical Similarity",
                                  "Structure_Date": "Date of Structure Deposition", }
fig = rename_legend_labels(newlabels, fig)
fig.update_layout(legend=dict(title="Structure Selection", 
                              x=0.5, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=24, 
                              font_size=18, 
                              font_color='black'),
                 xaxis=dict(title_font=dict(size=24), color='black', tickfont=dict(size=18)), yaxis=dict(title_font=dict(size=24), color='black', tickfont=dict(size=18)))
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.show()
fig.write_image("20240120_structure_splits_hybrid.svg")
fig.write_image("20240120_structure_splits_hybrid.png")

In [None]:
import plotly.graph_objects as go
def plot_scatter_with_confidence_bands(df, x, y, split_by, error_y_plus, error_y_minus, template="plotly_white", height=600, width=800, colors=px.colors.qualitative.Plotly):
    traces = []
    
    # order by mean
    ordered_splits = df.groupby(split_by)[y].mean().sort_values().index
    for i, split in enumerate(ordered_splits):
        subdf = df[df[split_by] == split]
        traces.append(go.Scatter(name=split,
                                 x=subdf[x],
                                 y=subdf[y],
                                 mode='lines',
                                 showlegend=True,
                                 line_color=colors[i]))
        traces.append(go.Scatter(name=split,
                                 x=subdf[x],
                                 y=subdf[y] + subdf[error_y_plus],
                                 mode='lines',
                                 line_color=colors[i],
                                 line_width=0,
                                 showlegend=False))
        traces.append(go.Scatter(name=split,
                                 x=subdf[x],
                                 y=subdf[y] - subdf[error_y_minus],
                                 fill='tonexty',
                                 mode='lines',
                                    line_color=colors[i],
                                 line_width=0,
                                 showlegend=False))
    fig = go.Figure(traces)
    fig.update_layout(template=template, height=height, width=width) 
    return fig
        
    
    

In [None]:
fig = plot_scatter_with_confidence_bands(df[df["Version"] == "Hybrid-Only"], 
                                         "Number of References", 
                                         "Fraction", 
                                         "Structure_Split", 
                                         "Max", 
                                         "Min", 
                                         template="simple_white", 
                                         height=600, 
                                         width=800, 
                                         colors=px.colors.qualitative.Safe)
newlabels = {"random":"Random",
              "TanimotoCombo": "Picking <b> least </b> similar reference", 
              "TanimotoCombo_R": "Picking <b> most </b> similar reference",
              "Structure_Date": "Picking the earliest deposition date", }
fig.for_each_trace(lambda t: t.update(name = newlabels[t.name],
                                          # legendgroup = newlabels[t.name],
                                          # hovertemplate = t.hovertemplate.replace(t.name, newlabels[t.name])
                                         ))
large_font = 18
small_font = 12
fig.update_layout(
    font=dict(size=small_font, 
              family='Arial'
              ),
    legend=dict(title="<b> Reference Selection Strategy </b>", 
                              x=0.4, y=0.1, 
                              traceorder='reversed', 
                              title_font_size=large_font, 
                              font_color='black'),
                 xaxis=dict(title="<b> Total Number of References Available to Use </b>", 
                            title_font=dict(size=large_font), 
                            color='black', 
                            ),
                  yaxis=dict(range=(0,1), 
                             title="<b> Fraction of Poses Docked < 2Å from Reference </b>", 
                      title_font=dict(size=large_font), 
                             color='black', 
                             ),)
fig.update_yaxes(tickvals=np.arange(0, 1.1, 0.1)) 
fig.show()

In [None]:
fig.write_image("20240229_structure_splits_hybrid.svg")
fig.write_image("20240229_structure_splits_hybrid.png")