In [None]:
from rdkit.Chem.Scaffolds import MurckoScaffold
from asapdiscovery.data.schema.ligand import Ligand
from asapdiscovery.data.readers.molfile import MolFileFactory
from rdkit import Chem

from pydantic import BaseModel
from abc import abstractmethod
PATT = Chem.MolFromSmarts("[$([D1]=[*])]")
REPL = Chem.MolFromSmarts("[*]")

In [None]:
class BaseBemisMurckoScaffold(BaseModel):
    name: str

    @abstractmethod
    def run(self, ligand: Ligand) -> str:
        """
        Run the Bemis-Murcko clustering on the ligands
        """
        pass


class DefaultRDKitBemisMurckoScaffold(BaseBemisMurckoScaffold):
    name = "RDKit Bemis-Murcko"

    def run(self, ligand: Ligand) -> str:
        """
        Run the Bemis-Murcko clustering on the ligands
        """
        mol = ligand.to_rdkit()
        scaff = MurckoScaffold.GetScaffoldForMol(mol)
        return Chem.MolToSmiles(scaff)


class BajorathBemisMurckoScaffold(BaseBemisMurckoScaffold):
    name = "Bajorath Bemis-Murcko"

    def run(self, ligand: Ligand) -> str:
        """
        Run the Bemis-Murcko clustering on the ligands
        """
        mol = ligand.to_rdkit()
        scaff = MurckoScaffold.GetScaffoldForMol(mol)
        scaff = Chem.rdmolops.DeleteSubstructs(scaff, PATT)
        return Chem.MolToSmiles(scaff)


class GenericBemisMurckoScaffold(BaseBemisMurckoScaffold):
    name = "RDKit Generic"

    def run(self, ligand: Ligand) -> str:
        """
        Run the Bemis-Murcko clustering on the ligands
        :param ligands:
        :return:
        """
        mol = ligand.to_rdkit()
        scaff = MurckoScaffold.GetScaffoldForMol(mol)
        scaff = MurckoScaffold.MakeScaffoldGeneric(scaff)
        return Chem.MolToSmiles(scaff)


class CSKBemisMurckoScaffold(BaseBemisMurckoScaffold):
    name = "Cyclic Skeletons"

    def run(self, ligand: Ligand) -> str:
        """
        Run the Bemis-Murcko clustering on the ligands
        :param ligands:
        :return:
        """
        mol = ligand.to_rdkit()
        scaff = MurckoScaffold.GetScaffoldForMol(mol)
        scaff = Chem.rdmolops.ReplaceSubstructs(scaff, PATT, REPL, replaceAll=True)[0]
        scaff = MurckoScaffold.MakeScaffoldGeneric(scaff)
        scaff = MurckoScaffold.GetScaffoldForMol(scaff)

        return Chem.MolToSmiles(scaff)

In [None]:
input_smi = "/Users/alexpayne/Scientific_Projects/asapdiscovery-sars-retrospective/science/20241025_ligand_analysis/data/unique_compounds.smi"

In [None]:
mff = MolFileFactory(filename=input_smi)
ligands = mff.load()

scaffold_types = [
    DefaultRDKitBemisMurckoScaffold(),
    BajorathBemisMurckoScaffold(),
    GenericBemisMurckoScaffold(),
    CSKBemisMurckoScaffold(),
]

In [None]:
ligand = ligands[0]

In [None]:
ligand

In [None]:
scaffolds = {}
for ligand in ligands:
    ligs = [(ligand.compound_name, ligand.to_rdkit())]
    for scaffold_type in scaffold_types:
        ligs.append((f"{scaffold_type.name}", Chem.MolFromSmiles(scaffold_type.run(ligand))))
    scaffolds[ligand.compound_name] = ligs

In [None]:
def draw_scaffolds(
    scaffold_list: list[(str, Chem.Mol)], first_n=-1, mols_per_row=-1, use_svg=True
):
    from rdkit.Chem import Draw, rdDepictor

    scaffold_rdmols = [Chem.RemoveHs(mol[1]) for mol in scaffold_list]

    # Set Draw Options
    dopts = Draw.rdMolDraw2D.MolDrawOptions()
    dopts.setHighlightColour((68 / 256, 178 / 256, 212 / 256))
    dopts.highlightBondWidthMultiplier = 16
    d2d = Draw.MolDraw2DCairo(1600, 1600)
    print("Preparing depictions")
    for mol in scaffold_rdmols[:first_n]:
        Draw.MolToImage(mol, size=(400, 400), options=dopts)
        rdDepictor.Compute2DCoords(mol)
        rdDepictor.StraightenDepiction(mol)
        d2d.DrawMolecule(mol)
    print("Creating image")
    print(first_n, scaffold_rdmols)
    img = Draw.MolsToGridImage(
        scaffold_rdmols[:first_n],
        molsPerRow=mols_per_row,
        subImgSize=(200, 200),
        useSVG=use_svg,
        legends=[mol[0] for mol in scaffold_list[:first_n]],
        drawOptions=dopts,
    )
    return img

In [None]:
scaffolds

In [None]:
import pandas as pd
df = pd.read_csv("data/csk_cluster_labels.csv")

In [None]:
separated_mols = df.groupby('cluster_id')['compound_name'].head(1).to_list()

In [None]:
compound_name = 'AAR-POS-d2a4d1df-38'

mols_to_plot = []
for i, compound_name in enumerate(separated_mols):
    if i % 10 == 0:
        if compound_name in scaffolds.keys():
            if all([scaffold[1] is not None for scaffold in scaffolds[compound_name]]):
                mols_to_plot.extend(scaffolds[compound_name])

In [None]:
base_n_examples = 5
n_rows = 3
mols_per_row = base_n_examples * 1
first_n = mols_per_row * n_rows
img = draw_scaffolds(mols_to_plot, first_n=first_n, mols_per_row=mols_per_row, use_svg=False)
# img.save(f"{compound_name}_scaffold_types.png")

with open(f"scaffold_examples.png", "wb") as f:
    f.write(img.data)

img = draw_scaffolds(mols_to_plot, first_n=first_n, mols_per_row=mols_per_row, use_svg=True)
with open(f"scaffold_examples.svg", "w") as f:
    f.write(img.data)