# Generating the Random Splits for Model Evaluation: Lukassen2020_Lung

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import scanpy as sc
import seaborn as sns
from sklearn.model_selection import ShuffleSplit

In [2]:
# Prep
dataset_name = 'Lukassen2020_Lung' # label for the dataset
dictionary_dir = 'SavedSplitDicts' # dir where we save the split dictionaries
metadata_dir = 'Metadata_Splits' # dir where we save metadata

In [3]:
# Make sure necessary directories are avalible

# dictionary dir
if not os.path.exists(dictionary_dir):
    os.makedirs(dictionary_dir)
    print(f"Directory {dictionary_dir} created for saving split dictionaries")
    
else:
    print('Directory already exists!')


# metadata dir
if not os.path.exists(metadata_dir):
    os.makedirs(metadata_dir)
    print(f"Directory {metadata_dir} created for saving metadata with the 5 splits")
    
else:
    print('Directory already exists!')

Directory already exists!
Directory already exists!


Load the dataset

In [4]:
adata = sc.read_h5ad('Lukassen2020_Lung_qc_hvg_anno_5k_split.h5ad')

In [5]:
X = adata.raw.X.todense()

Generate the 5 splits using sklearn `ShuffleSplit`
- standard 80/20 split

In [6]:
rs = ShuffleSplit(n_splits=5, test_size=.20, random_state=2022)
rs.get_n_splits(X)

5

Create a list of dictionaries containing the generated splits

In [7]:
split_list = []
for train_index, test_index in rs.split(X):
    split_list.append({"train": train_index, "test": test_index})

In [8]:
dataset_splits = dict(list(enumerate(split_list)))

In [9]:
dataset_splits

{0: {'train': array([16836, 38093,  2758, ..., 16557,  1244, 21373]),
  'test': array([39515,  3655, 15269, ...,  6522, 31211, 28694])},
 1: {'train': array([ 2320, 23041, 33517, ..., 20414, 20159, 23050]),
  'test': array([ 6960,  5428, 13812, ..., 24631,  8048, 36966])},
 2: {'train': array([20240, 14440, 25109, ..., 37536,  4980, 10899]),
  'test': array([20934, 34640, 11453, ..., 28944, 27040, 17777])},
 3: {'train': array([23748, 33186,  8835, ..., 20974, 19237, 13469]),
  'test': array([32308, 29640, 28666, ..., 10301,  9927, 20243])},
 4: {'train': array([18325, 15391, 37549, ...,  4276, 27976,  2743]),
  'test': array([22092,  6455,  1990, ..., 33767, 22763, 19798])}}

Renaming the dictionaries to contain "Split_#"

In [10]:
new_keys = [f'Split_{i}' for i in range(1,6)]
dataset_splits = dict(zip(new_keys, list(dataset_splits.values())))

In [11]:
# save the dictionary
def Pickler(data, filename):
    
    outfile = open(filename, 'wb+')
    
    #source destination
    
    pickle.dump(data, outfile)
    
    outfile.close()

Pickler(dataset_splits, filename=f"{dictionary_dir}/{dataset_name}_SplitDict.pickle")

In [12]:
# to load

def Unpickler(filename):
    
    infile = open(filename, 'rb+')
    
    return_file = pickle.load(infile);
    
    infile.close()

    return return_file

test_loaddict = Unpickler(filename=f"{dictionary_dir}/{dataset_name}_SplitDict.pickle")

In [13]:
test_loaddict

{'Split_1': {'train': array([16836, 38093,  2758, ..., 16557,  1244, 21373]),
  'test': array([39515,  3655, 15269, ...,  6522, 31211, 28694])},
 'Split_2': {'train': array([ 2320, 23041, 33517, ..., 20414, 20159, 23050]),
  'test': array([ 6960,  5428, 13812, ..., 24631,  8048, 36966])},
 'Split_3': {'train': array([20240, 14440, 25109, ..., 37536,  4980, 10899]),
  'test': array([20934, 34640, 11453, ..., 28944, 27040, 17777])},
 'Split_4': {'train': array([23748, 33186,  8835, ..., 20974, 19237, 13469]),
  'test': array([32308, 29640, 28666, ..., 10301,  9927, 20243])},
 'Split_5': {'train': array([18325, 15391, 37549, ...,  4276, 27976,  2743]),
  'test': array([22092,  6455,  1990, ..., 33767, 22763, 19798])}}

#### Add split information to the adata object metadata

In [14]:
# reset index to have numeric rownames instead of barcodes
adata.obs = adata.obs.reset_index()

In [15]:
# add all 5 splits
adata.obs['Split_1'] = np.where(adata.obs.index.isin(dataset_splits['Split_1']['train'].tolist()), 'train', 'test')
adata.obs['Split_2'] = np.where(adata.obs.index.isin(dataset_splits['Split_2']['train'].tolist()), 'train', 'test')
adata.obs['Split_3'] = np.where(adata.obs.index.isin(dataset_splits['Split_3']['train'].tolist()), 'train', 'test')
adata.obs['Split_4'] = np.where(adata.obs.index.isin(dataset_splits['Split_4']['train'].tolist()), 'train', 'test')
adata.obs['Split_5'] = np.where(adata.obs.index.isin(dataset_splits['Split_5']['train'].tolist()), 'train', 'test')

#### Quick sanity check on Split_1

In [16]:
adata.obs['Split_1'].value_counts()

train    31822
test      7956
Name: Split_1, dtype: int64

In [17]:
# quick sanity check for train samples
rand_train = np.random.choice(dataset_splits['Split_1']['train'].tolist())
print(f'Random test cell: {rand_train}');
adata.obs.iloc[rand_train] # supposed to be a 'train cell'

Random test cell: 19830


index                JVV9L8ng_CATCAAGGTGTCGCTG-1
orig.ident                              JVV9L8ng
nCount_RNA                                 821.0
nFeature_RNA                                 505
barcodes             JVV9L8ng_CATCAAGGTGTCGCTG-1
ID                                      JVV9L8ng
Sex                                            M
Age                                  -2147483648
Smoking                               NonSmoking
UMI.count                                   3157
Gene.count                                  1919
Cell.type                            Fibroblasts
MT.ratio                                0.001901
celltypes                            Fibroblasts
percent_mt                              0.190054
RNA_snn_res.0.2                                6
RNA_snn_res.0.4                                7
kmeans_9                                       3
cluster                                        5
encoded_celltypes                              5
split               

In [18]:
# quick sanity check for test samples
rand_test = np.random.choice(dataset_splits['Split_1']['test'].tolist())
print(f'Random test cell: {rand_test}');
adata.obs.iloc[rand_test] # supposed to be a 'test cell'

Random test cell: 10974


index                9JQK55ng_CACATAGTCTGACCTC-1
orig.ident                              9JQK55ng
nCount_RNA                                1273.0
nFeature_RNA                                 562
barcodes             9JQK55ng_CACATAGTCTGACCTC-1
ID                                      9JQK55ng
Sex                                            F
Age                                           45
Smoking                                  Smoking
UMI.count                                   4893
Gene.count                                  2643
Cell.type                                    AT1
MT.ratio                                0.003066
celltypes                                    AT1
percent_mt                               0.30656
RNA_snn_res.0.2                                2
RNA_snn_res.0.4                                1
kmeans_9                                       1
cluster                                        0
encoded_celltypes                              0
split               

In [19]:
adata.obs.to_csv(f'{metadata_dir}/{dataset_name}_metadata_splits.csv')