In [3]:
import random
import pandas as pd
import torch

In [7]:
def make_random_splits(num_elem, val_ratio, test_ratio, seed=42):
    """Make random splits of the data into train, validation, and test sets.

    Args:
        num_elem (int): Number of elements in the dataset.
        val_ratio (float): Ratio of the dataset to use for validation.
        test_ratio (float): Ratio of the dataset to use for testing.

    Returns:
        train_idx (list): List of indices for the training set.
        val_idx (list): List of indices for the validation set.
        test_idx (list): List of indices for the test set.
    """
    # Create a list of indices
    idx = list(range(num_elem))
    # Shuffle the list of indices
    random.seed(seed)
    random.shuffle(idx)
    # Compute the number of elements in each set
    num_val = int(num_elem * val_ratio)
    num_test = int(num_elem * test_ratio)
    num_train = num_elem - num_val - num_test
    # Split the list of indices into three sets
    train_idx = idx[:num_train]
    val_idx = idx[num_train:num_train + num_val]
    test_idx = idx[num_train + num_val:]
    # Return the three lists of indices
    return train_idx, val_idx, test_idx

In [10]:
def make_random_splits_from_file(in_file, out_file, val_ratio, test_ratio, seed=42):
    # Make the splits for QM9
    df = pd.read_csv(in_file, usecols=["smiles"])
    train_idx, val_idx, test_idx = make_random_splits(len(df), val_ratio, test_ratio, seed=seed)
    # Save the splits
    splits_dict = {"train": train_idx, "val": val_idx, "test": test_idx}

    # Check the splits validity
    assert len(set(train_idx).intersection(set(val_idx))) == 0
    assert len(set(train_idx).intersection(set(test_idx))) == 0
    assert len(set(val_idx).intersection(set(test_idx))) == 0
    assert len(train_idx) + len(val_idx) + len(test_idx) == len(df)
    
    torch.save(splits_dict, out_file)

In [12]:
make_random_splits_from_file("qm9.csv.gz", "qm9_random_splits.pt", 0.1, 0.1)
make_random_splits_from_file("Tox21-7k-12-labels.csv.gz", "Tox21_random_splits.pt", 0.1, 0.1)
make_random_splits_from_file("ZINC12k.csv.gz", "ZINC12k_random_splits.pt", 0.1, 0.1)