# Imports

In [None]:
import datetime

import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
from rdkit import Chem
import seaborn as sns
import pandas as pd
from pathlib import Path
from harbor.analysis.cross_docking import DockingDataModel
import numpy as np

# Load Data

In [None]:
results_path = Path("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_analyzed_results/")
posit_results = [results_path / csv for csv in ["ALL_1_poses_x_to_y_combined_results.csv", 
                                                "ALL_1_poses_x_to_y_5_combined_results.csv",
                                                "ALL_1_poses_x_to_x_5_combined_results.csv", 
                                                "ALL_1_poses_x_to_x_combined_results.csv",
                                                "ALL_1_poses_x_to_not_x_combined_results.csv",
                                                "ALL_1_poses_not_x_to_x_combined_results.csv"]]

posit_raw_df = DockingDataModel.deserialize("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_combined_results/ALL_1_poses.json")
pdf = pd.concat([pd.read_csv(csv) for csv in posit_results], ignore_index=True)
pdf["Error_Lower"] = pdf["Fraction"] - pdf["CI_Lower"]
pdf["Error_Lower"] = pdf["Error_Lower"].apply(lambda x: 0 if x < 0 else x)
pdf["Error_Upper"] = pdf["CI_Upper"] - pdf["Fraction"]
pdf["Error_Upper"] = pdf["Error_Upper"].apply(lambda x: 0 if x < 0 else x)

# replace brackets in the query and ref columns
query = "Query_Scaffold_ID_Subset"
ref = "Reference_Scaffold_ID_Subset"

# convert to single numbers
pdf[query] = pdf[query].astype(str).apply(lambda x: x.replace("[", "").replace("]", "") if "[" in x else x)
pdf[ref] = pdf[ref].astype(str).apply(lambda x: x.replace("[", "").replace("]", "") if "[" in x else x)

# increment the query and ref columns by 1 to make them 1-indexed
pdf[query] = pdf[query].apply(lambda x: str(int(x) + 1) if x.isdigit() else np.nan)
pdf[ref] = pdf[ref].apply(lambda x: str(int(x) + 1) if x.isdigit() else np.nan)

# convert query and ref columns to float
pdf['qint'] = pdf[query].astype(float)
pdf['rint'] = pdf[ref].astype(float)

In [None]:
ligand_data = pd.read_csv("ligand_scaffold_data.csv") # generated in 20250703_scaffolds_over_time.ipynb

In [None]:
scaff_data = ligand_data.groupby("scaffold_orig_id").head(1)[["scaffold_orig_id", "scaffold_smarts", "scaffold_count", "scaffold_first_date"]]
scaff_data["scaffold_first_date"] = pd.to_datetime(scaff_data["scaffold_first_date"])

# Plotting Params

In [None]:
# Global configuration
fig_path = Path("./20250702_scaffold_split")
fig_path.mkdir(parents=True, exist_ok=True)

FIGNUM_GLOBAL = 0

# def save_fig(fig, filename, dpi=200, suffix=".pdf"):
#     """Save the figure with a global figure number."""
#     global FIGNUM_GLOBAL
#     FIGNUM_GLOBAL += 1
#     figpath = Path(fig_path / f"{filename}_{FIGNUM_GLOBAL:02d}")
#     fig.savefig(figpath.with_suffix(suffix), 
#                 bbox_inches="tight", 
#                 dpi=dpi)

def save_fig(fig, filename, dpi=200, suffix=".pdf"):
    figpath = Path(fig_path / f"{filename}")
    fig.savefig(figpath.with_suffix(suffix), 
                bbox_inches="tight", 
                dpi=dpi)
    

sns.set_style("white")
label_map = {
    "Reference_Split": "Dataset Split Type",
    "Score": "Scoring Method",
    "RandomSplit": "Randomly Ordered",
    "DateSplit": "Ordered by Date",
    "RMSD": "RMSD (Positive Control)",
    "POSIT_Probability": "POSIT Probability",
    # "PairwiseSplit": "Similarity Metric",
    "Similarity_Threshold": "Similarity Threshold",
    "ECFP4_2048": "ECFP4 2048",
    "MCS": "MCS",
    "TanimotoCombo_True": "Tanimoto Combo (Aligned)",
    # "N_Reference_Structures": "Number of Randomly Chosen Reference Structures",
    "N_Reference_Structures": "Number of Reference Structures Available to Use \n(Log Scale)",
    "Fraction": "Fraction of Ligands Posed \n<2Ã… from Reference",
    "CI_Lower": "Confidence Interval Lower Bound",
    "CI_Upper": "Confidence Interval Upper Bound",
    
}
def get_label(var):
    return label_map.get(var,var)
for column in pdf.columns:
    if not column in label_map:
        label_map[column] = column
        
X_VAR = label_map["N_Reference_Structures"]
Y_VAR = label_map["Fraction"]
X_LABEL = label_map["N_Reference_Structures"]
Y_LABEL = label_map["Fraction"]
# QUERY_SCAFFOLD_ID = label_map["Query_Scaffold_ID_Subset_1"]
# REF_SCAFFOLD_ID = label_map["Reference_Scaffold_ID_Subset_1"]
COLOR_VAR = label_map["Reference_Split"]
STYLE_VAR = label_map["Score"]
CI_LOWER = label_map["CI_Lower"]
CI_UPPER = label_map["CI_Upper"]
LARGE_FIG_SIZE = (12, 8)
SMALL_FIG_SIZE = (8, 6)
FONT_SIZES = {
    "xlabel": 24,
    "ylabel": 24,
    "ticks": 18,
    "legend_title": 24,
    "legend_text": 18,
}
ALPHA = 0.1

# Plotting

In [None]:
sdf = pdf[(pdf["PairwiseSplit"] == "ScaffoldSplit")&(pdf["Scaffold_Split_Option"].isin(['x_to_y', 'x_to_x']))]

## Scaffold x_to_y heatmap - all refs

### filter to first 20 scaffolds by count

In [None]:
ssdf = sdf[(sdf["Reference_Split"].isna()&(sdf["Scaffold_Split_Option"].isin(['x_to_y', 'x_to_x'])))]
top20_scaff_ids = scaff_data.sort_values("scaffold_count", ascending=False).head(20)["scaffold_orig_id"].tolist()
ssdf = ssdf[(ssdf["rint"].isin(top20_scaff_ids))&(ssdf["qint"].isin(top20_scaff_ids))]
heatmap_dfs = {
        "_".join(name): group for name, group in ssdf.groupby(["Score"])
    }

In [None]:
for name, heatmap_df in heatmap_dfs.items():
    pivot_fraction = heatmap_df.pivot(
        index="qint", columns="rint", values="Fraction"
    )
    ref_counts = (
        heatmap_df.sort_values("rint")
        .groupby(ref)
        .head(1)[[ref, "Total"]]
        .to_dict(orient="records")
    )
    count_dict = {data[ref]: data["Total"] for data in ref_counts}

    query_counts = (
        heatmap_df.sort_values("qint")
        .groupby(query)
        .head(1)[[query, "Total"]]
        .to_dict(orient="records")
    )
    count_dict = {data[query]: data["Total"] for data in query_counts}
    
    ytick_labels = [
        f"$\\bf{int(cluster_id)}$ ({total})" for cluster_id, total in count_dict.items()
    ]
    xtick_labels = [
        f"$\\bf{int(cluster_id)}$\n({total})" for cluster_id, total in count_dict.items()
    ]
    plt.figure(figsize=LARGE_FIG_SIZE)
    
    # Create custom annotation array
    annotations = pivot_fraction.copy()
    annotations = annotations.map(lambda x: '' if x in [0.0] else f'{x:.1f}')
    heatmap = sns.heatmap(
        data=pivot_fraction,
        xticklabels=xtick_labels,
        yticklabels=ytick_labels,
        annot=annotations,
        fmt='',
        cmap="coolwarm_r",
        vmin=0, 
        vmax=1
    )
    # Add colorbar label
    heatmap.collections[0].colorbar.set_label("Fraction")

    # Rotate axis labels for better readability
    plt.xticks(rotation=0)
    plt.yticks(rotation=0)

    # Invert y-axis to put 0 at bottom
    plt.gca().invert_yaxis()

    # Set axis labels
    plt.xlabel(
        f"$\\bf{{Reference Scaffold ID}}$\n (# Reference Structures with Scaffold)",
        fontsize=FONT_SIZES["xlabel"],
        fontweight="normal",
    )
    plt.ylabel(
        f"$\\bf{{Query Scaffold ID}}$\n (# Query Ligands with Scaffold)",
        fontsize=FONT_SIZES["ylabel"],
        fontweight="normal",
    )
    plt.title(
        f"Scored by {get_label(name)}",
        fontsize=FONT_SIZES["xlabel"],
        fontweight="bold",
    )

    save_fig(plt, f"scaffold_x_to_y_heatmap_top_20_{name}")

# plot everything

In [None]:
ssdf = sdf[sdf["Reference_Split"].isna()]
heatmap_dfs = {
        "_".join(name): group for name, group in ssdf.groupby(["Score"])
    }

In [None]:
for name, heatmap_df in heatmap_dfs.items():
    pivot_fraction = heatmap_df.pivot(
        index="qint", columns="rint", values="Fraction"
    )
    ref_counts = (
        heatmap_df.sort_values("rint")
        .groupby(ref)
        .head(1)[[ref, "Total"]]
        .to_dict(orient="records")
    )
    count_dict = {data[ref]: data["Total"] for data in ref_counts}

    query_counts = (
        heatmap_df.sort_values("qint")
        .groupby(query)
        .head(1)[[query, "Total"]]
        .to_dict(orient="records")
    )
    count_dict = {data[query]: data["Total"] for data in query_counts}
    
    ytick_labels = [
        f"$\\bf{int(cluster_id) + 1}$ ({total})" for cluster_id, total in count_dict.items()
    ]
    xtick_labels = [
        f"$\\bf{int(cluster_id) + 1}$\n({total})" for cluster_id, total in count_dict.items()
    ]
    plt.figure(figsize=LARGE_FIG_SIZE)
    
    # Create custom annotation array
    annotations = pivot_fraction.copy()
    annotations = annotations.map(lambda x: '' if x in [0.0] else f'{x:.1f}')
    heatmap = sns.heatmap(
        data=pivot_fraction,
        # xticklabels=xtick_labels,
        # yticklabels=ytick_labels,
        # annot=annotations,
        fmt='',
        cmap="coolwarm_r",
        vmin=0, 
        vmax=1
    )
    # Add colorbar label
    heatmap.collections[0].colorbar.set_label("Fraction")

    # Rotate axis labels for better readability
    plt.xticks(rotation=0)
    plt.yticks(rotation=0)

    # Invert y-axis to put 0 at bottom
    plt.gca().invert_yaxis()

    # Set axis labels
    plt.xlabel(
        f"$\\bf{{Reference Scaffold ID}}$\n (# Reference Structures with Scaffold)",
        fontsize=FONT_SIZES["xlabel"],
        fontweight="normal",
    )
    plt.ylabel(
        f"$\\bf{{Query Scaffold ID}}$\n (# Query Ligands with Scaffold)",
        fontsize=FONT_SIZES["ylabel"],
        fontweight="normal",
    )
    plt.title(
        f"Scored by {get_label(name)}",
        fontsize=FONT_SIZES["xlabel"],
        fontweight="bold",
    )

    save_fig(plt, f"scaffold_x_to_y_heatmap_{name}")

# Add scaffold size info

In [None]:
ssdf = sdf[sdf["Reference_Split"].isna()]

In [None]:
ssdf.nunique()

# Plot fraction vs size as ref and query

In [None]:
id_to_scaffold = {scaff_data["scaffold_orig_id"].iloc[i]: scaff_data["scaffold_smarts"].iloc[i] for i in range(len(scaff_data))}

In [None]:
len(id_to_scaffold)

In [None]:
def get_scaffold_size(smarts):
    try:
        mol = Chem.MolFromSmarts(smarts)
        return int(mol.GetNumHeavyAtoms())
    except Exception as e:
        return 0

# get tidy dataframe from results df

In [None]:
posit_df = ssdf[ssdf["Score"] == "POSIT_Probability"]

In [None]:
query_data = posit_df[["qint", "Fraction", "CI_Lower", "CI_Upper"]]
query_data.rename(columns={"qint": "Scaffold ID"}, inplace=True)
query_data["Scaffold As"] = "Query"

In [None]:
ref_data = posit_df[["rint", "Fraction", "CI_Lower", "CI_Upper"]]
ref_data.rename(columns={"rint": "Scaffold ID"}, inplace=True)
ref_data["Scaffold As"] = "Reference"

## combine results

In [None]:
combined_data = pd.concat([query_data, ref_data], ignore_index=True)
combined_data = combined_data.merge(scaff_data, left_on="Scaffold ID", right_on="scaffold_orig_id", how="left", suffixes=("", "_ref"))

In [None]:
combined_data["Scaffold Size"] = combined_data["scaffold_smarts"].apply(get_scaffold_size)

In [None]:
#convert to integer to properly sort
combined_data['datetime'] = combined_data.scaffold_first_date.astype(int)
dates = sorted(combined_data['scaffold_first_date'].unique())

In [None]:
from datetime import timedelta

In [None]:
date_range = pd.date_range(start=dates[0], end=dates[-1] + timedelta(days=1), periods=10)

In [None]:
from bisect import bisect_left
date_map = {date: date_range[bisect_left(date_range, date)] for i, date in enumerate(dates)}

In [None]:
combined_data['scaffold_first_date_ceiling'] = combined_data['scaffold_first_date'].apply(lambda x: date_map[x])

# plot fraction vs size

## query

In [None]:
plot_df = combined_data.copy()
plot_df = plot_df[(plot_df["Scaffold As"] == "Query")]
plot_df = plot_df[plot_df["Scaffold Size"] > 0]
plot_df = plot_df[plot_df["Scaffold Size"] < 35]
plot_df.sort_values(["datetime", "Scaffold Size",], inplace=True)
# sns.scatterplot(data=combined_data, x="Scaffold Size", y="Fraction", hue="scaffold_first_date", style="Scaffold As", alpha=0.5, legend="full", palette="viridis", errorbar=('ci', 95))
ax = sns.pointplot(data=plot_df, 
                   x="Scaffold Size", 
                   y="Fraction", 
                   hue="scaffold_first_date_ceiling",
                   alpha=0.5, 
                   legend="auto",
                   palette="viridis", 
                   errorbar=('ci', 95), 
                   linestyle="none", 
                   hue_order=date_range,
                   native_scale=True
                   )
# set y range to (0,1)
ax.set_ylim(-0.05, 1.05)

# move legend to right of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc='upper left', bbox_to_anchor=(1, 1), title="First Date of Scaffold Deposition")

## ref

In [None]:
plot_df = combined_data.copy()
plot_df = plot_df[(plot_df["Scaffold As"] == "Reference")]
plot_df = plot_df[plot_df["Scaffold Size"] > 0]
plot_df = plot_df[plot_df["Scaffold Size"] < 35]
plot_df.sort_values(["datetime", "Scaffold Size",], inplace=True)
# sns.scatterplot(data=combined_data, x="Scaffold Size", y="Fraction", hue="scaffold_first_date", style="Scaffold As", alpha=0.5, legend="full", palette="viridis", errorbar=('ci', 95))
ax = sns.pointplot(data=plot_df, 
                   x="Scaffold Size", 
                   y="Fraction", 
                   hue="scaffold_first_date_ceiling",
                   alpha=0.5, 
                   legend="auto",
                   palette="viridis", 
                   errorbar=('ci', 95), 
                   linestyle="none", 
                   hue_order=date_range,
                   native_scale=True
                   )
# set y range to (0,1)
ax.set_ylim(-0.05, 1.05)

# move legend to right of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc='upper left', bbox_to_anchor=(1, 1), title="First Date of Scaffold Deposition")

# Fraction vs scaffold size for only top 20 scaffolds

## query

In [None]:
plot_df = combined_data.copy()
plot_df = plot_df[(plot_df["Scaffold As"] == "Query")]
plot_df = plot_df[plot_df["Scaffold Size"] > 0]
plot_df = plot_df[plot_df["Scaffold Size"] < 35]
plot_df = plot_df[plot_df["scaffold_orig_id"].isin(top20_scaff_ids)]
plot_df.sort_values(["datetime", "Scaffold Size",], inplace=True)
# sns.scatterplot(data=combined_data, x="Scaffold Size", y="Fraction", hue="scaffold_first_date", style="Scaffold As", alpha=0.5, legend="full", palette="viridis", errorbar=('ci', 95))
ax = sns.pointplot(data=plot_df, 
                   x="Scaffold Size", 
                   y="Fraction", 
                   hue="scaffold_first_date_ceiling",
                   alpha=0.5, 
                   legend="auto",
                   palette="viridis", 
                   errorbar=('ci', 95), 
                   linestyle="none", 
                   hue_order=date_range,
                   native_scale=True
                   )
# set y range to (0,1)
ax.set_ylim(-0.05, 1.05)

# move legend to right of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc='upper left', bbox_to_anchor=(1, 1), title="First Date of Scaffold Deposition")

# ref

In [None]:
plot_df = combined_data.copy()
plot_df = plot_df[(plot_df["Scaffold As"] == "Reference")]
plot_df = plot_df[plot_df["Scaffold Size"] > 0]
plot_df = plot_df[plot_df["Scaffold Size"] < 35]
plot_df = plot_df[plot_df["scaffold_orig_id"].isin(top20_scaff_ids)]
plot_df.sort_values(["datetime", "Scaffold Size",], inplace=True)
# sns.scatterplot(data=combined_data, x="Scaffold Size", y="Fraction", hue="scaffold_first_date", style="Scaffold As", alpha=0.5, legend="full", palette="viridis", errorbar=('ci', 95))
ax = sns.pointplot(data=plot_df, 
                   x="Scaffold Size", 
                   y="Fraction", 
                   hue="scaffold_first_date_ceiling",
                   alpha=0.5, 
                   legend="auto",
                   palette="viridis", 
                   errorbar=('ci', 95), 
                   linestyle="none", 
                   hue_order=date_range,
                   native_scale=True
                   )
# set y range to (0,1)
ax.set_ylim(-0.05, 1.05)

# move legend to right of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc='upper left', bbox_to_anchor=(1, 1), title="First Date of Scaffold Deposition")

# X to Not X

In [None]:
xdf = pdf[pdf["Scaffold_Split_Option"].isin(['x_to_not_x'])]

In [None]:
xdf = xdf[xdf["N_Reference_Structures"] == 403]

In [None]:
xdf = xdf[xdf["Score"] == "POSIT_Probability"]

In [None]:
xdf = xdf[xdf["qint"].isin(top20_scaff_ids)]

## add datetime and size info

In [None]:
xdf = xdf.merge(scaff_data, left_on="qint", right_on="scaffold_orig_id", how="left", suffixes=("", "_ref"))

In [None]:
xdf["Scaffold Size"] = xdf["scaffold_smarts"].apply(get_scaffold_size)

In [None]:
#convert to integer to properly sort
xdf['datetime'] = xdf.scaffold_first_date.astype(int)
dates = sorted(xdf['scaffold_first_date'].unique())

In [None]:
from datetime import timedelta

In [None]:
date_range = pd.date_range(start=dates[0], end=dates[-1] + timedelta(days=1), periods=10)

In [None]:
from bisect import bisect_left
date_map = {date: date_range[bisect_left(date_range, date)] for i, date in enumerate(dates)}

In [None]:
xdf['scaffold_first_date_ceiling'] = xdf['scaffold_first_date'].apply(lambda x: date_map[x])

In [None]:
plot_df = xdf.copy()
plot_df.sort_values(["datetime", "Scaffold Size",], inplace=True)
# sns.scatterplot(data=combined_data, x="Scaffold Size", y="Fraction", hue="scaffold_first_date", style="Scaffold As", alpha=0.5, legend="full", palette="viridis", errorbar=('ci', 95))
ax = sns.pointplot(data=plot_df, 
                   x="Scaffold Size", 
                   y="Fraction", 
                   hue="scaffold_first_date_ceiling",
                   alpha=0.5, 
                   legend="auto",
                   palette="viridis", 
                   errorbar=('ci', 95), 
                   linestyle="none", 
                   hue_order=date_range,
                   native_scale=True
                   )
# set y range to (0,1)
ax.set_ylim(-0.05, 1.05)

# move legend to right of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels, loc='upper left', bbox_to_anchor=(1, 1), title="First Date of Scaffold Deposition")

In [None]:
plot_df

In [None]:
xdf['scaffold_first_date_ceiling'] = xdf['scaffold_first_date'].apply(lambda x: date_map[x])
plot_df = xdf.copy()
# Format dates to show only the date part
plot_df['scaffold_first_date_ceiling'] = plot_df['scaffold_first_date_ceiling'].dt.date

ax = sns.scatterplot(data=plot_df, 
                    x="Scaffold Size", 
                    y="Fraction", 
                    hue="scaffold_first_date_ceiling",
                    # alpha=0.5, 
                     s=100,
                    palette="viridis",
                    hue_order=[d.date() for d in date_range]
                    )

# Add asymmetric error bars manually
ax.errorbar(x=plot_df["Scaffold Size"],
           y=plot_df["Fraction"],
           yerr=[plot_df["Error_Lower"], plot_df["Error_Upper"]],  # [lower errors, upper errors]
           fmt='none',
           color='gray',
           alpha=0.5
            )

# Set y range
ax.set_ylim(-0.05, 1.05)

# Move legend to right of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, 
         labels=labels, 
         loc='upper left', 
         bbox_to_anchor=(1, 1), 
         title="First Date of Scaffold Deposition")

# plot combined query and ref fraction vs scaffold size

In [None]:
query_df = pdf[pdf["Scaffold_Split_Option"].isin(['x_to_not_x'])]
ref_df = pdf[pdf["Scaffold_Split_Option"].isin(['not_x_to_x'])]
ref_df["Scaffold ID"] = ref_df["rint"]
query_df["Scaffold ID"] = query_df["qint"]
ref_df["Scaffold As"] = "Reference"
query_df["Scaffold As"] = "Query"
combined_df = pd.concat([query_df, ref_df], ignore_index=True)
useful_cols = ["Scaffold ID", "Scaffold As", "Fraction", "Score", "CI_Lower", "CI_Upper", "Error_Lower", "Error_Upper", "N_Reference_Structures"]
combined_df = combined_df[useful_cols]
combined_df = combined_df[(combined_df["N_Reference_Structures"].isna())|(combined_df["N_Reference_Structures"] == 403)]

## add data

In [None]:
combined_df = combined_df.merge(scaff_data, left_on="Scaffold ID", right_on="scaffold_orig_id", how="left", suffixes=("", "_ref"))
combined_df["Scaffold Size"] = combined_df["scaffold_smarts"].apply(get_scaffold_size)

In [None]:
combined_df.columns

In [None]:
dates = sorted(combined_df['scaffold_first_date'].unique())

In [None]:
from datetime import timedelta

date_range = pd.date_range(start=dates[0], end=dates[-1] + timedelta(days=1), periods=10)
from bisect import bisect_left

date_map = {date: date_range[bisect_left(date_range, date)] for i, date in enumerate(dates)}

In [None]:
combined_df['scaffold_first_date_ceiling'] = combined_df['scaffold_first_date'].apply(lambda x: date_map[x])

## POSIT Probability, Everything

In [None]:
plot_df = combined_df[combined_df["Score"] == "POSIT_Probability"]

In [None]:
from matplotlib.colors import LogNorm

In [None]:
plot_df['count_category'] = plot_df['scaffold_count'].apply(lambda x: str(x) if x < 10 else '>=10')
plot_df.sort_values("count_category", inplace=True)

In [None]:
sns.scatterplot(data=plot_df,
                x="Scaffold Size",
                y="Fraction",
                markers=["o", "D"],
                hue="count_category",
                style="Scaffold As",
                palette="mako_r",)

In [None]:
sns.displot(data=plot_df[plot_df["Scaffold As"] == "Query"],
            x="Fraction")

In [None]:
sns.displot(data=plot_df[plot_df["Scaffold As"] == "Reference"],
            x="Fraction")

In [None]:
sns.ecdfplot(data=plot_df[plot_df["Scaffold As"] == "Query"],
             x="Scaffold Size",
             )

In [None]:
sns.ecdfplot(data=plot_df[plot_df["Scaffold As"] == "Reference"],
             x="Scaffold Size",
             )

In [None]:
ax = sns.scatterplot(data=plot_df, 
                    x="Scaffold Size", 
                    y="Fraction", 
                    hue="scaffold_first_date_ceiling",
                     style="Scaffold As",
                     # markers = ["o", "D"],
                    # alpha=0.5, 
                    #  s=100,
                    palette="viridis",
                    # hue_order=[d.date() for d in date_range],
                    )

In [None]:
sns.scatterplot(data=plot_df, 
                    x="Scaffold Size", 
                    y="Fraction", 
                    hue="scaffold_first_date_ceiling",
                     style="Scaffold As",
                    palette="viridis",
                    )

In [None]:
ax = sns.scatterplot(data=plot_df, 
                    x="Scaffold Size", 
                    y="Fraction", 
                    hue="scaffold_first_date_ceiling",
                     style="Scaffold As",
                     markers = ["o", "D"],
                    # alpha=0.5, 
                     s=100,
                    palette="viridis",
                    hue_order=[d.date() for d in date_range],
                    )

# Add asymmetric error bars manually
# ax.errorbar(x=plot_df["Scaffold Size"],
#            y=plot_df["Fraction"],
#            yerr=[plot_df["Error_Lower"], plot_df["Error_Upper"]],
#            fmt='none',
#            color='black',
#            alpha=0.5
#            )

# Set y range
ax.set_ylim(-0.05, 1.05)

# Move legend to right of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles,
         labels=labels,
         loc='upper left',
         bbox_to_anchor=(1, 1),
          frameon=False,
         )


# update the x and y tick labels and ticks 
# Set axis labels
plt.xlabel(
    f"Scaffold Size (# Heavy Atoms)",
    fontsize=FONT_SIZES["xlabel"],
    fontweight="normal",
)
plt.ylabel(
    f"{Y_VAR}",
    fontsize=FONT_SIZES["ylabel"],
    fontweight="normal",
)
# update x and y ticks
plt.xticks(fontsize=FONT_SIZES["ticks"])
plt.yticks(fontsize=FONT_SIZES["ticks"])

# update legend text
for text in ax.get_legend().get_texts():
    if text._text in ["scaffold_first_date_ceiling", "Scaffold As"]:
        text.set_fontsize(FONT_SIZES["legend_text"])
        text._text = text._text.replace("scaffold_first_date_ceiling", "First Date of \nScaffold \nDeposition")
plt.savefig(fig_path / "scaffold_size_vs_fraction_query_ref.pdf", bbox_inches="tight", dpi=200)

## posit_probability, top 20

## actually make plot

In [None]:
plot_df = combined_df[(combined_df["Scaffold ID"].isin(top20_scaff_ids))&(combined_df["Score"] == "POSIT_Probability")]

In [None]:
plot_df.sort_values(["scaffold_first_date_ceiling"], inplace=True)

In [None]:
sns.set_style("ticks")
ax = sns.scatterplot(data=plot_df, 
                    x="Scaffold Size", 
                    y="Fraction", 
                    hue="scaffold_first_date_ceiling",
                     style="Scaffold As",
                     markers = ["o", "D"],
                    # alpha=0.5, 
                     s=100,
                    palette="viridis",
                    )

# Add asymmetric error bars manually
ax.errorbar(x=plot_df["Scaffold Size"],
           y=plot_df["Fraction"],
           yerr=[plot_df["Error_Lower"], plot_df["Error_Upper"]],
           fmt='none',
           color='black',
           alpha=0.5
           )

# Set y range
ax.set_ylim(-0.05, 1.05)

# Move legend to right of plot
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles,
         labels=labels,
         loc='upper left',
         bbox_to_anchor=(1, 1),
          frameon=False,
         )


# update the x and y tick labels and ticks 
# Set axis labels
plt.xlabel(
    f"Scaffold Size (# Heavy Atoms)",
    fontsize=FONT_SIZES["xlabel"],
    fontweight="normal",
)
plt.ylabel(
    f"{Y_VAR}",
    fontsize=FONT_SIZES["ylabel"],
    fontweight="normal",
)
# update x and y ticks
plt.xticks(fontsize=FONT_SIZES["ticks"])
plt.yticks(fontsize=FONT_SIZES["ticks"])

# update legend text
for text in ax.get_legend().get_texts():
    if text._text in ["scaffold_first_date_ceiling", "Scaffold As"]:
        text.set_fontsize(FONT_SIZES["legend_text"])
        text._text = text._text.replace("scaffold_first_date_ceiling", "First Date of \nScaffold \nDeposition")
plt.savefig(fig_path / "scaffold_size_vs_fraction_query_ref.pdf", bbox_inches="tight", dpi=200)