# Imports

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter
import seaborn as sns
import pandas as pd
from pathlib import Path
import json

# Load Data

In [None]:
posit_results = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_analyzed_results/posit_scaffold_date_split_combined_results.csv")
pdf = pd.read_csv(posit_results)
pdf["Error_Lower"] = pdf["Fraction"] - pdf["CI_Lower"]
pdf["Error_Lower"] = pdf["Error_Lower"].apply(lambda x: 0 if x < 0 else x)
pdf["Error_Upper"] = pdf["CI_Upper"] - pdf["Fraction"]
pdf["Error_Upper"] = pdf["Error_Upper"].apply(lambda x: 0 if x < 0 else x)

In [None]:
pdf

# Plotting

## Plot function

In [None]:
# Global configuration
fig_path = Path("./20250729_scaffold_date_split_v2")
fig_path.mkdir(parents=True, exist_ok=True)

FIGNUM_GLOBAL = 0

# def save_fig(fig, filename, dpi=200, suffix=".pdf"):
#     """Save the figure with a global figure number."""
#     global FIGNUM_GLOBAL
#     FIGNUM_GLOBAL += 1
#     figpath = Path(fig_path / f"{filename}_{FIGNUM_GLOBAL:02d}")
#     fig.savefig(figpath.with_suffix(suffix), 
#                 bbox_inches="tight", 
#                 dpi=dpi)

def save_fig(fig, filename, dpi=200, suffix=".pdf"):
    figpath = Path(fig_path / f"{filename}")
    fig.savefig(figpath.with_suffix(suffix), 
                bbox_inches="tight", 
                dpi=dpi)
    

sns.set_style("white")
label_map = {
    "Reference_Split": "Dataset Split Type",
    "Score": "Scoring Method",
    "RandomSplit": "Randomly Ordered",
    "DateSplit": "Ordered by Date",
    "RMSD": "RMSD (Positive Control)",
    "POSIT_Probability": "POSIT Probability",
    "ScaffoldDateSplit": "One Structure Per Scaffold",
    # "N_Reference_Structures": "Number of Randomly Chosen Reference Structures",
    "N_Reference_Structures": "Number of Reference Structures Available to Use \n(Log Scale)",
    "Fraction": "Fraction of Ligands Posed \n<2Å from Reference",
    "CI_Lower": "Confidence Interval Lower Bound",
    "CI_Upper": "Confidence Interval Upper Bound",
    "N_Refs_Per_Scaffold": "# Reference Structures Per Scaffold",
}
for column in pdf.columns:
    if not column in label_map:
        label_map[column] = column
        
X_VAR = label_map["N_Reference_Structures"]
Y_VAR = label_map["Fraction"]
X_LABEL = label_map["N_Reference_Structures"]
Y_LABEL = label_map["Fraction"]
# QUERY_SCAFFOLD_ID = label_map["Query_Scaffold_ID_Subset_1"]
# REF_SCAFFOLD_ID = label_map["Reference_Scaffold_ID_Subset_1"]
COLOR_VAR = label_map["Reference_Split"]
STYLE_VAR = label_map["Score"]
CI_LOWER = label_map["CI_Lower"]
CI_UPPER = label_map["CI_Upper"]
LARGE_FIG_SIZE = (12, 8)
SMALL_FIG_SIZE = (8, 6)
FONT_SIZES = {
    "xlabel": 24,
    "ylabel": 24,
    "ticks": 18,
    "legend_title": 24,
    "legend_text": 18,
}
ALPHA = 0.1
n_refs = sorted(pdf["N_Reference_Structures"].unique())

In [None]:
def plot_filled_in_error_bars(
    raw_df,
    x_var=X_VAR,
    y_var=Y_VAR,
    color_var=COLOR_VAR,
    style_var=STYLE_VAR,
    ci_lower=CI_LOWER,
    ci_upper=CI_UPPER,
    reverse_hue_order=False,
    reverse_style_order=False,
    n_refs= n_refs,
    fill_between=True,
        log_scale=True,
):
    # Sort the dataframe
    raw_df = raw_df.sort_values(by=[x_var, style_var, color_var])
    plt.figure(figsize=(LARGE_FIG_SIZE[0], LARGE_FIG_SIZE[1]))
    
    # Define hue and style orders
    hue_order = list(reversed(sorted(raw_df[color_var].unique()))) if reverse_hue_order else list(sorted(raw_df[color_var].unique()))
    style_order = list(reversed(sorted(raw_df[style_var].unique()))) if reverse_style_order else list(sorted(raw_df[style_var].unique()))
    
    # Create color mapping
    unique_colors = sns.color_palette(n_colors=len(raw_df[color_var].unique()))
    color_map = dict(zip(hue_order, unique_colors))
    
    
    
    # Create the line plot
    fig = sns.lineplot(
        data=raw_df,
        x=x_var,
        y=y_var,
        hue=color_var,
        style=style_var,  # Keep style_var for line styles
        palette=color_map,
        hue_order=hue_order,
        style_order= style_order,
    )

    if fill_between:
        # Create fill between for each group using matched colors
        for name, group in raw_df.groupby([color_var, style_var]):
            color_name = name[0]  # First element is Score
            fig.fill_between(
                group[x_var],
                group[ci_lower],
                group[ci_upper],
                color=color_map[color_name],
                alpha=ALPHA,
            )
    
    # Customize each subplot
    if log_scale:
        fig.set_xscale("log")
        fig.xaxis.set_major_formatter(ScalarFormatter())

    custom_ticks = n_refs
    fig.set_xticks(custom_ticks)
    fig.set_xticklabels(custom_ticks, fontsize=FONT_SIZES["ticks"])
    fig.tick_params(axis='y', labelsize=FONT_SIZES["ticks"])

    fig.set_xlabel(X_LABEL, fontsize=FONT_SIZES["xlabel"], fontweight="bold")
    fig.set_ylabel(Y_LABEL, fontsize=FONT_SIZES["ylabel"], fontweight="bold")

    # Customize legend
    legend = fig.legend()
    plt.setp(legend.get_title(), fontsize=FONT_SIZES["legend_title"], fontweight="bold")
    plt.setp(legend.get_texts(), fontsize=FONT_SIZES["legend_text"])
    return plt

## update labels

In [None]:
df = pdf.copy()
df = df.rename(columns=label_map)
for column in df.columns:
    df[column] = df[column].apply(lambda x: label_map.get(x, x))
df.columns

In [None]:
# if ScaffoldDateSplit, only include if n_refs_per_scaffold is 1
plot_df = df[(df[label_map['Reference_Split']] != 'One Structure Per Scaffold') | (df[label_map['N_Refs_Per_Scaffold']] == 1)]

In [None]:
fig = plot_filled_in_error_bars(plot_df, 
                                    x_var=label_map['N_Reference_Structures'], 
                                    y_var=label_map['Fraction'], 
                                    color_var=label_map['Reference_Split'], 
                                    style_var=label_map['Score'], 
                                    ci_lower=label_map['CI_Lower'], 
                                    ci_upper=label_map['CI_Upper'],
                                    reverse_hue_order=True,
                                    )
save_fig(fig, "scaffold_date_split_results")
save_fig(fig, "scaffold_date_split_results", suffix=".png", dpi=300)

In [None]:
# only include the ordered by date + POSIT result and both results for One Structure Per Scaffold
simplified_df = df[(df[label_map['Reference_Split']].isin(['Ordered by Date', 'One Structure Per Scaffold'])) & 
                    (df[label_map['Score']].isin(['POSIT Probability', 'RMSD (Positive Control)']))]

fig = plot_filled_in_error_bars(simplified_df, 
                                    x_var=label_map['N_Reference_Structures'], 
                                    y_var=label_map['Fraction'], 
                                    color_var=label_map['Reference_Split'], 
                                    style_var=label_map['Score'], 
                                    ci_lower=label_map['CI_Lower'], 
                                    ci_upper=label_map['CI_Upper'],
                                    reverse_hue_order=True,
                                    )
save_fig(fig, "scaffold_date_split_results_simple")
save_fig(fig, "scaffold_date_split_results_simple", suffix=".png", dpi=300)

In [None]:
# only include the ordered by date + POSIT result and both results for One Structure Per Scaffold
simplified_df = df[(df[label_map['Reference_Split']].isin(['Ordered by Date', 'One Structure Per Scaffold'])) & 
                    (df[label_map['Score']].isin(['POSIT Probability', 'RMSD (Positive Control)']))]
for n in simplified_df[label_map['N_Refs_Per_Scaffold']].unique():
    plot_df = simplified_df[(simplified_df[label_map['Reference_Split']] != 'One Structure Per Scaffold') | (simplified_df[label_map['N_Refs_Per_Scaffold']] == n)]
    
    fig = plot_filled_in_error_bars(plot_df, 
                                        x_var=label_map['N_Reference_Structures'], 
                                        y_var=label_map['Fraction'], 
                                        color_var=label_map['Reference_Split'], 
                                        style_var=label_map['Score'], 
                                        ci_lower=label_map['CI_Lower'], 
                                        ci_upper=label_map['CI_Upper'],
                                        reverse_hue_order=True,
                                        )
    name = f"scaffold_date_split_results_simple_{n}_refs_per_scaffold"
    save_fig(fig, name)
    save_fig(fig, name, suffix=".png", dpi=300)

# Color by N Refs Per Scaffold

In [None]:
# only include the ordered by date + POSIT result and both results for One Structure Per Scaffold
plot_df = df[(df[label_map['Reference_Split']] == 'One Structure Per Scaffold')&(df[label_map['N_Refs_Per_Scaffold']] > 0)]

fig = plot_filled_in_error_bars(plot_df, 
                                    x_var=label_map['N_Reference_Structures'], 
                                    y_var=label_map['Fraction'], 
                                    color_var=label_map['N_Refs_Per_Scaffold'], 
                                    style_var=label_map['Score'], 
                                    ci_lower=label_map['CI_Lower'], 
                                    ci_upper=label_map['CI_Upper'],
                                    reverse_hue_order=True,
                                fill_between=False,
                                    )
name = f"scaffold_date_split_results_n_refs_per_scaffold_colorised"
save_fig(fig, name)
save_fig(fig, name, suffix=".png", dpi=300)

# Plot on linear scale

In [None]:
# if ScaffoldDateSplit, only include if n_refs_per_scaffold is 1
plot_df = df[(df[label_map['Reference_Split']].isin(['Ordered by Date', 'One Structure Per Scaffold'])) & 
                    (df[label_map['Score']].isin(['POSIT Probability', 'RMSD (Positive Control)']))]
plot_df = plot_df[(plot_df[label_map['Reference_Split']] != 'One Structure Per Scaffold') | (plot_df[label_map['N_Refs_Per_Scaffold']] == 1)]

In [None]:
fig = plot_filled_in_error_bars(plot_df, 
                                    x_var=label_map['N_Reference_Structures'], 
                                    y_var=label_map['Fraction'], 
                                    color_var=label_map['Reference_Split'], 
                                    style_var=label_map['Score'], 
                                    ci_lower=label_map['CI_Lower'], 
                                    ci_upper=label_map['CI_Upper'],
                                    reverse_hue_order=True,
                                log_scale=False,
                                n_refs=[1, 50, 100, 137, 200, 300, 403]
                                    )
save_fig(fig, "scaffold_date_split_results_linear")
save_fig(fig, "scaffold_date_split_results_linear", suffix=".png", dpi=300)