In [33]:
from rdkit import Chem
from rdkit.Chem.rdMMPA import FragmentMol
import pandas as pd
import random
import os
import torch

def check_mmpa_linker(linker_smi, min_size, max_link_size=None):
    mol = Chem.MolFromSmiles(linker_smi)
    num_exits = linker_smi.count('*:')
    return (mol.GetNumAtoms() - num_exits) >= min_size and (mol.GetNumAtoms() - num_exits) <= max_link_size

def check_mmpa_fragment(fragment_smi, min_size):
    mol = Chem.MolFromSmiles(fragment_smi)
    num_exits = fragment_smi.count('*:')
    return (mol.GetNumAtoms() - num_exits) >= min_size

def check_mmpa_fragments(fragments_smi, min_size):
    for fragment_smi in fragments_smi.split('.'):
        if not check_mmpa_fragment(fragment_smi, min_size):
            return False
    return True

def fragment_by_mmpa(mol, mol_name, mol_smiles, min_cuts, max_cuts, min_frag_size, min_link_size, protein_fn, ligand_fn, affinity, max_link_size=None):
    mmpa_results = []
    for i in range(min_cuts, max_cuts + 1):
        mmpa_results += FragmentMol(
            mol,
            minCuts=i,
            maxCuts=i,
            maxCutBonds=100,
            pattern="[#6+0;!$(*=,#[!#6])]!@!=!#[*]",
            resultsAsMols=False
        )
    if min_cuts == 1:
        for i in range(len(mmpa_results)):
            rgroup = mmpa_results[i][1].split('.')[0]
            scaffold = mmpa_results[i][1].split('.')[1]
            if len(rgroup) > len(scaffold):
                tmp = rgroup
                rgroup = scaffold
                scaffold = tmp
            mmpa_results[i] = (rgroup, scaffold)

    filtered_mmpa_results = []
    for linker_smiles, fragments_smiles in mmpa_results:
        if check_mmpa_linker(linker_smiles, min_link_size, max_link_size) and check_mmpa_fragments(fragments_smiles, min_frag_size):
            filtered_mmpa_results.append([mol_name, mol_smiles, linker_smiles, fragments_smiles, 'mmpa', protein_fn, ligand_fn, affinity])
    return filtered_mmpa_results

def mmpa(data, cuts, num, save_path):
    min_frag_size = 5
    min_link_size = 3
    max_link_size = 20

    type_list = [6, 7, 8, 9, 15, 16, 17, 35, 53]

    mol_results = []
    final_dict = {
        'protein_filename': [],
        'ligand_filename': [],
        'retain_smi': [],
        'mask_smi': [],
        'affinity': [],
    }

    crossdock_dir = '/path/to/crossdock2020/crossdocked_pocket10'
    for i in range(len(data)):
        flag = 0
        ligand_fn = os.path.join(crossdock_dir, data[i][0])
        try:
            mol = Chem.MolFromMolFile(ligand_fn)
            mol = Chem.RemoveAllHs(mol)
            Chem.SanitizeMol(mol)
        except:
            continue
        if mol is None:
            continue
        for atom in mol.GetAtoms():
            if atom.GetAtomicNum() not in type_list:
                flag = 1
                break
        if flag == 1:
            continue
        mol_name = data[i][0].split('/')[1]
        if mol.GetNumAtoms() <= 40:
            try:
                res = fragment_by_mmpa(
                    mol,
                    mol_smiles=Chem.MolToSmiles(mol),
                    mol_name=mol_name,
                    min_cuts=cuts,
                    max_cuts=cuts,
                    min_link_size=min_link_size,
                    min_frag_size=min_frag_size,
                    protein_fn = data[i][1],
                    ligand_fn = data[i][0],
                    affinity = data[i][2],
                    max_link_size=max_link_size,
                )
            except:
                continue
            mol_results += res

    random.seed(2024)
    random.shuffle(mol_results)
    table = pd.DataFrame(mol_results, columns=['molecule_name', 'molecule', 'linker', 'fragments', 'method', 'protein_filename', 'ligand_filename', 'affinity'])
    table = table.drop_duplicates(['molecule_name', 'molecule', 'linker'])
    table_dict = table.to_dict(orient='list')
    final_dict['protein_filename'] = table_dict['protein_filename'][:num]
    final_dict['ligand_filename'] = table_dict['ligand_filename'][:num]
    final_dict['retain_smi'] = table_dict['fragments'][:num]
    final_dict['mask_smi'] = table_dict['linker'][:num]
    final_dict['affinity'] = table_dict['affinity'][:num]
    torch.save(final_dict, save_path)

In [None]:
dec_lig_pro_aff_list = torch.load('./demo/dec/dec_lig_pro_aff_list.pt')
mmpa(dec_lig_pro_aff_list, 1, 500, './demo/dec/demo_dict.pt')

In [None]:
linker_lig_pro_aff_list = torch.load('./demo/linker/linker_lig_pro_aff_list.pt')
mmpa(linker_lig_pro_aff_list, 2, 500, './demo/linker/demo_dict.pt')