# Dataset Splitting Strategy Notebook

This notebook demonstrate how different settings were generated for the systematic study of a method's generalizability. With a specified split ratio, running through the whole notebook can give us the random split, unseen ligand scaffold split and unseen protein split setting for a given dataset. For unseen protein-ligand split, we refer our reader to the DrugBAN work (https://github.com/peizhenbai/DrugBAN). 

In [1]:
## Dataset csv file
full_dataset_path = 'human_dataset.csv'

## some initial parameters
split_ratio = [0.7,0.1,0.2] 
output_directory = 'human'
seed = 8

In [2]:
## import libraries
import pandas as pd

In [3]:
full = pd.read_csv(full_dataset_path)

## Random Split

In [4]:
def create_fold(df, fold_seed, frac):
    """create random split

    Args:
        df (pd.DataFrame): dataset dataframe
        fold_seed (int): the random seed
        frac (list): a list of train/valid/test fractions

    Returns:
        dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
    """
    train_frac, val_frac, test_frac = frac
    test = df.sample(frac=test_frac, replace=False, random_state=fold_seed)
    train_val = df[~df.index.isin(test.index)]
    val = train_val.sample(
        frac=val_frac / (1 - test_frac), replace=False, random_state=1
    )
    train = train_val[~train_val.index.isin(val.index)]

    return {
        "train": train.reset_index(drop=True),
        "valid": val.reset_index(drop=True),
        "test": test.reset_index(drop=True),
    }

In [5]:
splits = create_fold(full, seed, split_ratio)

In [6]:
import os 
random_dir = os.path.join(output_directory, 'random')
if not os.path.exists(random_dir):
    os.makedirs(random_dir)

splits['train'].to_csv(os.path.join(random_dir,'train.csv'),index=False)
splits['valid'].to_csv(os.path.join(random_dir,'valid.csv'),index=False)
splits['test'].to_csv(os.path.join(random_dir,'test.csv'),index=False)

In [7]:
## CHECK ratio - random split usually gives almost the same ratio as the original specification
print(pd.read_csv(os.path.join(random_dir,'train.csv')).shape[0]/full.shape[0])
print(pd.read_csv(os.path.join(random_dir,'valid.csv')).shape[0]/full.shape[0])
print(pd.read_csv(os.path.join(random_dir,'test.csv')).shape[0]/full.shape[0])

0.7000166750041688
0.10005002501250625
0.19993329998332499


## Unseen Ligand Scaffolds

In [8]:
def create_scaffold_split(df, seed, frac, entity):
    """create scaffold split. it first generates molecular scaffold for each molecule and then split based on scaffolds
    reference: https://github.com/chemprop/chemprop/blob/master/chemprop/data/scaffold.py

    Args:
        df (pd.DataFrame): dataset dataframe
        fold_seed (int): the random seed
        frac (list): a list of train/valid/test fractions
        entity (str): the column name for where molecule stores

    Returns:
        dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
    """

    try:
        from rdkit import Chem
        from rdkit.Chem.Scaffolds import MurckoScaffold
        from rdkit import RDLogger

        RDLogger.DisableLog("rdApp.*")
    except:
        raise ImportError(
            "Please install rdkit by 'conda install -c conda-forge rdkit'! "
        )
    from tqdm import tqdm
    from random import Random

    from collections import defaultdict

    random = Random(seed)

    s = df[entity].values
    scaffolds = defaultdict(set)
    idx2mol = dict(zip(list(range(len(s))), s))

    error_smiles = 0
    for i, smiles in tqdm(enumerate(s), total=len(s)):
        try:
            scaffold = MurckoScaffold.MurckoScaffoldSmiles(
                mol=Chem.MolFromSmiles(smiles), includeChirality=False
            )
            scaffolds[scaffold].add(i)
        except:
            print_sys(smiles + " returns RDKit error and is thus omitted...")
            error_smiles += 1

    train, val, test = [], [], []
    train_size = int((len(df) - error_smiles) * frac[0])
    val_size = int((len(df) - error_smiles) * frac[1])
    test_size = (len(df) - error_smiles) - train_size - val_size
    train_scaffold_count, val_scaffold_count, test_scaffold_count = 0, 0, 0

    # index_sets = sorted(list(scaffolds.values()), key=lambda i: len(i), reverse=True)
    index_sets = list(scaffolds.values())
    big_index_sets = []
    small_index_sets = []
    for index_set in index_sets:
        if len(index_set) > val_size / 2 or len(index_set) > test_size / 2:
            big_index_sets.append(index_set)
        else:
            small_index_sets.append(index_set)
    random.seed(seed)
    random.shuffle(big_index_sets)
    random.shuffle(small_index_sets)
    index_sets = big_index_sets + small_index_sets

    if frac[2] == 0:
        for index_set in index_sets:
            if len(train) + len(index_set) <= train_size:
                train += index_set
                train_scaffold_count += 1
            else:
                val += index_set
                val_scaffold_count += 1
    else:
        for index_set in index_sets:
            if len(train) + len(index_set) <= train_size:
                train += index_set
                train_scaffold_count += 1
            elif len(val) + len(index_set) <= val_size:
                val += index_set
                val_scaffold_count += 1
            else:
                test += index_set
                test_scaffold_count += 1

    return {
        "train": df.iloc[train].reset_index(drop=True),
        "valid": df.iloc[val].reset_index(drop=True),
        "test": df.iloc[test].reset_index(drop=True),
    }

In [9]:
scaffold_splits = create_scaffold_split(full, seed, split_ratio,'Ligand')

100%|██████████| 5997/5997 [00:02<00:00, 2593.37it/s]


In [10]:
import os 
scaffold_dir = os.path.join(output_directory, 'scaffold')
if not os.path.exists(scaffold_dir):
    os.makedirs(scaffold_dir)

scaffold_splits['train'].to_csv(os.path.join(scaffold_dir,'train.csv'),index=False)
scaffold_splits['valid'].to_csv(os.path.join(scaffold_dir,'valid.csv'),index=False)
scaffold_splits['test'].to_csv(os.path.join(scaffold_dir,'test.csv'),index=False)

In [11]:
## CHECK ratio - scaffold split usually gives almost the same ratio as the original specification as well
print(pd.read_csv(os.path.join(scaffold_dir,'train.csv')).shape[0]/full.shape[0])
print(pd.read_csv(os.path.join(scaffold_dir,'valid.csv')).shape[0]/full.shape[0])
print(pd.read_csv(os.path.join(scaffold_dir,'test.csv')).shape[0]/full.shape[0])

0.6998499249624812
0.09988327497081874
0.20026680006670003


## Unseen protein split

In [12]:
def create_fold_setting_cold(df, fold_seed, frac, entities):
    """create cold-split where given one or multiple columns, it first splits based on
    entities in the columns and then maps all associated data points to the partition

    Args:
            df (pd.DataFrame): dataset dataframe
            fold_seed (int): the random seed
            frac (list): a list of train/valid/test fractions
            entities (Union[str, List[str]]): either a single "cold" entity or a list of
                    "cold" entities on which the split is done

    Returns:
            dict: a dictionary of splitted dataframes, where keys are train/valid/test and values correspond to each dataframe
    """
    if isinstance(entities, str):
        entities = [entities]

    train_frac, val_frac, test_frac = frac

    # For each entity, sample the instances belonging to the test datasets
    test_entity_instances = [
        df[e]
        .drop_duplicates()
        .sample(frac=test_frac, replace=False, random_state=fold_seed)
        .values
        for e in entities
    ]

    # Select samples where all entities are in the test set
    test = df.copy()
    for entity, instances in zip(entities, test_entity_instances):
        test = test[test[entity].isin(instances)]

    if len(test) == 0:
        raise ValueError(
            "No test samples found. Try another seed, increasing the test frac or a "
            "less stringent splitting strategy."
        )

    # Proceed with validation data
    train_val = df.copy()
    for i, e in enumerate(entities):
        train_val = train_val[~train_val[e].isin(test_entity_instances[i])]

    val_entity_instances = [
        train_val[e]
        .drop_duplicates()
        .sample(frac=val_frac / (1 - test_frac), replace=False, random_state=fold_seed)
        .values
        for e in entities
    ]
    val = train_val.copy()
    for entity, instances in zip(entities, val_entity_instances):
        val = val[val[entity].isin(instances)]

    if len(val) == 0:
        raise ValueError(
            "No validation samples found. Try another seed, increasing the test frac "
            "or a less stringent splitting strategy."
        )

    train = train_val.copy()
    for i, e in enumerate(entities):
        train = train[~train[e].isin(val_entity_instances[i])]

    return {
        "train": train.reset_index(drop=True),
        "valid": val.reset_index(drop=True),
        "test": test.reset_index(drop=True),
    }

In [13]:
protein_splits = create_fold_setting_cold(full,seed,split_ratio,'Protein')

In [14]:
import os 
protein_dir = os.path.join(output_directory, 'protein')
if not os.path.exists(protein_dir):
    os.makedirs(protein_dir)

protein_splits['train'].to_csv(os.path.join(protein_dir,'train.csv'),index=False)
protein_splits['valid'].to_csv(os.path.join(protein_dir,'valid.csv'),index=False)
protein_splits['test'].to_csv(os.path.join(protein_dir,'test.csv'),index=False)

In [15]:
## CHECK ratio - protein split usually gives slight difference in ratio...
print(pd.read_csv(os.path.join(protein_dir,'train.csv')).shape[0]/full.shape[0])
print(pd.read_csv(os.path.join(protein_dir,'valid.csv')).shape[0]/full.shape[0])
print(pd.read_csv(os.path.join(protein_dir,'test.csv')).shape[0]/full.shape[0])

0.6706686676671669
0.10005002501250625
0.22928130732032684
