# Imports

In [None]:
import pandas as pd
from asapdiscovery.data.schema.ligand import Ligand
from asapdiscovery.data.readers.molfile import MolFileFactory
from rdkit.Chem import Draw, rdMolAlign, rdDepictor
from rdkit import Chem

from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict

# Load the Ligands

In [None]:
mff = MolFileFactory(filename="combined_3d.sdf")
ligs = mff.load()

In [None]:
unique_smiles = set([(lig.smiles, lig.compound_name) for lig in ligs])

In [None]:
dict_data = [{"smiles": lig.smiles, "compound_name": lig.compound_name, "series": lig.tags['xtal_name'][5], "number": lig.tags['xtal_name'].split("_")[0][6:], "xtal_id": lig.tags['xtal_name'].split("_")[1], "lig": lig} for lig in ligs]

In [None]:
df = pd.DataFrame.from_records(dict_data)

In [None]:
df = df[df["series"].isin(["x", "P"])]

In [None]:
unique_compounds = df.sort_values(["series", "number", "xtal_id"], ascending=[True, False, True]).groupby("compound_name").head(1).groupby("smiles").head(1)

In [None]:
unique_compounds.groupby("series").count()

In [None]:
ligs = unique_compounds["lig"].tolist()

In [None]:
unique_compounds

# Function Definitions

In [None]:
def generate_scaffold(ligand: Ligand, generic=False, include_chirality=True):
    """
    Compute the Bemis-Murcko scaffold for a SMILES string.
    Implementation copied from https://github.com/chemprop/chemprop.

    :param mol: A smiles string or an RDKit molecule.
    :param include_chirality: Whether to include chirality.
    :return:
    """
    
    if include_chirality:
        mol = ligand.to_rdkit()
    else:
        mol = Chem.MolFromSmiles(ligand.non_iso_smiles)
    scaffold = MurckoScaffold.GetScaffoldForMol(mol)
    if generic:
        scaffold = MurckoScaffold.MakeScaffoldGeneric(scaffold)

    return Chem.MolToSmiles(scaffold)

In [None]:
def split_by_scaffold(ligands, generic=False, include_chirality=True):
    """
    Split ligands by scaffold.
    """
    
    scaffolds = defaultdict(list)
    for ligand in ligands:
        scaffold = generate_scaffold(ligand, generic=generic, include_chirality=include_chirality)
        scaffolds[scaffold].append(ligand)
    scaffold_list = [{"scaffold": scaffold, "ligands": ligands} for scaffold, ligands in scaffolds.items()]
    return sorted(scaffold_list, key=lambda x: len(x["ligands"]), reverse=True)

In [None]:
def align_to_scaffold(scaffold: Chem.Mol, mols: list[Chem.Mol]):
    rdDepictor.Compute2DCoords(scaffold)
    template_match = scaffold.GetSubstructMatch(scaffold)
    for mol in mols:
        rdDepictor.Compute2DCoords(mol)
        query_match = mol.GetSubstructMatch(scaffold)
        rdMolAlign.AlignMol(mol, scaffold, atomMap=list(zip(query_match, template_match)))
    return scaffold, mols

In [None]:
def draw_scaffold_cluster(scaffold: str, ligands: list[Ligand], filename: str):
    
    # Convert to rdkit
    rdkitmols = [Chem.RemoveHs(ligand.to_rdkit()) for ligand in ligands]
    rdkit_scaffold = Chem.MolFromSmiles(scaffold)
    
    # Align the molecules to the scaffold
    rdkit_scaffold, rdkitmols = align_to_scaffold(rdkit_scaffold, rdkitmols)
    
    
    # Find the atoms to highlight
    highlight = [mol.GetSubstructMatch(rdkit_scaffold) for mol in rdkitmols]
    
    # Set Draw Options
    dopts = Draw.rdMolDraw2D.MolDrawOptions()
    dopts.setHighlightColour((68/256, 178/256, 212/256))
    dopts.highlightBondWidthMultiplier = 16
    
    # Draw the molecules
    img = Draw.MolsToGridImage(rdkitmols, 
                               molsPerRow=6, 
                               subImgSize=(200, 200), 
                               highlightAtomLists=highlight,
                               useSVG= True,
                               legends = [f"{ligand.compound_name} ({ligand.tags['xtal_name']})" for ligand in ligands],
                               drawOptions=dopts)
    with open(filename, 'w') as f:
        f.write(img.data) 

In [None]:
def draw_scaffolds(scaffold_smiles, first_n=-1, mols_per_row=-1):
    from rdkit.Chem import Draw, rdDepictor
    
    scaffolds = [Ligand.from_smiles(scaffold_dict['scaffold'], compound_name=f"Cluster {i} - {len(scaffold_dict['ligands'])} molecules") for i, scaffold_dict in enumerate(scaffold_smiles)]
    scaffold_rdmols = [Chem.RemoveHs(ligand.to_rdkit()) for ligand in scaffolds]
    
    # Set Draw Options
    dopts = Draw.rdMolDraw2D.MolDrawOptions()
    dopts.setHighlightColour((68/256, 178/256, 212/256))
    dopts.highlightBondWidthMultiplier = 16
    d2d = Draw.MolDraw2DCairo(350,300)
    for mol in scaffold_rdmols[:first_n]:
        Draw.MolToImage(mol, size=(200, 200), options=dopts)
        rdDepictor.Compute2DCoords(mol)
        rdDepictor.StraightenDepiction(mol)
        d2d.DrawMolecule(mol)
    img = Draw.MolsToGridImage(scaffold_rdmols[:first_n], 
                               molsPerRow=mols_per_row, 
                               subImgSize=(200, 200), 
                               # highlightAtomLists=highlight,
                               useSVG= True,
                               legends = [ligand.compound_name for ligand in scaffolds[:first_n]],
                               drawOptions=dopts)
    return img

# Execution

In [None]:
generic_scaffolds = split_by_scaffold(ligs, generic=True)
scaffolds = split_by_scaffold(ligs, generic=False)
generic_achiral = split_by_scaffold(ligs, generic=True, include_chirality=False)
achiral = split_by_scaffold(ligs, generic=False, include_chirality=False)

In [None]:
len(generic_scaffolds), len(scaffolds), len(generic_achiral), len(achiral)

In [None]:
scaffolds_labeled = {ligand.compound_name: scaffold for scaffold in scaffolds for ligand in scaffold['ligands']}
generic_scaffolds_labeled = {ligand.compound_name: scaffold for scaffold in generic_scaffolds for ligand in scaffold['ligands']}

In [None]:
example = 'MAT-POS-7dfc56d9-1'

In [None]:
lig_dict = {ligand.compound_name: ligand for ligand in ligs}

In [None]:
Chem.MolFromSmiles(lig_dict[example].smiles)

In [None]:
Chem.MolFromSmiles(generic_scaffolds_labeled[example]['scaffold'])

In [None]:
Chem.MolFromSmiles(scaffolds_labeled[example]['scaffold'])

# Save Label Dict

In [None]:
import pandas as pd

In [None]:
unique_compounds['name'] = unique_compounds['compound_name']
unique_compounds.drop(columns=["compound_name"], inplace=True)

# Draw Scaffolds

In [None]:
for name, scaffold_list in [("generic", generic_scaffolds), ("default", scaffolds), ("generic_achiral", generic_achiral), ("achiral", achiral)]:
    cluster_labels = []
    for i, scaffold_dict in enumerate(scaffold_list): 
        for ligand in scaffold_dict['ligands']:
            cluster_labels.append(dict(name=ligand.compound_name, Cluster=i, Scaffold_Smiles=scaffold_dict['scaffold']))
        
    cluster_df = pd.DataFrame.from_records(cluster_labels)
    cluster_df = pd.merge(cluster_df, unique_compounds, on="name")
    cluster_df['structure_name'] = "Mpro-" + cluster_df['series'] + cluster_df['number']
    cluster_df.to_csv(f"{name}_cluster_labels.csv", index=False)
    
    img = draw_scaffolds(scaffold_list, first_n=96, mols_per_row=8)
    with open(f"{name}_scaffold_images_8x12.svg", 'w') as f:
            f.write(img.data)