In [1]:
from typing import List, Tuple, Dict

import random
import pandas as pd
import numpy as np
import torch
from os.path import join
import datamol as dm

In [2]:
BASE_PATH = "../neurips2023/large-dataset/"

In [3]:
df_L1000_VCAP = pd.read_csv(join(BASE_PATH, "LINCS_L1000_VCAP_0-4.csv.gz"))
print("df_L1000_VCAP.shape", df_L1000_VCAP.shape)
df_L1000_MCF7 = pd.read_csv(join(BASE_PATH, "LINCS_L1000_MCF7_0-4.csv.gz"))
print("df_L1000_MCF7.shape", df_L1000_MCF7.shape)
df_PCBA = pd.read_parquet(join(BASE_PATH, "PCBA_1328_1564k.parquet"))
print("df_PCBA.shape", df_PCBA.shape)
df_PCQM4M = pd.read_parquet(join(BASE_PATH, "PCQM4M_G25_N4.parquet"))
print("df_PCQM4M.shape", df_PCQM4M.shape)

df_L1000_VCAP.shape (15220, 984)
df_L1000_MCF7.shape (11622, 984)
df_PCBA.shape (1563664, 1332)
df_PCQM4M.shape (3810323, 31)


In [4]:
def smiles_to_unique_ids(smiles: List):
    return dm.parallelized_with_batches(loop_smiles_to_unique_ids, smiles, batch_size=100, n_jobs=32, progress=True)

def loop_smiles_to_unique_ids(smiles: List):
    unique_ids = []
    for s in smiles:
        if not isinstance(s, str):
            unique_ids.append(None)
            continue
        mol = dm.to_mol(s)
        if mol is None:
            unique_ids.append(None)
        else:
            unique_ids.append(dm.unique_id(mol))
    return unique_ids

In [5]:
unique_ids_QM = smiles_to_unique_ids(df_PCQM4M["ordered_smiles"])
unique_ids_L1000_VCAP = smiles_to_unique_ids(df_L1000_VCAP["SMILES"])
unique_ids_L1000_MCF7 = smiles_to_unique_ids(df_L1000_MCF7["SMILES"])
unique_ids_PCBA = smiles_to_unique_ids(df_PCBA["SMILES"])

  0%|          | 0/38103 [00:00<?, ?it/s]



  0%|          | 0/152 [00:00<?, ?it/s]

[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax er

  0%|          | 0/116 [00:00<?, ?it/s]

[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax error while parsing: restricted
[16:03:49] SMILES Parse Error: Failed parsing SMILES 'restricted' for input: 'restricted'
[16:03:49] SMILES Parse Error: syntax er

  0%|          | 0/15636 [00:00<?, ?it/s]



In [6]:
# Check the number of unique ids that intersect between unique_ids_QM and the other columns
intersection_VCAP = set(unique_ids_QM) & set(unique_ids_L1000_VCAP)
print("L1000_VCAP", len(intersection_VCAP))

intersection_MCF7 = set(unique_ids_QM) & set(unique_ids_L1000_MCF7)
print("L1000_MCF7", len(intersection_MCF7))

intersection_PCBA = set(unique_ids_QM) & set(unique_ids_PCBA)
print("PCBA", len(intersection_PCBA))

L1000_VCAP 726
L1000_MCF7 1023
PCBA 56512


In [7]:
def find_indices(list_1, list_2):
    intersection = set(list_1) & set(list_2)
    intersection = {elem for elem in intersection if elem is not None}
    is_2_in_1 = np.isin(list_2, list(intersection))
    is_1_in_2 = np.isin(list_1, list(intersection))
    return is_1_in_2, is_2_in_1

QM_in_VCAP, VCAP_in_QM = find_indices(unique_ids_QM, unique_ids_L1000_VCAP)
print("QM_in_VCAP", sum(QM_in_VCAP), "VCAP_in_QM", sum(VCAP_in_QM))

QM_in_MCF7, MCF7_in_QM = find_indices(unique_ids_QM, unique_ids_L1000_MCF7)
print("QM_in_MCF7", sum(QM_in_MCF7), "MCF7_in_QM", sum(MCF7_in_QM))

QM_in_PCBA, PCBA_in_QM = find_indices(unique_ids_QM, unique_ids_PCBA)
print("QM_in_PCBA", sum(QM_in_PCBA), "PCBA_in_QM", sum(PCBA_in_QM))

QM_in_VCAP 1443 VCAP_in_QM 748
QM_in_MCF7 2373 MCF7_in_QM 1065
QM_in_PCBA 91174 PCBA_in_QM 56512


In [8]:
test_seen_VCAP = np.where(VCAP_in_QM)[0].tolist()
test_seen_MCF7 = np.where(MCF7_in_QM)[0].tolist()
test_seen_PCBA = np.where(PCBA_in_QM)[0].tolist()

train_QM_seen = np.where(QM_in_VCAP | QM_in_MCF7 | QM_in_PCBA)[0].tolist()

In [9]:
def make_random_splits(num_elem, val_ratio, test_ratio, seed=42, ignore_idx=None):
    """Make random splits of the data into train, validation, and test sets.

    Args:
        num_elem (int): Number of elements in the dataset.
        val_ratio (float): Ratio of the dataset to use for validation.
        test_ratio (float): Ratio of the dataset to use for testing.
        seed (int): Random seed.
        ignore_idx (list): List of indices to ignore.

    Returns:
        train_idx (list): List of indices for the training set.
        val_idx (list): List of indices for the validation set.
        test_idx (list): List of indices for the test set.
    """
    # Create a list of indices
    idx = list(range(num_elem))
    # Remove the indices to ignore
    if ignore_idx is not None:
        idx = list(set(idx) - set(ignore_idx))
        num_elem = len(idx)
    # Shuffle the list of indices
    random.seed(seed)
    random.shuffle(idx)
    # Compute the number of elements in each set
    num_val = int(num_elem * val_ratio)
    num_test = int(num_elem * test_ratio)
    num_train = num_elem - num_val - num_test
    # Split the list of indices into three sets
    train_idx = idx[:num_train]
    val_idx = idx[num_train:(num_train + num_val)]
    test_idx = idx[(num_train + num_val):]
    # Return the three lists of indices
    return train_idx, val_idx, test_idx

In [10]:
train_VCAP, val_VCAP, test_VCAP = make_random_splits(len(df_L1000_VCAP), 0.04, 0.04, seed=42, ignore_idx=test_seen_VCAP)
train_MCF7, val_MCF7, test_MCF7 = make_random_splits(len(df_L1000_MCF7), 0.04, 0.04, seed=42, ignore_idx=test_seen_MCF7)
train_PCBA, val_PCBA, test_PCBA = make_random_splits(len(df_PCBA), 0.04, 0.04, seed=42, ignore_idx=test_seen_PCBA)
train_QM, val_QM, test_QM = make_random_splits(len(df_PCQM4M), 0.04, 0.04, seed=42, ignore_idx=train_QM_seen)
train_QM = np.concatenate((train_QM, train_QM_seen))

In [11]:
def make_random_splits_file(df, out_file, train_idx, val_idx, test_idx, test_seen_idx=None):
    # Save the splits
    if test_seen_idx is None:
        splits_dict = {"train": train_idx, "val": val_idx, "test": test_idx}
    else:
        splits_dict = {"train": train_idx, "val": val_idx, "test": test_idx, "test_seen": test_seen_idx}

    # Check the splits validity
    
    assert len(set(train_idx).intersection(set(val_idx))) == 0
    assert len(set(train_idx).intersection(set(test_idx))) == 0
    assert len(set(val_idx).intersection(set(test_idx))) == 0
    assert len(train_idx) > 0
    assert len(val_idx) > 0
    assert len(test_idx) > 0

    if test_seen_idx is None:
        assert len(df) == len(train_idx) + len(val_idx) + len(test_idx), f"{len(df)} != {len(train_idx)} + {len(val_idx)} + {len(test_idx)}"
        print(out_file, "train", len(train_idx), "val", len(val_idx), "test", len(test_idx))
    else:
        assert len(test_seen_idx) > 0
        assert len(set(train_idx).intersection(set(test_seen_idx))) == 0
        assert len(set(val_idx).intersection(set(test_seen_idx))) == 0
        assert len(set(test_idx).intersection(set(test_seen_idx))) == 0
        assert len(df) == len(train_idx) + len(val_idx) + len(test_idx) + len(test_seen_idx), f"{len(df)} != {len(train_idx)} + {len(val_idx)} + {len(test_idx)} + {len(test_seen_idx)}"
        print(out_file, "train", len(train_idx), "val", len(val_idx), "test", len(test_idx), "test_seen", len(test_seen_idx))

    torch.save(splits_dict, out_file)

In [12]:
make_random_splits_file(df_L1000_VCAP, "l1000_vcap_random_splits.pt", train_VCAP, val_VCAP, test_VCAP, test_seen_VCAP)
make_random_splits_file(df_L1000_MCF7, "l1000_mcf7_random_splits.pt", train_MCF7, val_MCF7, test_MCF7, test_seen_MCF7)
make_random_splits_file(df_PCBA, "pcba_1328_random_splits.pt", train_PCBA, val_PCBA, test_PCBA, test_seen_PCBA)
make_random_splits_file(df_PCQM4M, "pcqm4m_g25_n4_random_splits.pt", train_QM, val_QM, test_QM)

l1000_vcap_random_splits.pt train 13316 val 578 test 578 test_seen 748
l1000_mcf7_random_splits.pt train 9713 val 422 test 422 test_seen 1065
pcba_1328_random_splits.pt train 1386580 val 60286 test 60286 test_seen 56512
pcqm4m_g25_n4_random_splits.pt train 3512805 val 148759 test 148759
