# Create training and test set

**Goal**: Create test set and put it aside - this will only be used for benchmarking at the end. I will apply scikitlearn's stratifiedshufflesplit to try and sample evenly from different cell types, in case there are cell type specific signatures.

**Output**: Test dataset (used for model assessment only), training dataset (used for model training and hyperparameter selection)

**Assumptions (that could be tweaked)**:
- Normalization procedure is sufficient rescaling to facilitate effective model training (scanpy normalize_total + log1p)

In [2]:

# Import needed libraries
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import OneHotEncoder

import scanpy as sc
import numpy as np
import os

import matplotlib.cm as cm

import pickle

from utils.config import *
from utils.analysis_variables import *

In [3]:
# Scanpy setup
sc.settings.verbosity = 3 # corresponds to hints

# Notebook setup
np.random.seed(15)

import warnings
warnings.filterwarnings('ignore')

In [4]:
# Important paths
notebook_name = "03_create_test_train"

# path_outdir_base = "../../output/20240221_import"
path_results = os.path.join(path_outdir_base, notebook_name)
os.makedirs(path_results, exist_ok=True)

path_input_data = os.path.join(path_outdir_base, "02_explore_data_cell_id", "adata_clusterlabels.h5ad")

# Import data and data preparation

In [5]:
adata = sc.read_h5ad(path_input_data)

In [6]:
# Create factor level that is target variable, Sex, plus the level to split on
adata.obs['Sex_Celltype'] = adata.obs.apply(lambda x: x.Sex + "_" + x.leiden_labeled, axis=1)

adata.obs.head(3)

Unnamed: 0_level_0,sample_ids,PRESENT_raw,sample_ids_letter,SampleID,True ID,Litter,Pooled,Genotype,Condition,Group,...,pct_counts_gene_mt,total_counts_gene_hsp,pct_counts_gene_hsp,total_counts_gene_ribo,pct_counts_gene_ribo,total_counts_gene_hemo,pct_counts_gene_hemo,n_genes,leiden_labeled,Sex_Celltype
index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCTGAGCAGCGTA-1-F_filtered_gene_bc_matrices,F_filtered_gene_bc_matrices,True,F,F,7011 #7-5 (F) and 7011 #7-4 (M),7011_7,Y,WT,LPS,WT.LPS,...,1.749744,123.0,1.049846,2133.0,18.205872,4.0,0.034141,3824,Radial Glial Cell,Unknown_Radial Glial Cell
AAACCTGAGCGATCCC-1-F_filtered_gene_bc_matrices,F_filtered_gene_bc_matrices,True,F,F,7011 #7-5 (F) and 7011 #7-4 (M),7011_7,Y,WT,LPS,WT.LPS,...,1.423794,115.0,1.098901,1898.0,18.136646,5.0,0.047778,3674,Radial Glial Cell,Unknown_Radial Glial Cell
AAACCTGAGCTGAACG-1-F_filtered_gene_bc_matrices,F_filtered_gene_bc_matrices,True,F,F,7011 #7-5 (F) and 7011 #7-4 (M),7011_7,Y,WT,LPS,WT.LPS,...,2.448111,89.0,0.789427,2000.0,17.739933,11.0,0.09757,3725,Radial Glial Cell,Unknown_Radial Glial Cell


# Create test dataset to put aside

Here, y is Sex_Celltype so that I can sample representatives from each Sex x Cell type combination.

In [7]:
# Define the number of splits and test size
n_splits = 1
test_size = 0.2

In [8]:
# Filter out the data with no known sex as an unlabeled test set
adata_unlabeled = adata[adata.obs['Sex'] == 'Unknown']
adata_labeled = adata[adata.obs['Sex'] != 'Unknown']

In [9]:

# Make X and y_label variables for these in array format
data = adata_labeled.X.toarray()

one_hot_encoder = OneHotEncoder(sparse=False, drop='first', handle_unknown='ignore')
target_sex_celltype = adata_labeled.obs.Sex_Celltype
target_sex = adata_labeled.obs.Sex
target_sex_celltype_encoded = one_hot_encoder.fit_transform(target_sex_celltype.astype('category').values.reshape(-1, 1))

stratified_splitter = StratifiedShuffleSplit(n_splits=n_splits, test_size=test_size, random_state=my_random_state)

for train_index, test_index in stratified_splitter.split(data, target_sex_celltype_encoded):
    # Get the training and test data
    X_train, X_test = data[train_index], data[test_index] # order of entries same as order in adata_labeled
    y_train, y_test = target_sex[train_index], target_sex[test_index] # Actually only include the one predictor we care about


In [10]:
print(f"Size of training data (N cell x gene): {X_train.shape}")
print(f"Size of test data: {X_test.shape}")

Size of training data (N cell x gene): (7407, 20821)
Size of test data: (1852, 20821)


In [11]:
print("Training data N:")
adata_labeled.obs.iloc[train_index].value_counts(['Sex', 'leiden_labeled'])

Training data N:


Sex  leiden_labeled              
M    Radial Glial Cell               1860
F    Radial Glial Cell               1154
M    Intermediate Progenitor Cell     988
F    Intermediate Progenitor Cell     841
     Immature Excitatory Neuron       513
M    Excitatory Neuron                513
     Immature Excitatory Neuron       498
F    Excitatory Neuron                470
M    Inhibitory Neuron                223
F    Inhibitory Neuron                134
     Radial Glia with VIM             107
M    Radial Glia with VIM             106
Name: count, dtype: int64

In [12]:
print("Test data N:")
adata_labeled.obs.iloc[test_index].value_counts(['Sex', 'leiden_labeled'])

Test data N:


Sex  leiden_labeled              
M    Radial Glial Cell               465
F    Radial Glial Cell               289
M    Intermediate Progenitor Cell    247
F    Intermediate Progenitor Cell    210
     Immature Excitatory Neuron      128
M    Excitatory Neuron               128
     Immature Excitatory Neuron      125
F    Excitatory Neuron               118
M    Inhibitory Neuron                56
F    Inhibitory Neuron                33
M    Radial Glia with VIM             27
F    Radial Glia with VIM             26
Name: count, dtype: int64

# Save data

In [13]:

with open(os.path.join(path_results, 'training_data.pkl'), 'wb') as f:
    pickle.dump({'X': X_train, 'Y': y_train}, f)

with open(os.path.join(path_results, 'test_data.pkl'), 'wb') as f:
    pickle.dump({'X': X_test, 'Y': y_test}, f)

adata_labeled.write(os.path.join(path_results, 'adata_labeled.h5ad'))
adata_unlabeled.write(os.path.join(path_results, 'adata_unlabeled.h5ad'))


In [14]:
path_results

'../../output/20240221_import/03_create_test_train'

In [19]:
y_test

index
GGACAGATCTCGAGTA-1-KM13_filtered_gene_bc_matrices    M
GAACCTACACGGCTAC-1-KM1_filtered_gene_bc_matrices     M
CTACGTCCAGCTGTAT-1-J_filtered_gene_bc_matrices       F
TGCTACCGTGCCTGTG-1-J_filtered_gene_bc_matrices       F
AACTCCCAGGACGAAA-1-I_filtered_gene_bc_matrices       F
                                                    ..
CTAGAGTTCTCAAACG-1-K_filtered_gene_bc_matrices       F
GTGAAGGAGTGAACAT-1-KM2_filtered_gene_bc_matrices     M
ACGCAGCTCGCATGGC-1-K_filtered_gene_bc_matrices       F
GCTGGGTCACGCCAGT-1-KM2_filtered_gene_bc_matrices     M
CAAGTTGAGAAGGCCT-1-L_filtered_gene_bc_matrices       F
Name: Sex, Length: 1852, dtype: category
Categories (2, object): ['F', 'M']