# Generating the Random Splits for Model Evaluation: GSE144236

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import scanpy as sc
import seaborn as sns
from sklearn.model_selection import ShuffleSplit

In [2]:
# Prep
dataset_name = 'GSE144236' # label for the dataset
dictionary_dir = 'SavedSplitDicts' # dir where we save the split dictionaries
metadata_dir = 'Metadata_Splits' # dir where we save metadata

In [3]:
# Make sure necessary directories are avalible

# dictionary dir
if not os.path.exists(dictionary_dir):
    os.makedirs(dictionary_dir)
    print(f"Directory {dictionary_dir} created for saving split dictionaries")
    
else:
    print('Directory already exists!')


# metadata dir
if not os.path.exists(metadata_dir):
    os.makedirs(metadata_dir)
    print(f"Directory {metadata_dir} created for saving metadata with the 5 splits")
    
else:
    print('Directory already exists!')

Directory already exists!
Directory already exists!


Load the dataset

In [4]:
adata = sc.read_h5ad('GSE144236_qc_hvg_anno_5k_raw_train_split.h5ad')

In [5]:
X = adata.raw.X.todense()

Generate the 5 splits using sklearn `ShuffleSplit`
- standard 80/20 split

In [6]:
rs = ShuffleSplit(n_splits=5, test_size=.20, random_state=2022)
rs.get_n_splits(X)

5

Create a list of dictionaries containing the generated splits

In [7]:
split_list = []
for train_index, test_index in rs.split(X):
    split_list.append({"train": train_index, "test": test_index})

In [8]:
dataset_splits = dict(list(enumerate(split_list)))

In [9]:
dataset_splits

{0: {'train': array([11054, 30965, 42248, ..., 16557,  1244, 21373]),
  'test': array([34717, 27820, 41304, ...,  7538, 35197,  1780])},
 1: {'train': array([37433, 28797, 17101, ..., 28290, 33927, 23822]),
  'test': array([ 3797, 33124, 34866, ..., 44651,   283, 14810])},
 2: {'train': array([41661, 33172,  9029, ...,   116, 15794, 27078]),
  'test': array([33852, 44196,  9632, ...,   458, 11384, 29039])},
 3: {'train': array([36299, 17846, 25235, ..., 15014, 16076, 17602]),
  'test': array([24579, 43797,  5638, ..., 41347, 10263, 25275])},
 4: {'train': array([32429,   609, 36693, ..., 16899,  3625, 44968]),
  'test': array([11461, 41460,  7667, ..., 42606,   803, 37816])}}

Renaming the dictionaries to contain "Split_#"

In [10]:
new_keys = [f'Split_{i}' for i in range(1,6)]
dataset_splits = dict(zip(new_keys, list(dataset_splits.values())))

In [11]:
# save the dictionary
def Pickler(data, filename):
    
    outfile = open(filename, 'wb+')
    
    #source destination
    
    pickle.dump(data, outfile)
    
    outfile.close()

Pickler(dataset_splits, filename=f"{dictionary_dir}/{dataset_name}_SplitDict.pickle")

In [12]:
# to load

def Unpickler(filename):
    
    infile = open(filename, 'rb+')
    
    return_file = pickle.load(infile);
    
    infile.close()

    return return_file

test_loaddict = Unpickler(filename=f"{dictionary_dir}/{dataset_name}_SplitDict.pickle")

In [13]:
test_loaddict

{'Split_1': {'train': array([11054, 30965, 42248, ..., 16557,  1244, 21373]),
  'test': array([34717, 27820, 41304, ...,  7538, 35197,  1780])},
 'Split_2': {'train': array([37433, 28797, 17101, ..., 28290, 33927, 23822]),
  'test': array([ 3797, 33124, 34866, ..., 44651,   283, 14810])},
 'Split_3': {'train': array([41661, 33172,  9029, ...,   116, 15794, 27078]),
  'test': array([33852, 44196,  9632, ...,   458, 11384, 29039])},
 'Split_4': {'train': array([36299, 17846, 25235, ..., 15014, 16076, 17602]),
  'test': array([24579, 43797,  5638, ..., 41347, 10263, 25275])},
 'Split_5': {'train': array([32429,   609, 36693, ..., 16899,  3625, 44968]),
  'test': array([11461, 41460,  7667, ..., 42606,   803, 37816])}}

#### Add split information to the adata object metadata

In [14]:
# reset index to have numeric rownames instead of barcodes
adata.obs = adata.obs.reset_index()

In [15]:
# add all 5 splits
adata.obs['Split_1'] = np.where(adata.obs.index.isin(dataset_splits['Split_1']['train'].tolist()), 'train', 'test')
adata.obs['Split_2'] = np.where(adata.obs.index.isin(dataset_splits['Split_2']['train'].tolist()), 'train', 'test')
adata.obs['Split_3'] = np.where(adata.obs.index.isin(dataset_splits['Split_3']['train'].tolist()), 'train', 'test')
adata.obs['Split_4'] = np.where(adata.obs.index.isin(dataset_splits['Split_4']['train'].tolist()), 'train', 'test')
adata.obs['Split_5'] = np.where(adata.obs.index.isin(dataset_splits['Split_5']['train'].tolist()), 'train', 'test')

#### Quick sanity check on Split_1

In [16]:
adata.obs['Split_1'].value_counts()

train    37616
test      9405
Name: Split_1, dtype: int64

In [17]:
# quick sanity check for train samples
rand_train = np.random.choice(dataset_splits['Split_1']['train'].tolist())
print(f'Random test cell: {rand_train}');
adata.obs.iloc[rand_train] # supposed to be a 'train cell'

Random test cell: 2036


index                P1_Normal_CGACTTCCAGCTGGCT
nFeature_RNA                              430.0
nCount_RNA                               2375.0
barcodes             P1_Normal_CGACTTCCAGCTGGCT
patient                                      P1
tum.norm                                 Normal
celltypes                            Epithelial
celltypes_lvl2                   Normal_KC_Diff
celltypes_lvl3                   Normal_KC_Diff
percent_mito                           4.749009
RNA_snn_res.0.2                               0
RNA_snn_res.0.4                               0
kmeans_14                                     7
encoded_celltypes                             5
cluster                                       5
split                                     train
Split_1                                   train
Split_2                                   train
Split_3                                   train
Split_4                                   train
Split_5                                 

In [18]:
# quick sanity check for test samples
rand_test = np.random.choice(dataset_splits['Split_1']['test'].tolist())
print(f'Random test cell: {rand_test}');
adata.obs.iloc[rand_test] # supposed to be a 'test cell'

Random test cell: 5490


index                P2_Normal_ACTGATGAGCCCAATT
nFeature_RNA                              596.0
nCount_RNA                               2087.0
barcodes             P2_Normal_ACTGATGAGCCCAATT
patient                                      P2
tum.norm                                 Normal
celltypes                            Epithelial
celltypes_lvl2                  Normal_KC_Basal
celltypes_lvl3                  Normal_KC_Basal
percent_mito                           4.696417
RNA_snn_res.0.2                               9
RNA_snn_res.0.4                              14
kmeans_14                                    14
encoded_celltypes                             5
cluster                                       5
split                                     train
Split_1                                    test
Split_2                                   train
Split_3                                   train
Split_4                                    test
Split_5                                 

In [19]:
adata.obs.to_csv(f'{metadata_dir}/{dataset_name}_metadata_splits.csv')