# MoleculeNet Task Using MolLM Embeddings + Random Forest Example

In [1]:
def load_bbbp(input_df):
    smiles_list = input_df['smiles']
    labels = input_df['p_np']
    labels = labels.replace(0, -1)
    return smiles_list, labels

def load_tox21(input_df):
    smiles_list = input_df['smiles']
    tasks = ['NR-AR', 'NR-AR-LBD', 'NR-AhR', 'NR-Aromatase', 'NR-ER', 'NR-ER-LBD',
       'NR-PPAR-gamma', 'SR-ARE', 'SR-ATAD5', 'SR-HSE', 'SR-MMP', 'SR-p53']
    labels = input_df[tasks]
    labels = labels.replace(0, -1)
    labels = labels.fillna(0)
    return smiles_list, labels

def load_toxcast(input_df):
    smiles_list = input_df['smiles']
    tasks = list(input_df.columns)[1:]
    labels = input_df[tasks]
    labels = labels.replace(0, -1)
    labels = labels.fillna(0)
    return smiles_list, labels

def load_sider(input_df):
    smiles_list = input_df['smiles']
    tasks = ['Hepatobiliary disorders',
       'Metabolism and nutrition disorders', 'Product issues', 'Eye disorders',
       'Investigations', 'Musculoskeletal and connective tissue disorders',
       'Gastrointestinal disorders', 'Social circumstances',
       'Immune system disorders', 'Reproductive system and breast disorders',
       'Neoplasms benign, malignant and unspecified (incl cysts and polyps)',
       'General disorders and administration site conditions',
       'Endocrine disorders', 'Surgical and medical procedures',
       'Vascular disorders', 'Blood and lymphatic system disorders',
       'Skin and subcutaneous tissue disorders',
       'Congenital, familial and genetic disorders',
       'Infections and infestations',
       'Respiratory, thoracic and mediastinal disorders',
       'Psychiatric disorders', 'Renal and urinary disorders',
       'Pregnancy, puerperium and perinatal conditions',
       'Ear and labyrinth disorders', 'Cardiac disorders',
       'Nervous system disorders',
       'Injury, poisoning and procedural complications']
    labels = input_df[tasks]
    labels = labels.replace(0, -1)
    return smiles_list, labels

def load_clintox(input_df):
    smiles_list = input_df['smiles']
    tasks = ['FDA_APPROVED', 'CT_TOX']
    labels = input_df[tasks]
    labels = labels.replace(0, -1)
    return smiles_list, labels

def load_muv(input_df):
    smiles_list = input_df['smiles']
    tasks = ['MUV-466', 'MUV-548', 'MUV-600', 'MUV-644', 'MUV-652', 'MUV-689',
       'MUV-692', 'MUV-712', 'MUV-713', 'MUV-733', 'MUV-737', 'MUV-810',
       'MUV-832', 'MUV-846', 'MUV-852', 'MUV-858', 'MUV-859']
    labels = input_df[tasks]
    labels = labels.replace(0, -1)
    labels = labels.fillna(0)
    return smiles_list, labels

def load_hiv(input_df):
    smiles_list = input_df['smiles']
    labels = input_df['HIV_active']
    labels = labels.replace(0, -1)
    return smiles_list, labels

def load_bace(input_df):
    smiles_list = input_df['mol']
    labels = input_df['Class']
    # convert 0 to -1
    labels = labels.replace(0, -1)
    return smiles_list, labels

datasets = {
    'bbbp': ('BBBP.csv', load_bbbp), #
    'tox21': ('tox21.csv', load_tox21), #
    'toxcast': ('toxcast_data.csv', load_toxcast), #
    'sider': ('sider.csv', load_sider), #
    'clintox': ('clintox.csv', load_clintox),
    'muv': ('muv.csv', load_muv),
    'hiv': ('HIV.csv', load_hiv), #sc
    'bace': ('bace.csv', load_bace)
}

In [3]:
import importlib
import sys

import pandas as pd
import torch
from tqdm import tqdm

dataset_base_path = '../../downstream/MoleculePrediction/dataset'

In [None]:
sys.path.insert(0, '../../downstream/graph-transformer')
MolLMPkg = importlib.import_module("MolLM")
MolLM = MolLMPkg.MolLM

# Same checkpoint as property prediction
model = MolLM('../../downstream/GraphTextRetrieval/all_checkpoints/model-epoch=394.ckpt', '../../downstream/GraphTextRetrieval/', '../../downstream/GraphTextRetrieval/bert_pretrained')
model = model.to('cuda')

for dataset, (dataset_csv, load_dataset_func) in datasets.items():
    if dataset == 'bbbp' or dataset == 'tox21':
        continue
    print(dataset)
    csv_path = f'{dataset_base_path}/{dataset}/raw/{dataset_csv}'
    input_df = pd.read_csv(csv_path, sep=',')
    smiles, labels = load_dataset_func(input_df)
    embeddings = []
    for i, smile in enumerate(tqdm(smiles)):
        try:
            embeddings.append(model.forward_molecule(smile))
        except Exception as e:
            print(f'Failed {dataset}-{i}: {e}')
            embeddings.append(torch.zeros((1, 768)))
    torch.save(embeddings, f'{dataset_base_path}/{dataset}/raw/molm.pt')
    torch.save(labels, f'{dataset_base_path}/{dataset}/raw/molm_labels.pt')

In [5]:

import statistics
from sklearn.multioutput import MultiOutputClassifier
from sklearn.ensemble import RandomForestClassifier
from itertools import compress
import numpy as np
from torch_geometric.data import InMemoryDataset, Data
from rdkit.Chem.Scaffolds import MurckoScaffold
from sklearn.metrics import roc_auc_score


class EmbeddingsDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(EmbeddingsDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['molm.pt', 'molm_labels.pt']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        pass

    def process(self):
        data_list = []

        embeddings = torch.load(f'{self.raw_dir}/molm.pt')
        labels = torch.load(f'{self.raw_dir}/molm_labels.pt')
        labels = labels.values

        for i, (emb, label) in enumerate(zip(embeddings, labels)):
            emb = emb.squeeze()
            if labels.shape == (1,2):
                y = torch.tensor(labels[i, :])
            else:
                y = torch.tensor([labels[i]])

            data = Data(x=emb.to('cpu'), y=y)
            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def len(self):
        return len(self.data.y)

    def get(self, idx):
        data = Data()
    
        for key in self.data.keys:
            item, slices = self.data[key], self.slices[key]
            start, stop = slices[idx].item(), slices[idx + 1].item()
            data[key] = item[start:stop]
    
        return data
    
    
def scaffold_split(dataset, smiles_list, task_idx=None, null_value=0,
                   frac_train=0.8, frac_valid=0.1, frac_test=0.1,
                   return_smiles=False):
    """
    Adapted from https://github.com/deepchem/deepchem/blob/master/deepchem/splits/splitters.py
    Split dataset by Bemis-Murcko scaffolds
    This function can also ignore examples containing null values for a
    selected task when splitting. Deterministic split
    :param dataset: pytorch geometric dataset obj
    :param smiles_list: list of smiles corresponding to the dataset obj
    :param task_idx: column idx of the data.y tensor. Will filter out
    examples with null value in specified task column of the data.y tensor
    prior to splitting. If None, then no filtering
    :param null_value: float that specifies null value in data.y to filter if
    task_idx is provided
    :param frac_train:
    :param frac_valid:
    :param frac_test:
    :param return_smiles:
    :return: train, valid, test slices of the input dataset obj. If
    return_smiles = True, also returns ([train_smiles_list],
    [valid_smiles_list], [test_smiles_list])
    """
    np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0)

    if task_idx != None:
        # filter based on null values in task_idx
        # get task array
        y_task = np.array([data.y[task_idx].item() for data in dataset])
        # boolean array that correspond to non null values
        non_null = y_task != null_value
        smiles_list = list(compress(enumerate(smiles_list), non_null))
    else:
        non_null = np.ones(len(dataset)) == 1
        smiles_list = list(compress(enumerate(smiles_list), non_null))

    # create dict of the form {scaffold_i: [idx1, idx....]}
    all_scaffolds = {}
    for i, smiles in smiles_list:
        scaffold = generate_scaffold(smiles, include_chirality=True)
        if scaffold is None:
            continue
        if scaffold not in all_scaffolds:
            all_scaffolds[scaffold] = [i]
        else:
            all_scaffolds[scaffold].append(i)

    # sort from largest to smallest sets
    all_scaffolds = {key: sorted(value) for key, value in all_scaffolds.items()}
    all_scaffold_sets = [
        scaffold_set for (scaffold, scaffold_set) in sorted(
            all_scaffolds.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
    ]

    # get train, valid test indices
    train_cutoff = frac_train * len(smiles_list)
    valid_cutoff = (frac_train + frac_valid) * len(smiles_list)
    train_idx, valid_idx, test_idx = [], [], []
    for scaffold_set in all_scaffold_sets:
        if len(train_idx) + len(scaffold_set) > train_cutoff:
            if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:
                test_idx.extend(scaffold_set)
            else:
                valid_idx.extend(scaffold_set)
        else:
            train_idx.extend(scaffold_set)

    assert len(set(train_idx).intersection(set(valid_idx))) == 0
    assert len(set(test_idx).intersection(set(valid_idx))) == 0

    train_dataset = dataset[torch.tensor(train_idx)]
    valid_dataset = dataset[torch.tensor(valid_idx)]
    test_dataset = dataset[torch.tensor(test_idx)]

    if not return_smiles:
        return train_dataset, valid_dataset, test_dataset
    else:
        train_smiles = [smiles_list[i][1] for i in train_idx]
        valid_smiles = [smiles_list[i][1] for i in valid_idx]
        test_smiles = [smiles_list[i][1] for i in test_idx]
        return train_dataset, valid_dataset, test_dataset, (train_smiles,
                                                            valid_smiles,
                                                            test_smiles)
    
def generate_scaffold(smiles, include_chirality=False):
    """
    Obtain Bemis-Murcko scaffold from smiles
    :param smiles:
    :param include_chirality:
    :return: smiles of scaffold
    """
    try:
        scaffold = MurckoScaffold.MurckoScaffoldSmiles(
        smiles=smiles, includeChirality=include_chirality)
    except:
        # Invalid molecule, just use a random one
        scaffold = MurckoScaffold.MurckoScaffoldSmiles(
        smiles='CCCC', includeChirality=include_chirality)
    return scaffold
    

for dataset, (dataset_csv, load_dataset_func) in datasets.items():
    cdataset = EmbeddingsDataset(root=f'{dataset_base_path}/hiv/')
    csv_path = f'{dataset_base_path}/{dataset}/raw/{dataset_csv}'
    input_df = pd.read_csv(csv_path, sep=',')
    smiles_list, _ = load_dataset_func(input_df)
    smiles_list = list(smiles_list)
    train_dataset, valid_dataset, test_dataset = scaffold_split(cdataset, smiles_list, null_value=0, frac_train=0.8,frac_valid=0.1, frac_test=0.1)
    
    def extract_features_labels(dataset):
        X = []
        Y = []
        for data in dataset:
            X.append(data.x.numpy())
            Y.append((data.y.numpy() + 1) // 2)
        return np.array(X), np.array(Y)
    
    X_train, Y_train = extract_features_labels(train_dataset)
    X_test, Y_test = extract_features_labels(test_dataset)
    # Found after grid search on validation set
    hyperparams = {
        'bbbp': (100, 20),
        'tox21': (50, 20),
        'toxcast': (100, 9),
        'sider': (150, 9),
        'clintox': (50, 20),
        'muv': (100, 20),
        'hiv': (100, 9),
        'bace': (100, 20)
    }
    n_estimators, max_depth = hyperparams[dataset]
    scores = []
    for _ in range(15):
        forest = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
        multi_output_rf = MultiOutputClassifier(forest, n_jobs=-1)
        multi_output_rf.fit(X_train, Y_train)
        Y_pred = multi_output_rf.predict(X_test)
        # Calculate the accuracy
        Y_scores = multi_output_rf.predict_proba(X_test)
    
        # Convert Y_scores to the correct format
        Y_scores_formatted = np.array([Y_scores[i][:, 1] for i in range(len(Y_scores))]).T
        
        roc_list = []
        for i in range(Y_test.shape[1]):
            # Check that there is at least one positive and one negative example
            if np.sum(Y_test[:, i] == 1) > 0 and np.sum(Y_test[:, i] == 0) > 0:
                # Calculate ROC AUC score for the i-th label
                roc_list.append(roc_auc_score(Y_test[:, i], Y_scores_formatted[:, i]))
        
        # Handle case where some labels may not have both positive and negative examples
        if len(roc_list) < Y_test.shape[1]:
            print("Some target is missing!")
            print("Missing ratio: %f" % (1 - float(len(roc_list)) / Y_test.shape[1]))
    
        # Calculate the average ROC AUC score across all labels
        average_roc_auc = sum(roc_list) / len(roc_list)
        scores.append(average_roc_auc)
    avg_score = sum(scores) / 15
    stddev = statistics.stdev(scores) 
    print(f"{dataset} :: {avg_score * 100:.2f} +- {stddev * 100:.1f}")

[15:56:13] Explicit valence for atom # 1 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 6 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 6 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 11 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 12 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 5 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 5 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 5 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 5 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 5 N, 4, is greater than permitted
[15:56:13] Explicit valence for atom # 5 N, 4, is greater than permitted


bbbp :: 66.72 +- 5.5




tox21 :: 80.11 +- 1.8


[16:00:23] Explicit valence for atom # 0 F, 2, is greater than permitted
[16:00:23] Explicit valence for atom # 2 Cl, 2, is greater than permitted
[16:00:23] Explicit valence for atom # 0 Cl, 2, is greater than permitted
[16:00:23] Explicit valence for atom # 3 Si, 8, is greater than permitted
[16:00:23] Explicit valence for atom # 3 Si, 8, is greater than permitted
[16:00:24] SMILES Parse Error: syntax error while parsing: FAIL
[16:00:24] SMILES Parse Error: Failed parsing SMILES 'FAIL' for input: 'FAIL'
[16:00:24] SMILES Parse Error: syntax error while parsing: FAIL
[16:00:24] SMILES Parse Error: Failed parsing SMILES 'FAIL' for input: 'FAIL'
[16:00:24] SMILES Parse Error: syntax error while parsing: FAIL
[16:00:24] SMILES Parse Error: Failed parsing SMILES 'FAIL' for input: 'FAIL'
[16:00:24] SMILES Parse Error: syntax error while parsing: FAIL
[16:00:24] SMILES Parse Error: Failed parsing SMILES 'FAIL' for input: 'FAIL'
[16:00:24] SMILES Parse Error: syntax error while parsing: FAIL

toxcast :: 88.56 +- 0.9




sider :: 80.80 +- 4.5


[16:04:12] Explicit valence for atom # 0 N, 5, is greater than permitted
[16:04:12] Can't kekulize mol.  Unkekulized atoms: 9
[16:04:12] Explicit valence for atom # 10 N, 4, is greater than permitted
[16:04:12] Explicit valence for atom # 10 N, 4, is greater than permitted
[16:04:12] Can't kekulize mol.  Unkekulized atoms: 4
[16:04:12] Can't kekulize mol.  Unkekulized atoms: 4


clintox :: 70.71 +- 1.9




muv :: 74.33 +- 0.8




hiv :: 77.02 +- 0.7




bace :: 70.55 +- 4.2
