# Imports

In [None]:
import pandas as pd
from pathlib import Path
import numpy as np
from importlib import reload
import software.analysis as a
reload(a)

# Loading

In [None]:
df = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20230611-combined.csv", index_col=0, engine='pyarrow', dtype_backend="pyarrow")

# Analyzing Slow Code

In [None]:
%%prun
random_stats = a.calculate_rmsd_stats(df, query_mol_id="Compound_ID", reference_selection="random", ref_structure_stride=10, score_column="POSIT_R", group_by=["Version"], n_bootstraps=10)

## splitting up functions into sub-functions

In [None]:
def get_random_sample(df) -> pd.DataFrame:
    # Randomize the order of the structures
    randomized = df.sample(frac=1)
    return randomized

In [None]:
def get_structure_sort(df, group_by: list, query_mol_id, reference_selection, n_struc) -> pd.DataFrame:
    subset_df = df.sort_values(reference_selection).groupby([query_mol_id] + group_by).head(n_struc)
    return subset_df

In [None]:
def get_score_sort(subset_df, score_column, query_mol_id, group_by):
    # Rank the poses by score
    return (
        subset_df.sort_values(score_column)
        .groupby([query_mol_id] + group_by)
        .head(1))

In [None]:
def calculate_fraction(scored_df, score_column, query_mol_id, group_by, rmsd_col, rmsd_cutoff, n_mols):
    return scored_df.groupby(group_by, group_keys=True)[rmsd_col].apply(lambda x: x <= rmsd_cutoff).groupby(group_by).sum() / n_mols

In [None]:
def collect_results(rmsd_stats_series, n_ref, score_column, query_mol_id, group_by, n_mols):
    split_cols_list = []
    score_list = []
    n_references = []

    min_nrefs = []
    max_nrefs = []
    mean_nrefs = []

    for split_col in rmsd_stats_series.index:
        split_cols_list.append(split_col)
        score_list.append(rmsd_stats_series[split_col])
        n_references.append(n_ref)

    # n_allowed_refs = n_references if cumulative else ref_structure_stride

    return_df = pd.DataFrame(
        {
            "Fraction": score_list,
            "Version": split_cols_list,
            "Number of References": n_references,
            "Structure_Split": reference_selection,
        }
    )
    if reference_selection == "random":
        return_df["Split_Value_min"] = "Random"
        return_df["Split_Value_max"] = "Random"
    else:
        return_df["Split_Value_min"] = subset_df[reference_selection].min()
        return_df["Split_Value_max"] = subset_df[reference_selection].max()
    return return_df

### combined workflow

In [None]:
reference_selection = "Structure_Date"
query_mol_id = "Compound_ID"
group_by = ["Version"]
score_column = "POSIT_R"
rmsd_col="RMSD"

## with pyarrow

In [None]:
df = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20230611-combined.csv", index_col=0, engine='pyarrow', dtype_backend="pyarrow")

In [None]:
%%prun
dfs = []
for j in range(1,10):
    randomized = get_random_sample(df)
    n_mols = len(df[query_mol_id].unique())
    for i in range(1, 200, 10):
        subset_df = get_structure_sort(randomized, group_by=group_by, query_mol_id = query_mol_id, reference_selection=reference_selection, n_struc=i)
        scored_df = get_score_sort(subset_df=subset_df, score_column=score_column, query_mol_id=query_mol_id, group_by=group_by)
        fraction = calculate_fraction(scored_df, score_column=score_column, query_mol_id=query_mol_id, group_by=group_by, rmsd_col=rmsd_col, rmsd_cutoff=2, n_mols=n_mols)
        dfs.append(collect_results(fraction, i, score_column, query_mol_id, group_by, n_mols))
combined = pd.concat(dfs)

## without pyarrow

In [None]:
df = pd.read_csv("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/20230611-combined.csv", index_col=0)

In [None]:
%%prun
dfs = []
for j in range(1,10):
    randomized = get_random_sample(df)
    n_mols = len(df[query_mol_id].unique())
    for i in range(1, 200, 10):
        subset_df = get_structure_sort(randomized, group_by=group_by, query_mol_id = query_mol_id, reference_selection=reference_selection, n_struc=i)
        scored_df = get_score_sort(subset_df=subset_df, score_column=score_column, query_mol_id=query_mol_id, group_by=group_by)
        fraction = calculate_fraction(scored_df, score_column=score_column, query_mol_id=query_mol_id, group_by=group_by, rmsd_col=rmsd_col, rmsd_cutoff=2, n_mols=n_mols)
        dfs.append(collect_results(fraction, i, score_column, query_mol_id, group_by, n_mols))
combined = pd.concat(dfs)

# plotting

In [None]:
import plotly.express as px
from plotly.graph_objs import Figure

In [None]:
fraction_plot_kwargs = dict(range_y=[-0.1,1.1])

### cleanup functions

In [None]:
def replace_xaxis_labels(fig: Figure, axis_title):
    fig.for_each_xaxis(lambda x: x.update(title = ''))
    fig.add_annotation(x=0.5,y=-0.15,
                   text=axis_title, textangle=0,
                       font=dict(size=16),
                    xref="paper", yref="paper",
            showarrow=False,)
    return fig

In [None]:
def replace_yaxis_labels(fig: Figure, axis_title):
    fig.for_each_yaxis(lambda y: y.update(title = ''))
    fig.add_annotation(x=-0.05,y=0.5,
                   text=axis_title, textangle=-90,
                       font=dict(size=16),
                    xref="paper", yref="paper",
            showarrow=False,)
    return fig

In [None]:
def clean_labels(fig):
    fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
    return fig

### scatterplot wrapper

In [None]:
def scatter_wrapper(df, kwarg_dict, 
                    x_axis_title=None, 
                    y_axis_title=None, 
                    replace_xaxis=False,
                    replace_y_axis=False,
                    clean=True,
                    x_axis_reversed=False
                   ):
    fig:Figure = px.scatter(df, **kwarg_dict, hover_data=df.columns)
    if x_axis_title:
        if replace_xaxis:
            fig = replace_xaxis_labels(fig, x_axis_title)
        else:
            fig.update_xaxes(title=x_axis_title)
    
    if y_axis_title:
        if replace_y_axis:
            fig = replace_yaxis_labels(fig, y_axis_title)
        else:
            fig.update_yaxes(title=y_axis_title)
    
    if clean:
        fig = clean_labels(fig)
    if x_axis_reversed:
        fig.update_xaxes(autorange="reversed")
    return fig

In [None]:
fig = scatter_wrapper(combined, 
                      dict(
                          y="Fraction", color="Structure_Split", facet_col="Version",
                          color_discrete_sequence=px.colors.qualitative.Dark24,
                          # error_y="Max", 
                          # error_y_minus="Min",
                          template="seaborn",
                           **fraction_plot_kwargs,
                           x="Number of References", 
                          height=600,
                          width=800
                           # **single_plot_kwargs
                      ),
                     y_axis_title="Fraction of Poses < 2Å from Reference",
                     x_axis_title="Number of References"
                     )
fig.for_each_yaxis(lambda y: y.update(title = ''))
fig.update_layout(yaxis1=dict(title="Fraction of Poses < 2Å from Reference"), height=400, width=800)
fig.show()
fig.write_image("20231101_sasa_comparison.png")

# What if we didn't sort every time?

In [None]:
structure_sorted = df.sort_values(reference_selection).groupby([query_mol_id] + group_by)

In [None]:
structure_sorted.head(10)

## yeah lets do that