In [1]:
import numpy as np
import pandas as pd
import glob
import json
import os
from pathlib import Path
from rdkit import Chem
from tqdm.auto import tqdm

In [3]:
with open('random_seeds/MF_PCBA_random_seeds.json', 'r') as f:
    random_seeds = json.load(f)

In [2]:
# From https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test
def train_validate_test_split(df, train_percent=.8, validate_percent=.1, seed=None):
    np.random.seed(seed)
    perm = np.random.permutation(df.index)
    m = len(df.index)
    train_end = int(train_percent * m)
    validate_end = int(validate_percent * m) + train_end
    train = df.iloc[perm[:train_end]]
    validate = df.iloc[perm[train_end:validate_end]]
    test = df.iloc[perm[validate_end:]]
    return train, validate, test

In [7]:
def split_dr(dataset_path: str, dataset_name:str, save_path: str, is_dr_separate: bool):
    def split_dataset(dataframe):
        seeds = random_seeds[dataset_name]

        for j, s in enumerate(seeds):
            train, validate, test = train_validate_test_split(dataframe, train_percent=.8, validate_percent=.1, seed=s)

            save_dirpath = os.path.join(save_path, dataset_name, str(j))
            Path(save_dirpath).mkdir(exist_ok=True, parents=True)

            train.to_csv(f'{save_dirpath}/train.csv', index=False)
            validate.to_csv(f'{save_dirpath}/validate.csv', index=False)
            test.to_csv(f'{save_dirpath}/test.csv', index=False)

    # Main function        
    df = pd.read_csv(dataset_path)

    if is_dr_separate:
        split_dataset(df)
    elif not is_dr_separate:
        df_dr = df[~df['DR'].isna()].reset_index()
        split_dataset(df_dr)

In [8]:
split_dr(dataset_path=f'filtered_datasets/AID1445/SD.csv',
         dataset_name='1445',
         save_path='train_val_test_splits',
         is_dr_separate=False)

In [9]:
split_dr(dataset_path=f'filtered_datasets/AID624273-588549/DR.csv',
         dataset_name='624273-588549',
         save_path='train_val_test_splits',
         is_dr_separate=True)