In [1]:
from typing import List, Tuple, Dict

import random
import pandas as pd
import numpy as np
import torch
from os.path import join

In [2]:
NUM_MOLS = 83746835

In [3]:
def make_random_splits(num_elem, val_ratio, test_ratio, seed=42, ignore_idx=None):
    """Make random splits of the data into train, validation, and test sets.

    Args:
        num_elem (int): Number of elements in the dataset.
        val_ratio (float): Ratio of the dataset to use for validation.
        test_ratio (float): Ratio of the dataset to use for testing.
        seed (int): Random seed.
        ignore_idx (list): List of indices to ignore.

    Returns:
        train_idx (list): List of indices for the training set.
        val_idx (list): List of indices for the validation set.
        test_idx (list): List of indices for the test set.
    """
    # Create a list of indices
    idx = list(range(num_elem))
    # Remove the indices to ignore
    if ignore_idx is not None:
        idx = list(set(idx) - set(ignore_idx))
        num_elem = len(idx)
    # Shuffle the list of indices
    random.seed(seed)
    random.shuffle(idx)
    # Compute the number of elements in each set
    num_val = int(num_elem * val_ratio)
    num_test = int(num_elem * test_ratio)
    num_train = num_elem - num_val - num_test
    # Split the list of indices into three sets
    train_idx = idx[:num_train]
    val_idx = idx[num_train:(num_train + num_val)]
    test_idx = idx[(num_train + num_val):]
    # Return the three lists of indices
    return train_idx, val_idx, test_idx

In [4]:
train_PM6, val_PM6, test_PM6 = make_random_splits(NUM_MOLS, 0.01, 0.01, seed=42, ignore_idx=None)

In [5]:
def make_random_splits_file(num_mols, out_file, train_idx, val_idx, test_idx, test_seen_idx=None):
    # Save the splits
    if test_seen_idx is None:
        splits_dict = {"train": train_idx, "val": val_idx, "test": test_idx}
    else:
        splits_dict = {"train": train_idx, "val": val_idx, "test": test_idx, "test_seen": test_seen_idx}

    # Check the splits validity
    
    assert len(set(train_idx).intersection(set(val_idx))) == 0
    assert len(set(train_idx).intersection(set(test_idx))) == 0
    assert len(set(val_idx).intersection(set(test_idx))) == 0
    assert len(train_idx) > 0
    assert len(val_idx) > 0
    assert len(test_idx) > 0

    if test_seen_idx is None:
        assert num_mols == len(train_idx) + len(val_idx) + len(test_idx), f"{num_mols} != {len(train_idx)} + {len(val_idx)} + {len(test_idx)}"
        print(out_file, "train", len(train_idx), "val", len(val_idx), "test", len(test_idx))
    else:
        assert len(test_seen_idx) > 0
        assert len(set(train_idx).intersection(set(test_seen_idx))) == 0
        assert len(set(val_idx).intersection(set(test_seen_idx))) == 0
        assert len(set(test_idx).intersection(set(test_seen_idx))) == 0
        assert num_mols == len(train_idx) + len(val_idx) + len(test_idx) + len(test_seen_idx), f"{num_mols} != {len(train_idx)} + {len(val_idx)} + {len(test_idx)} + {len(test_seen_idx)}"
        print(out_file, "train", len(train_idx), "val", len(val_idx), "test", len(test_idx), "test_seen", len(test_seen_idx))

    torch.save(splits_dict, out_file)

In [6]:
make_random_splits_file(NUM_MOLS, "pm6_random_splits.pt", train_PM6, val_PM6, test_PM6)

pm6_random_splits.pt train 82071899 val 837468 test 837468
