In [1]:
#!/usr/bin/env python3
import copy
import os
import random

from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
from scipy.cluster.hierarchy import dendrogram, linkage, leaves_list, cut_tree
from rdkit.Chem import AllChem
from collections import Counter
import pandas as pd
import numpy as np
from collections import defaultdict
import joblib

In [2]:

# copy from xiong et al. attentivefp
class ScaffoldGenerator(object):
    """
    Generate molecular scaffolds.
    Parameters
    ----------
    include_chirality : : bool, optional (default False)
        Include chirality in scaffolds.
    """

    def __init__(self, include_chirality=False):
        self.include_chirality = include_chirality

    def get_scaffold(self, mol):
        """
        Get Murcko scaffolds for molecules.
        Murcko scaffolds are described in DOI: 10.1021/jm9602928. They are
        essentially that part of the molecule consisting of rings and the
        linker atoms between them.
        Parameters
        ----------
        mols : array_like
            Molecules.
        """
        return MurckoScaffold.MurckoScaffoldSmiles(
            mol=mol, includeChirality=self.include_chirality)


# copy from xiong et al. attentivefp
def generate_scaffold(smiles, include_chirality=False):
    """Compute the Bemis-Murcko scaffold for a SMILES string."""
    mol = Chem.MolFromSmiles(smiles)
    engine = ScaffoldGenerator(include_chirality=include_chirality)
    scaffold = engine.get_scaffold(mol)
    return scaffold



In [3]:
df = pd.read_csv(os.path.join('./datasets/twoside/twosides_ge_500.csv'))
scaffolds = {}
print(len(df))

4576287


In [4]:
drug_set = set()

for i in range(len(df)):
    drug_set.add(df.loc[i, 'Drug1'])
    drug_set.add(df.loc[i, 'Drug2'])

print(len(drug_set))

for d in drug_set:
    try:
        scaffold = generate_scaffold(d)
        if scaffolds.__contains__(scaffold):
            scaffolds[scaffold] = scaffolds[scaffold] + 1
        else:
            scaffolds[scaffold] = 1
    except:
        print("error", d)
        # df.drop(index=i, inplace=True)
        continue
# print(len(df))
# df = df.reset_index()

645
4576287


In [5]:
smile_scafold = {}
for d in drug_set:
    smile_scafold[d] = generate_scaffold(d)

In [7]:
all_key = scaffolds.keys()
print(len(all_key))

415


In [22]:
# all_key = scaffolds.keys()
# print(all_key)
train_scaffold = random.sample(all_key, round(len(all_key) * 0.88))

train_idx = []
test1_idx = []
test2_idx = []
for i in range(len(df)):
    if train_scaffold.__contains__(smile_scafold[df.loc[i, 'Drug1']]) \
            and train_scaffold.__contains__(smile_scafold[df.loc[i, 'Drug2']]):
        train_idx.append(i)
    elif not train_scaffold.__contains__(smile_scafold[df.loc[i, 'Drug1']]) \
            and not train_scaffold.__contains__(smile_scafold[df.loc[i, 'Drug2']]):
        test2_idx.append(i)
    else:
        test1_idx.append(i)
print(len(train_idx), len(test1_idx), len(test2_idx), len(train_idx)+len(test1_idx)+len(test2_idx))
if len(train_idx) > 3600000:
    df_old = df.loc[train_idx]
    df_old = df_old.reset_index()
    df_old.to_csv('./datasets/twoside/twoside_train_val.csv')

    df_test1_idx = df.loc[test1_idx]
    df_test1_idx = df_test1_idx.reset_index()
    df_test1_idx.to_csv('./datasets/twoside/twoside_test1.csv')

    df_test2_idx = df.loc[test2_idx]
    df_test2_idx = df_test2_idx.reset_index()
    df_test2_idx.to_csv('./datasets/twoside/twoside_test2.csv')
    print("done")


3639019 885260 52008 4576287
done


In [31]:
def search_index(unique_smiles, df, num_class, num_limit):

    vec_list = []
    for smi in unique_smiles:
        m1 = Chem.MolFromSmiles(smi)
        fp4 = list(AllChem.GetMorganFingerprintAsBitVect(m1, radius=2, nBits=256))
        vec_list.append(fp4)
    print("drug num", len(vec_list))
    Z = linkage(vec_list, 'average', metric='jaccard')
    cluster = cut_tree(Z, num_class).ravel()
    stat_dict = {k: v for k, v in sorted(Counter(cluster).items(), key=lambda item: item[1], reverse=True)}

    num = 0
    data_dict = defaultdict(list)
    for k,v in stat_dict.items():
        pos = np.nonzero(cluster==k)[0]
        # print(k, stat_dict[k], len(pos))
        smi_idx = []
        for idx in pos:
            smi_single = df[df["Drug1"] == unique_smiles[idx]]
            smi_idx.append(smi_single)
        df_tmp = pd.concat(smi_idx)
        num += len(df_tmp)
        data_dict[k] = df_tmp
    print("@@@@@@@@@@@", len(data_dict.keys()), num)

    num = 0
    all_keys = list(data_dict.keys())
    class_num = -1
    meat_class = {}
    for k,v in data_dict.items():
        if len(v) > num_limit:
            class_num += 1
            meat_class[class_num] = v
            num += len(v)
            all_keys.remove(k)

    random.shuffle(all_keys)

    smi_idx = []
    smi_idx_num = 0
    for i,k in enumerate(all_keys):
        # print(i, len(data_dict[k]))
        if smi_idx_num < num_limit:
            smi_idx.append(data_dict[k])
            smi_idx_num += len(data_dict[k])
        else:
            class_num += 1
            meat_class[class_num] = pd.concat(smi_idx)
            num += len(meat_class[class_num])

            smi_idx = []
            smi_idx_num = 0
            smi_idx.append(data_dict[k])
            smi_idx_num += len(data_dict[k])

        if i == len(all_keys) -1:
            class_num += 1
            meat_class[class_num] = pd.concat(smi_idx)
            num += len(meat_class[class_num])


    print(class_num, len(meat_class[class_num]),num)

    if len(meat_class[class_num]) < 10:
        meat_class.pop(class_num)

    num = 0
    for k,v in meat_class.items():
        num += len(v)
    print(num)

    return meat_class

In [34]:
df_train = pd.read_csv('datasets/twoside/twoside_train_val.csv')
unique_smi = set(df_train["Drug1"].unique())
unique_smi_aa = unique_smi.union(set(df_train["Drug2"].unique()))
meat_class = search_index(list(unique_smi_aa), df_train, 100, 150)
# meat_class = search_index(unique_smi, df, 100, 100)
print(len(meat_class.keys()))
meta_train = {}
meta_train_num = 0
meta_train_k_num = 0
meta_val = {}
meta_val_num = 0
meta_val_k_num = 0
meta_keys = list(meat_class.keys())
random.shuffle(meta_keys)
for k in meta_keys:
    if len(meta_train.keys()) < len(meta_keys) *0.8:
        meta_train[k] = meat_class[k]
        meta_train_num += len(meat_class[k])
        meta_train_k_num += 1
    else:
        meta_val_k_num +=1
        meta_val[k] = meat_class[k]
        meta_val_num += len(meat_class[k])
print(meta_train_num, meta_train_k_num,meta_val_num,meta_val_k_num, meta_train_num+meta_val_num)
joblib.dump(meta_train, "datasets/twoside/meta_train.pkl")
joblib.dump(meta_val, "datasets/twoside/meta_val.pkl")


drug num 578
@@@@@@@@@@@ 100 3639019
94 164 3639019
3639019
95
2724872 76 914147 19 3639019


['datasets/twoside/twoside_val.pkl']

In [38]:
df = pd.read_csv('datasets/twoside/twoside_test1.csv')
# df = pd.read_csv('datasets/twoside/twoside_test2.csv')
unique_smi = df["Drug1"].unique()
meat_class = search_index(unique_smi, df, 50, 200)
joblib.dump(meat_class, "datasets/twoside/meta_test1.pkl")
# joblib.dump(meat_class, "datasets/twoside/meta_test2.pkl")

drug num 535
@@@@@@@@@@@ 50 885260
47 170 885260
885260


['datasets/twoside/twoside_test1.pkl']

In [40]:
meta_train = joblib.load("datasets/twoside/twoside_train.pkl")
meta_val = joblib.load("datasets/twoside/twoside_val.pkl")
train_pd = []
for k,v in meta_train.items():
    train_pd.append(v)
df_tmp = pd.concat(train_pd)
df_tmp = df_tmp.reset_index(drop=True)
df_tmp.to_csv('./datasets/twoside/twoside_train.csv')
print(len(df_tmp))

test_pd = []
for k,v in meta_val.items():
    test_pd.append(v)
df_tmp = pd.concat(test_pd)
df_tmp = df_tmp.reset_index(drop=True)
df_tmp.to_csv('./datasets/twoside/twoside_val.csv')

print(len(df_tmp))

2724872
914147
