# Imports

In [None]:
import pandas as pd
from asapdiscovery.data.schema.ligand import Ligand
from asapdiscovery.data.readers.molfile import MolFileFactory
from rdkit.Chem import Draw, rdMolAlign, rdDepictor
from rdkit import Chem

from rdkit.Chem.Scaffolds import MurckoScaffold
from collections import defaultdict

In [None]:
mff = MolFileFactory(filename="../data/combined_2d.sdf")
ligs = mff.load()

# Function Definitions

In [None]:
def generate_scaffold(ligand: Ligand, include_chirality=True):
    """
    Compute the Bemis-Murcko scaffold for a SMILES string.
    Implementation copied from https://github.com/chemprop/chemprop.

    :param mol: A smiles string or an RDKit molecule.
    :param include_chirality: Whether to include chirality.
    :return:
    """
    mol = ligand.to_rdkit()
    scaffold = MurckoScaffold.MurckoScaffoldSmiles(mol=mol, includeChirality=include_chirality)

    return scaffold

In [None]:
def split_by_scaffold(ligands):
    """
    Split ligands by scaffold.
    """
    
    scaffolds = defaultdict(list)
    for ligand in ligands:
        scaffold = generate_scaffold(ligand)
        scaffolds[scaffold].append(ligand)
    scaffold_list = [{"scaffold": scaffold, "ligands": ligands} for scaffold, ligands in scaffolds.items()]
    return sorted(scaffold_list, key=lambda x: len(x["ligands"]), reverse=True)

In [None]:
def align_to_scaffold(scaffold: Chem.Mol, mols: list[Chem.Mol]):
    rdDepictor.Compute2DCoords(scaffold)
    template_match = scaffold.GetSubstructMatch(scaffold)
    for mol in mols:
        rdDepictor.Compute2DCoords(mol)
        query_match = mol.GetSubstructMatch(scaffold)
        rdMolAlign.AlignMol(mol, scaffold, atomMap=list(zip(query_match, template_match)))
    return scaffold, mols

In [None]:
def draw_scaffold_cluster(scaffold: str, ligands: list[Ligand], filename: str):
    
    # Convert to rdkit
    rdkitmols = [Chem.RemoveHs(ligand.to_rdkit()) for ligand in ligands]
    rdkit_scaffold = Chem.MolFromSmiles(scaffold)
    
    # Align the molecules to the scaffold
    rdkit_scaffold, rdkitmols = align_to_scaffold(rdkit_scaffold, rdkitmols)
    
    
    # Find the atoms to highlight
    highlight = [mol.GetSubstructMatch(rdkit_scaffold) for mol in rdkitmols]
    
    # Set Draw Options
    dopts = Draw.rdMolDraw2D.MolDrawOptions()
    dopts.setHighlightColour((68/256, 178/256, 212/256))
    dopts.highlightBondWidthMultiplier = 16
    
    # Draw the molecules
    img = Draw.MolsToGridImage(rdkitmols, 
                               molsPerRow=6, 
                               subImgSize=(200, 200), 
                               highlightAtomLists=highlight,
                               useSVG= True,
                               legends = [ligand.compound_name for ligand in ligands],
                               drawOptions=dopts)
    with open(filename, 'w') as f:
        f.write(img.data)

# Execution

In [None]:
scaffolds = split_by_scaffold(ligs)

In [None]:
cluster_labels = []
for i, scaffold_dict in enumerate(scaffolds):    
    draw_scaffold_cluster(scaffold_dict['scaffold'], scaffold_dict['ligands'], f"scaffold_{i}.svg")
    for ligand in scaffold_dict['ligands']:
        cluster_labels.append({"Name": ligand.compound_name, "Cluster": i, "Scaffold_Smiles": scaffold_dict['scaffold']})

# Save Label Dict

In [None]:
import pandas as pd

In [None]:
cluster_df = pd.DataFrame.from_records(cluster_labels)

In [None]:
cluster_df.to_csv("cluster_labels.csv", index=False)

# Draw Scaffolds

In [None]:
scaffold_smiles = [Ligand.from_smiles(scaffold_dict['scaffold'], compound_name=f"Cluster {i} - {len(scaffold_dict['ligands'])} molecules") for i, scaffold_dict in enumerate(scaffolds)]

In [None]:
scaffold_rdmols = [Chem.RemoveHs(ligand.to_rdkit()) for ligand in scaffold_smiles]

In [None]:
scaffold_rdmols[:10]

In [None]:
# Set Draw Options
dopts = Draw.rdMolDraw2D.MolDrawOptions()
dopts.setHighlightColour((68/256, 178/256, 212/256))
dopts.highlightBondWidthMultiplier = 16

In [None]:
from rdkit.Chem import Draw, rdDepictor
d2d = Draw.MolDraw2DCairo(350,300)
for mol in scaffold_rdmols[:12]:
    Draw.MolToImage(mol, size=(200, 200), options=dopts)
    rdDepictor.Compute2DCoords(mol)
    rdDepictor.StraightenDepiction(mol)
    d2d.DrawMolecule(mol)

In [None]:
img = Draw.MolsToGridImage(scaffold_rdmols[:12], 
                               molsPerRow=3, 
                               subImgSize=(200, 200), 
                               # highlightAtomLists=highlight,
                               useSVG= True,
                               legends = [ligand.compound_name for ligand in scaffold_smiles[:12]],
                               drawOptions=dopts)
with open("scaffold_images_3x4.svg", 'w') as f:
        f.write(img.data)

In [None]:
img = Draw.MolsToGridImage(scaffold_rdmols[:12], 
                               molsPerRow=6, 
                               subImgSize=(200, 200), 
                               # highlightAtomLists=highlight,
                               useSVG= True,
                               legends = [ligand.compound_name for ligand in scaffold_smiles[:12]],
                               drawOptions=dopts)
with open("scaffold_images_6x2.svg", 'w') as f:
        f.write(img.data)

In [None]:
# Do the same thing for the x-series