In [None]:
import pandas as pd
import json
from datetime import datetime
import plotly.express as px
from pathlib import Path
from matplotlib.pyplot import ScalarFormatter
from asapdiscovery.data.readers.molfile import MolFileFactory
from harbor.analysis.cross_docking import DockingDataModel
import seaborn as sns
from matplotlib import pyplot as plt

In [None]:
data = DockingDataModel.deserialize("/Users/alexpayne/Scientific_Projects/mers-drug-discovery/sars2-retrospective-analysis/full_cross_dock_v2_combined_results/ALL_1_poses.parquet")

In [None]:
raw_df = data.dataframe.groupby("Reference_Ligand").head(1)

# INCREMENT SCAFFOLD ID BY 1

In [None]:
raw_df['RefData_Scaffold_ID'] = raw_df['RefData_Scaffold_ID'].apply(lambda x: x + 1)

## count number of structures per cluster

In [None]:
cluster_counts = raw_df.groupby('RefData_Scaffold_ID').count().reset_index()[['RefData_Scaffold_ID', 'Reference_Ligand']]
cluster_counts.columns = ['RefData_Scaffold_ID', 'count']

## get earliest date for each cluster

In [None]:
date_df = raw_df.sort_values('RefData_Date').groupby("RefData_Scaffold_ID").first().reset_index()[['RefData_Scaffold_ID', 'RefData_Date', "RefData_Scaffold_Smarts"]].sort_values('RefData_Date')

### combine data

In [None]:
df = cluster_counts.merge(date_df, on='RefData_Scaffold_ID')

In [None]:
df.columns = ['scaffold_orig_id', 'scaffold_count', 'scaffold_first_date', 'scaffold_smarts']

In [None]:
ligand_df = raw_df[["Reference_Ligand", "RefData_Scaffold_ID", "RefData_Date", "RefData_Scaffold_Smarts", "PoseData_SMILES"]]

In [None]:
ligand_df.columns = ['compound_name', 'scaffold_orig_id', 'compound_date', 'scaffold_smarts', 'smiles']

## add scaffold data to ligand_df

In [None]:
ligand_df = ligand_df.merge(df, on='scaffold_orig_id', how='left', suffixes=['', '_from_scaff_data']).reset_index()

### make sure scaffold smarts is the same

In [None]:
ligand_df[ligand_df['scaffold_smarts'] != ligand_df['scaffold_smarts_from_scaff_data']]

"AAR-POS-d2a4d1df-17" doesn't have a scaffold smarts?

### make sure compound_date is always after or equal to scaffold_date

In [None]:
all(ligand_df['compound_date'] >= ligand_df['scaffold_first_date'])

yay!

# Plotting variables

In [None]:
# Global configuration
fig_path = Path("./20250703_scaffolds_over_time")
fig_path.mkdir(parents=True, exist_ok=True)

FIGNUM_GLOBAL = 0

# def save_fig(fig, filename, dpi=200, suffix=".pdf"):
#     """Save the figure with a global figure number."""
#     global FIGNUM_GLOBAL
#     FIGNUM_GLOBAL += 1
#     figpath = Path(fig_path / f"{filename}_{FIGNUM_GLOBAL:02d}")
#     fig.savefig(figpath.with_suffix(suffix), 
#                 bbox_inches="tight", 
#                 dpi=dpi)

def save_fig(fig, filename, dpi=200, suffix=".pdf"):
    figpath = Path(fig_path / f"{filename}")
    fig.savefig(figpath.with_suffix(suffix), 
                bbox_inches="tight", 
                dpi=dpi)
    

sns.set_style("white")
label_map = {
    "Reference_Split": "Dataset Split Type",
    "Score": "Scoring Method",
    "RandomSplit": "Randomly Ordered",
    "DateSplit": "Ordered by Date",
    "RMSD": "RMSD (Positive Control)",
    "POSIT_Probability": "POSIT Probability",
    # "PairwiseSplit": "Similarity Metric",
    "Similarity_Threshold": "Similarity Threshold",
    "ECFP4_2048": "ECFP4 2048",
    "MCS": "MCS",
    "TanimotoCombo_True": "Tanimoto Combo (Aligned)",
    # "N_Reference_Structures": "Number of Randomly Chosen Reference Structures",
    "N_Reference_Structures": "Number of Reference Structures Available to Use \n(Log Scale)",
    "Fraction": "Fraction of Ligands Posed \n<2Å from Reference",
    "CI_Lower": "Confidence Interval Lower Bound",
    "CI_Upper": "Confidence Interval Upper Bound",
    
}
        
X_VAR = label_map["N_Reference_Structures"]
Y_VAR = label_map["Fraction"]
X_LABEL = label_map["N_Reference_Structures"]
Y_LABEL = label_map["Fraction"]
# QUERY_SCAFFOLD_ID = label_map["Query_Scaffold_ID_Subset_1"]
# REF_SCAFFOLD_ID = label_map["Reference_Scaffold_ID_Subset_1"]
COLOR_VAR = label_map["Reference_Split"]
STYLE_VAR = label_map["Score"]
CI_LOWER = label_map["CI_Lower"]
CI_UPPER = label_map["CI_Upper"]
LARGE_FIG_SIZE = (12, 8)
SMALL_FIG_SIZE = (8, 6)
FONT_SIZES = {
    "xlabel": 24,
    "ylabel": 24,
    "ticks": 18,
    "legend_title": 24,
    "legend_text": 18,
}
ALPHA = 0.1

# Plot Scaffolds Over Time

In [None]:
def make_image(df,x_col = "Date", color="cluster_id"):
    import plotly.express as px

    large_font = 24
    small_font = 18

    fig = px.ecdf(
        df,
        x=x_col,
        color=color,
        ecdfnorm=None,
        template="simple_white",
        height=600,
        width=800,
    )
    # update legend title
    fig.update_layout(legend_title_text="<b> Scaffold </b>")
    fig.update_xaxes(title_text="<b> Date of Crystal Structure Collection </b>")
    fig.update_yaxes(title_text="<b> Cumulative Number of Structures </b>")

    update_layout_dict = dict(
        xaxis=dict(
            title_font=dict(size=large_font),
            color="black",
        ),
        yaxis=dict(
            # range=(0,1),
            title_font=dict(size=large_font),
            color="black",
        ),
    )

    # move legend to inside the plot
    fig.update_layout(
        legend=dict(yanchor="bottom", y=0.25, xanchor="right", x=1.1),
        **update_layout_dict,
    )

    return fig

In [None]:
pdf = ligand_df.sort_values("scaffold_count", ascending=False)

In [None]:
min_counts = 8
labels = []
for row in pdf.itertuples():
    if row.scaffold_count > min_counts:
        labels.append(f'Scaffold {row.scaffold_orig_id} - {row.scaffold_count} Molecules')
    else:
        labels.append(f'Misc Scaffolds - {sum(pdf.scaffold_count <= min_counts)} Molecules')

In [None]:
pdf["simplified_scaffold_ids"] = labels

In [None]:
fig = make_image(pdf, x_col="compound_date", color="simplified_scaffold_ids")
fig.write_image(fig_path / "scaffs_over_time_grouped.svg")

In [None]:
fig = make_image(pdf, x_col="compound_date", color="scaffold_orig_id")
fig.write_image(fig_path / "scaffs_over_time.svg")

# save svg of all scaffolds

In [None]:
from rdkit import Chem
from rdkit.Chem import Draw, rdDepictor

In [None]:
# get tuple of scaffold_id and rdkit mol
scaffold_mols = df.copy()
scaffold_mols['scaffold_mol'] = scaffold_mols.scaffold_smarts.apply(lambda x: Chem.MolFromSmiles(x) if x is not None else None)

In [None]:
def draw_single_mol(mol, fn, size=(400, 400)):
    mol = Chem.RemoveHs(mol)
    rdDepictor.Compute2DCoords(mol)
    rdDepictor.StraightenDepiction(mol)
    
    # Create the drawing object
    drawer = Draw.rdMolDraw2D.MolDraw2DSVG(*size)
    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()
    svg = drawer.GetDrawingText()
    with open(fn, "w") as f:
        f.write(svg)
    
    img = Draw.MolsToImage([mol], subImgSize=size, legends=[f"Scaffold {row.scaffold_orig_id}"])
    img.save(fn.with_suffix(".png"))

In [None]:
scaff_dir = fig_path / "scaffolds"
scaff_dir.mkdir(exist_ok=True)
for i, row in scaffold_mols.iterrows():
    if row.scaffold_mol is not None:
        row.scaffold_mol.SetProp("_Name", f"Scaffold {row.scaffold_orig_id}")
        draw_single_mol(row.scaffold_mol, scaff_dir / f"generic_scaffold_{row.scaffold_orig_id}.svg")

# Plot histogram of mols per scaffold

In [None]:
pdf

In [None]:
sns.set_style("ticks")
plt.figure(figsize=(4,4))
ax = sns.histplot(df, x="scaffold_count", binwidth=2)

# Set the y-axis to a logarithmic scale
ax.set_yscale('log')
ax.yaxis.set_major_formatter(ScalarFormatter())
from itertools import product
custom_ticks = list([a*b for a,b in product([1,2,5], [1,10,100]) if a*b <= 100])
ax.set_yticks(custom_ticks)
ax.set_ylabel("Number of Scaffolds \nwith N Molecules")
ax.set_xlabel("Number of Molecules per Scaffold")
save_fig(plt.gcf(), "scaffold_histogram")

# write out scaffold df

In [None]:
ligand_df.to_csv("ligand_scaffold_data.csv", index=False)