In [1]:
import os
os.environ["SCIPY_ARRAY_API"] = "1"

In [2]:

from imblearn.over_sampling import RandomOverSampler


In [3]:
import anndata as ad
import pandas as pd
import numpy as np
from os.path import join
import os
from sklearn.model_selection import train_test_split
# from imblearn.over_sampling import RandomOverSampler
from sklearn.model_selection import StratifiedKFold, KFold

import random
import json

In [4]:
'''
Pre vs post
        no filter

Cancer type
	Analyze pre only
	Restrict to ER+ vs TNBC
	
Outcome:
	Analyze pre only
	Restrict to ['NE', 'E']
	Oversample minority class
	
	
Chemo: 
    no filter
	Analyze pre only
    
Split by patient 

Cell type:
	Analyze pre only
	Split by patient 


'''

"\nPre vs post\n        no filter\n\nCancer type\n\tAnalyze pre only\n\tRestrict to ER+ vs TNBC\n\t\nOutcome:\n\tAnalyze pre only\n\tRestrict to ['NE', 'E']\n\tOversample minority class\n\t\n\t\nChemo: \n    no filter\n\tAnalyze pre only\n    \nSplit by patient \n\nCell type:\n\tAnalyze pre only\n\tSplit by patient \n\n\n"

In [4]:
def read_data(DATA_FILE):
    fname = DATA_FILE
    adata = ad.read_h5ad(fname)
    return adata

In [5]:

def get_patient_info(adata, patient_key, label_key):
    patient_labels = adata.obs.groupby(patient_key)[label_key].first()
    sample_counts = adata.obs[patient_key].value_counts()
    
    patient_info = patient_labels.to_frame().join(sample_counts)
    patient_ids = patient_info.index.values
    ids = patient_info.index.values
    labels = patient_info[label_key]
    counts= patient_info['count'].values
    return ids, labels, counts


def oversample_data(train_ids, y_train, random_state):
    # Convert train_ids to array if it's not
    train_ids = np.array(train_ids).reshape(-1, 1)  # Reshape for oversampling

    # Apply random oversampling to balance the classes
    ros = RandomOverSampler(random_state=random_state)
    train_ids_oversampled, y_train_oversampled = ros.fit_resample(train_ids, y_train)

    # Flatten the train_ids back to 1D if needed
    train_ids_oversampled = train_ids_oversampled.ravel()
    return train_ids_oversampled, y_train_oversampled

def has_class_imbalance(y, threshold):
    counts = list(Counter(y).values())
    return min(counts) / max(counts) < threshold

def save_cv_splits(patient_ids,labels, save_dir, n_splits=5, random_state=42, imbalance_threshold=0.8):

    
    # Create directory for CV splits
    # cv_save_dir = os.path.join(save_dir, 'cv_splits')
    os.makedirs(save_dir, exist_ok=True)
    
    # Initialize splitter
    splitter = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    
    # Dictionary to store all splits
    cv_splits = {}    
    cv_splits_oversampled = {}

    # Generate and save splits
    for i, (train_idx, test_idx) in enumerate(splitter.split(patient_ids, labels)):
        print(f"=== Fold {i + 1}/{n_splits} ===")
        train_ids = patient_ids[train_idx]
        test_ids = patient_ids[test_idx]
        y_train = labels.iloc[train_idx]
        y_test = labels.iloc[test_idx]

        
        train_ids_oversampled, y_train_oversampled = oversample_data(train_ids, y_train, random_state)


        # Store the splits for this fold
        cv_splits[f'fold_{i+1}'] = {
            'train_ids': train_ids.tolist(),
            'test_ids': test_ids.tolist(),
            'train_labels': y_train.tolist(),
            'test_labels': y_test.tolist()
        }

        cv_splits_oversampled[f'fold_{i+1}'] = {
        'train_ids': train_ids_oversampled.tolist(),
        'test_ids': test_ids.tolist(),
        'train_labels': y_train_oversampled.tolist(),
        'test_labels': y_test.tolist()
        }
        
        # Save individual fold data
        fold_data = {
            'train_ids': train_ids.tolist(),
            'test_ids': test_ids.tolist(),
            'train_labels': labels.iloc[train_idx].tolist(),
            'test_labels': labels.iloc[test_idx].tolist()
        }
        
        # Save this fold's data
        with open(os.path.join(save_dir, f'fold_{i+1}.json'), 'w') as f:
            json.dump(fold_data, f)
    
    cv_splits['id_column'] = patient_key
    cv_splits['label_column'] = label_key
    cv_splits['random_state'] = random_state
    cv_splits['n_splits'] = n_splits

    cv_splits_oversampled['id_column'] = patient_key
    cv_splits_oversampled['label_column'] = label_key
    cv_splits_oversampled['random_state'] = random_state
    cv_splits_oversampled['n_splits'] = n_splits

    # Save all splits in one file
    with open(os.path.join(save_dir, 'cv_splits.json'), 'w') as f:
        json.dump(cv_splits, f)

    with open(os.path.join(save_dir, 'cv_splits_oversampled.json'), 'w') as f:
        json.dump(cv_splits_oversampled, f)
    

    print(f"Cross-validation splits saved to {save_dir}")
    return save_dir


def safe_stratified_split(patient_ids, counts, labels, test_size=0.3, random_state=42, max_tries=100):
    for attempt in range(max_tries):
        train_ids, test_ids, train_counts, test_counts, y_train, y_test = train_test_split(
            patient_ids,
            counts,
            labels,
            test_size=test_size,
            random_state=random_state + attempt,  # to vary the seed
            stratify=labels
        )
        if len(np.unique(y_train)) > 1 and len(np.unique(y_test)) > 1:
            return train_ids, test_ids, train_counts, test_counts, y_train, y_test
    raise ValueError("Could not generate a valid stratified split with at least 2 classes per split after multiple attempts.")

    
import os
import json
import pandas as pd
from collections import Counter

def save_train_test_split(patient_ids, labels, counts, save_dir, test_size=0.33, random_state=42, max_tries=10, imbalance_threshold=0.8):
    os.makedirs(save_dir, exist_ok=True)

    # Stratified split
    train_ids, test_ids, train_counts, test_counts, y_train, y_test = safe_stratified_split(
        patient_ids, counts, labels, test_size=test_size, random_state=random_state, max_tries=max_tries
    )

    def has_class_imbalance(y, threshold):
        counts = list(Counter(y).values())
        return min(counts) / max(counts) < threshold


    def save_df(data, filename):
        df = pd.DataFrame(data)
        df.to_csv(os.path.join(save_dir, filename), index=False)

    save_df({'patient_id': train_ids.ravel(), 'label': y_train, 'cell_count': train_counts}, 'train.csv')
    save_df({'patient_id': test_ids.ravel(), 'label': y_test, 'cell_count': test_counts}, 'test.csv')

    

    # Helper to build metadata
    def build_metadata(train_ids, y_train, oversampled):
        return {
            'train_test_split': {
                'train_ids': train_ids.tolist(),
                'test_ids': test_ids.tolist(),
                'train_labels': y_train.tolist(),
                'test_labels': y_test.tolist()
            },
            'id_column': patient_key,
            'label_column': label_key,
            'random_state': random_state,
            'test_size': test_size,
            'oversampled': oversampled
        }

    # Save original split
    with open(os.path.join(save_dir, 'train_test_split.json'), 'w') as f:
        json.dump(build_metadata(train_ids, y_train, oversampled=False), f)

    perform_oversampling = has_class_imbalance(y_train, imbalance_threshold)

    # Save oversampled metadata only if used
    if perform_oversampling:
        train_ids_oversampled, y_train_oversampled = oversample_data(train_ids, y_train, random_state)
        save_df({'patient_id': train_ids_oversampled.ravel(), 'label': y_train_oversampled}, 'train_oversampled.csv')
        with open(os.path.join(save_dir, 'train_test_split_oversampled.json'), 'w') as f:
            json.dump(build_metadata(train_ids_oversampled, y_train_oversampled, oversampled=True), f)

    print(f"Train-test split saved to {save_dir}")
    return save_dir


## chemo vs treatment naive


In [18]:
DATA_FILE = '/home/jupyter/DATA/brca_full/cancer_cells_only.h5ad'
save_dir ='/home/jupyter/sceval/data_splits/brca_full/brca_chemo'

patient_key = 'donor_id'


test_size = 0.33
random_state = 42
# test_size = 0.4
max_tries = 10
n_splits = 5
os.makedirs(save_dir, exist_ok=True)

In [23]:
label_key = 'cohort'

In [19]:
adata_chemo = read_data(DATA_FILE)


In [20]:
adata_chemo.obs.cohort.value_counts()

cohort
treatment_naive      29613
neoadjuvant_chemo     5079
Name: count, dtype: int64

In [21]:
adata_chemo = adata_chemo[adata_chemo.obs['timepoint']=='Pre']

In [25]:
patient_ids, labels, counts = get_patient_info(adata_chemo, patient_key, label_key)

  patient_labels = adata.obs.groupby(patient_key)[label_key].first()


In [26]:
patient_ids

['BIOKEY_1', 'BIOKEY_2', 'BIOKEY_3', 'BIOKEY_4', 'BIOKEY_5', ..., 'BIOKEY_38', 'BIOKEY_39', 'BIOKEY_40', 'BIOKEY_41', 'BIOKEY_42']
Length: 39
Categories (39, object): ['BIOKEY_1', 'BIOKEY_2', 'BIOKEY_3', 'BIOKEY_4', ..., 'BIOKEY_39', 'BIOKEY_40', 'BIOKEY_41', 'BIOKEY_42']

In [27]:
print(labels.value_counts())

cohort
treatment_naive      31
neoadjuvant_chemo     8
Name: count, dtype: int64


In [28]:
save_dir

'/home/jupyter/sceval/data_splits/brca_full/brca_chemo'

In [29]:
save_train_test_split(patient_ids,labels, counts, save_dir, test_size, random_state)

Train-test split saved to /home/jupyter/sceval/data_splits/brca_full/brca_chemo


'/home/jupyter/sceval/data_splits/brca_full/brca_chemo'

In [30]:
save_cv_splits(patient_ids,labels, save_dir, n_splits=n_splits, random_state=random_state) 


=== Fold 1/5 ===
=== Fold 2/5 ===
=== Fold 3/5 ===
=== Fold 4/5 ===
=== Fold 5/5 ===
Cross-validation splits saved to /home/jupyter/sceval/data_splits/brca_full/brca_chemo


'/home/jupyter/sceval/data_splits/brca_full/brca_chemo'

## pre_post

In [6]:

DATA_FILE = '/home/jupyter/DATA/brca_full/brca_full_3000cell_2048gene.h5ad'
save_dir ='/home/jupyter/sceval/data_splits/brca_full/brca_pre_post'

patient_key = 'donor_id_pre_post'
label_key = 'pre_post'

test_size = 0.33
random_state = 42
# test_size = 0.4
max_tries = 10
n_splits = 5
os.makedirs(save_dir, exist_ok=True)

In [7]:
adata = read_data(DATA_FILE)
adata

AnnData object with n_obs × n_vars = 87326 × 2048
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'donor_id', 'timepoint', 'outcome', 'Cancer_type', 'cell_types', 'cohort', 'pre_post', 'donor_id_pre_post', 'donor_id_outcome', 'donor_id_cell_types', 'donor_id_cell_types_pre_post', 'sample_id_pre_post_outcome', 'enough_cells', 'Study_name', 'Primary_or_met', 'RNA_snn_res.0.8', 'seurat_clusters', 'ident', 'n_genes_by_counts', 'total_counts', 'n_genes'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'X_name', 'hvg', 'log1p'
    obsm: 'PCA', 'UMAP'
    layers: 'logcounts', 'scaledata'

In [8]:
adata.obs['timepoint'].value_counts()

timepoint
On     46224
Pre    41102
Name: count, dtype: int64

In [9]:
adata.obs['Cancer_type'].value_counts()

Cancer_type
ER+      43114
TNBC     35212
HER2+     9000
Name: count, dtype: int64

In [10]:
adata.obs['cell_types'].value_counts()

cell_types
T_cell              27007
Cancer_cell         26822
Fibroblast          14793
Myeloid_cell         8317
B_cell               5890
Endothelial_cell     3828
Mast_cell             438
pDC                   231
Name: count, dtype: int64

In [11]:
adata.obs['outcome'].value_counts()

outcome
NE     54326
E      27000
n/a     6000
Name: count, dtype: int64

In [12]:
# Get patient IDs and labels
patient_ids, labels, counts = get_patient_info(adata, patient_key, label_key)

  patient_labels = adata.obs.groupby(patient_key)[label_key].first()


In [13]:
print(labels.value_counts())

pre_post
Post    31
Pre     31
Name: count, dtype: int64


In [14]:
save_train_test_split(patient_ids,labels, counts, save_dir, test_size, random_state)

Train-test split saved to /home/jupyter/sceval/data_splits/brca_full/brca_pre_potst


'/home/jupyter/sceval/data_splits/brca_full/brca_pre_potst'

In [15]:

save_cv_splits(patient_ids,labels, save_dir, n_splits=n_splits, random_state=random_state) 


=== Fold 1/5 ===
=== Fold 2/5 ===
=== Fold 3/5 ===
=== Fold 4/5 ===
=== Fold 5/5 ===
Cross-validation splits saved to /home/jupyter/sceval/data_splits/brca_full/brca_pre_potst


'/home/jupyter/sceval/data_splits/brca_full/brca_pre_potst'

In [16]:
adata.obs.head()

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,cell_id,donor_id,timepoint,outcome,Cancer_type,cell_types,cohort,...,sample_id_pre_post_outcome,enough_cells,Study_name,Primary_or_met,RNA_snn_res.0.8,seurat_clusters,ident,n_genes_by_counts,total_counts,n_genes
BIOKEY_13_Pre_GCGGGTTCAATGAATG-1,BIOKEY,4986.0,1918,BIOKEY_13_Pre_GCGGGTTCAATGAATG-1,BIOKEY_13,Pre,,HER2+,Myeloid_cell,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,6,6,6,1918,4986.0,1918
BIOKEY_13_On_GTCTCGTCACCATCCT-1,BIOKEY,876.0,640,BIOKEY_13_On_GTCTCGTCACCATCCT-1,BIOKEY_13,On,,HER2+,Cancer_cell,treatment_naive,...,BIOKEY_13_Post_n/a,enough,Bassez,Primary,0,0,0,640,876.0,640
BIOKEY_13_Pre_TGGCTGGAGATCCGAG-1,BIOKEY,10818.0,3438,BIOKEY_13_Pre_TGGCTGGAGATCCGAG-1,BIOKEY_13,Pre,,HER2+,Fibroblast,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,3,3,3,3438,10818.0,3438
BIOKEY_13_Pre_TGAGCCGAGATGTGTA-1,BIOKEY,646.0,411,BIOKEY_13_Pre_TGAGCCGAGATGTGTA-1,BIOKEY_13,Pre,,HER2+,Fibroblast,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,8,8,8,411,646.0,411
BIOKEY_13_Pre_GCGAGAAAGTGCGTGA-1,BIOKEY,3564.0,1588,BIOKEY_13_Pre_GCGAGAAAGTGCGTGA-1,BIOKEY_13,Pre,,HER2+,Fibroblast,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,3,3,3,1588,3564.0,1588


In [17]:
for c in adata.obs.columns:
    print(c)

orig.ident
nCount_RNA
nFeature_RNA
cell_id
donor_id
timepoint
outcome
Cancer_type
cell_types
cohort
pre_post
donor_id_pre_post
donor_id_outcome
donor_id_cell_types
donor_id_cell_types_pre_post
sample_id_pre_post_outcome
enough_cells
Study_name
Primary_or_met
RNA_snn_res.0.8
seurat_clusters
ident
n_genes_by_counts
total_counts
n_genes


## Cancer_type

In [18]:
DATA_FILE = '/home/jupyter/DATA/brca_full/brca_full_3000cell_2048gene.h5ad'
save_dir ='/home/jupyter/sceval/data_splits/brca_full/brca_subtype'

patient_key = 'donor_id'
label_key = 'Cancer_type'

test_size = 0.33
random_state = 42
# test_size = 0.4
max_tries = 10
n_splits = 5
os.makedirs(save_dir, exist_ok=True)

In [19]:
adata_subtype = read_data(DATA_FILE)


In [20]:
adata_subtype = adata_subtype[adata_subtype.obs['timepoint']=='Pre']

In [21]:
adata_subtype= adata_subtype[adata_subtype.obs['Cancer_type'].isin(['ER+', 'TNBC'])]

In [22]:
patient_ids, labels, counts = get_patient_info(adata_subtype, patient_key, label_key)

  patient_labels = adata.obs.groupby(patient_key)[label_key].first()


In [23]:
patient_ids

['BIOKEY_1', 'BIOKEY_2', 'BIOKEY_3', 'BIOKEY_4', 'BIOKEY_5', ..., 'BIOKEY_26', 'BIOKEY_27', 'BIOKEY_29', 'BIOKEY_30', 'BIOKEY_31']
Length: 28
Categories (28, object): ['BIOKEY_1', 'BIOKEY_2', 'BIOKEY_3', 'BIOKEY_4', ..., 'BIOKEY_27', 'BIOKEY_29', 'BIOKEY_30', 'BIOKEY_31']

In [24]:
print(labels.value_counts())

Cancer_type
ER+     15
TNBC    13
Name: count, dtype: int64


In [25]:
save_train_test_split(patient_ids,labels, counts, save_dir, test_size, random_state)

Train-test split saved to /home/jupyter/sceval/data_splits/brca_full/brca_subtype


'/home/jupyter/sceval/data_splits/brca_full/brca_subtype'

In [26]:
save_cv_splits(patient_ids,labels, save_dir, n_splits=n_splits, random_state=random_state) 


=== Fold 1/5 ===
=== Fold 2/5 ===
=== Fold 3/5 ===
=== Fold 4/5 ===
=== Fold 5/5 ===
Cross-validation splits saved to /home/jupyter/sceval/data_splits/brca_full/brca_subtype


'/home/jupyter/sceval/data_splits/brca_full/brca_subtype'

## outcome

In [27]:
DATA_FILE = '/home/jupyter/DATA/brca_full/brca_full_3000cell_2048gene.h5ad'
save_dir ='/home/jupyter/sceval/data_splits/brca_full/brca_outcome'

# patient_key = 'donor_id_outcome'
patient_key = 'donor_id'

label_key = 'outcome'

test_size = 0.33
random_state = 42
# test_size = 0.4
max_tries = 10
n_splits = 5
os.makedirs(save_dir, exist_ok=True)

In [28]:
adata_outcome = read_data(DATA_FILE)


In [29]:
adata_outcome = adata_outcome[adata_outcome.obs['timepoint']=='Pre']

In [30]:
adata_outcome= adata_outcome[adata_outcome.obs['outcome'].isin(['NE', 'E'])]

In [31]:
patient_ids, labels, counts = get_patient_info(adata_outcome, patient_key, label_key)

  patient_labels = adata.obs.groupby(patient_key)[label_key].first()


In [32]:
list(patient_ids)

['BIOKEY_2',
 'BIOKEY_3',
 'BIOKEY_4',
 'BIOKEY_5',
 'BIOKEY_6',
 'BIOKEY_7',
 'BIOKEY_8',
 'BIOKEY_9',
 'BIOKEY_10',
 'BIOKEY_11',
 'BIOKEY_12',
 'BIOKEY_14',
 'BIOKEY_15',
 'BIOKEY_16',
 'BIOKEY_17',
 'BIOKEY_18',
 'BIOKEY_19',
 'BIOKEY_20',
 'BIOKEY_21',
 'BIOKEY_22',
 'BIOKEY_23',
 'BIOKEY_24',
 'BIOKEY_25',
 'BIOKEY_26',
 'BIOKEY_27',
 'BIOKEY_28',
 'BIOKEY_29',
 'BIOKEY_30',
 'BIOKEY_31']

In [33]:
print(labels.value_counts())

outcome
NE    20
E      9
Name: count, dtype: int64


In [34]:
save_train_test_split(patient_ids,labels, counts, save_dir, test_size, random_state)

Train-test split saved to /home/jupyter/sceval/data_splits/brca_full/brca_outcome


'/home/jupyter/sceval/data_splits/brca_full/brca_outcome'

In [35]:
save_cv_splits(patient_ids,labels, save_dir, n_splits=n_splits, random_state=random_state) 


=== Fold 1/5 ===
=== Fold 2/5 ===
=== Fold 3/5 ===
=== Fold 4/5 ===
=== Fold 5/5 ===
Cross-validation splits saved to /home/jupyter/sceval/data_splits/brca_full/brca_outcome


'/home/jupyter/sceval/data_splits/brca_full/brca_outcome'

## cell type

In [8]:
DATA_FILE = '/home/jupyter/DATA/brca_full/brca_cells_only_3000cell_4096gene.h5ad'
save_dir ='/home/jupyter/sceval/data_splits/brca_full/brca_cell_type'

# patient_key = 'donor_id_outcome'
cell_key = 'cell_id'

label_key = 'cell_types'

test_size = 0.33
random_state = 42
# test_size = 0.4
max_tries = 10
n_splits = 5
os.makedirs(save_dir, exist_ok=True)

In [9]:
adata = read_data(DATA_FILE)


In [10]:
adata

AnnData object with n_obs × n_vars = 87326 × 4096
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cell_id', 'donor_id', 'timepoint', 'outcome', 'Cancer_type', 'cell_types', 'cohort', 'pre_post', 'donor_id_pre_post', 'donor_id_outcome', 'donor_id_cell_types', 'donor_id_cell_types_pre_post', 'sample_id_pre_post_outcome', 'enough_cells', 'Study_name', 'Primary_or_met', 'RNA_snn_res.0.8', 'seurat_clusters', 'ident', 'n_genes_by_counts', 'total_counts', 'n_genes'
    var: 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'X_name', 'hvg', 'log1p'
    obsm: 'PCA', 'UMAP'
    layers: 'counts', 'logcounts', 'scaledata'

In [11]:
adata_celltype = adata[adata.obs.timepoint =='Pre']

In [12]:
adata_celltype.obs

Unnamed: 0,orig.ident,nCount_RNA,nFeature_RNA,cell_id,donor_id,timepoint,outcome,Cancer_type,cell_types,cohort,...,sample_id_pre_post_outcome,enough_cells,Study_name,Primary_or_met,RNA_snn_res.0.8,seurat_clusters,ident,n_genes_by_counts,total_counts,n_genes
BIOKEY_13_Pre_GCGGGTTCAATGAATG-1,BIOKEY,4986.0,1918,BIOKEY_13_Pre_GCGGGTTCAATGAATG-1,BIOKEY_13,Pre,,HER2+,Myeloid_cell,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,6,6,6,1918,4986.0,1918
BIOKEY_13_Pre_TGGCTGGAGATCCGAG-1,BIOKEY,10818.0,3438,BIOKEY_13_Pre_TGGCTGGAGATCCGAG-1,BIOKEY_13,Pre,,HER2+,Fibroblast,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,3,3,3,3438,10818.0,3438
BIOKEY_13_Pre_TGAGCCGAGATGTGTA-1,BIOKEY,646.0,411,BIOKEY_13_Pre_TGAGCCGAGATGTGTA-1,BIOKEY_13,Pre,,HER2+,Fibroblast,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,8,8,8,411,646.0,411
BIOKEY_13_Pre_GCGAGAAAGTGCGTGA-1,BIOKEY,3564.0,1588,BIOKEY_13_Pre_GCGAGAAAGTGCGTGA-1,BIOKEY_13,Pre,,HER2+,Fibroblast,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,3,3,3,1588,3564.0,1588
BIOKEY_13_Pre_CTAGTGATCTGTGCAA-1,BIOKEY,7016.0,2180,BIOKEY_13_Pre_CTAGTGATCTGTGCAA-1,BIOKEY_13,Pre,,HER2+,Fibroblast,treatment_naive,...,BIOKEY_13_Pre_n/a,enough,Bassez,Primary,5,5,5,2180,7016.0,2180
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
BIOKEY_24_Pre_GATTCAGTCAAGGTAA-1,BIOKEY,8436.0,2689,BIOKEY_24_Pre_GATTCAGTCAAGGTAA-1,BIOKEY_24,Pre,NE,ER+,Endothelial_cell,treatment_naive,...,BIOKEY_24_Pre_NE,enough,Bassez,Primary,10,10,10,2689,8436.0,2689
BIOKEY_24_Pre_ACATGGTGTTGGGACA-1,BIOKEY,13452.0,3313,BIOKEY_24_Pre_ACATGGTGTTGGGACA-1,BIOKEY_24,Pre,NE,ER+,T_cell,treatment_naive,...,BIOKEY_24_Pre_NE,enough,Bassez,Primary,20,20,20,3313,13452.0,3313
BIOKEY_24_Pre_GGGACCTTCTCGCTTG-1,BIOKEY,506.0,344,BIOKEY_24_Pre_GGGACCTTCTCGCTTG-1,BIOKEY_24,Pre,NE,ER+,Myeloid_cell,treatment_naive,...,BIOKEY_24_Pre_NE,enough,Bassez,Primary,6,6,6,344,506.0,344
BIOKEY_24_Pre_GTCGGGTGTTGATTGC-1,BIOKEY,3094.0,1072,BIOKEY_24_Pre_GTCGGGTGTTGATTGC-1,BIOKEY_24,Pre,NE,ER+,T_cell,treatment_naive,...,BIOKEY_24_Pre_NE,enough,Bassez,Primary,0,0,0,1072,3094.0,1072


In [13]:
celltype = list(adata_celltype.obs.cell_types)
cell_ids = list(adata_celltype.obs.cell_id)
patient_ids =  list(adata_celltype.obs.donor_id_cell_types)
counts = list(adata_celltype.obs.cell_id.value_counts())

In [14]:
patient_ids_unique = np.unique(patient_ids)

In [24]:
patient_train, patient_test = train_test_split(patient_ids_unique,
            test_size=test_size,
            random_state=random_state 
    
        )

In [25]:
idx = adata_celltype.obs.donor_id_cell_types.isin(patient_train)
train_ids = list(adata_celltype[idx].obs.cell_id)
y_train = list(adata_celltype[idx].obs.cell_types)


idx = adata_celltype.obs.donor_id_cell_types.isin(patient_test)
test_ids = list(adata_celltype[idx].obs.cell_id)
y_test = list(adata_celltype[idx].obs.cell_types)



In [26]:
len(y_train), len(y_test)

(25355, 15747)

In [27]:
tobe_saved = {
            'train_test_split': {
                'train_ids': train_ids,
                'test_ids': test_ids,
                'train_labels': y_train,
                'test_labels': y_test
            },
            'id_column': 'cell_id',
            'label_column': 'celltype',
            'random_state': random_state,
            'test_size': test_size,
            'oversampled': False
        }

    # Save original split
with open(os.path.join(save_dir, 'train_test_split.json'), 'w') as f:
        json.dump(tobe_saved, f)


In [28]:
splitter = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)


In [29]:
cv_splits = {} 
for i, (patient_train, test_idx) in enumerate(splitter.split(patient_ids_unique)):
    print (i)
    idx = adata_celltype.obs.donor_id_cell_types.isin(patient_train)
    train_ids = list(adata_celltype[idx].obs.cell_id)
    y_train = list(adata_celltype[idx].obs.cell_types)


    idx = adata_celltype.obs.donor_id_cell_types.isin(patient_test)
    test_ids = list(adata_celltype[idx].obs.cell_id)
    y_test = list(adata_celltype[idx].obs.cell_types)
    
    # Store the splits for this fold
    cv_splits[f'fold_{i+1}'] = {
        'train_ids': train_ids,
        'test_ids': test_ids,
        'train_labels': y_train,
        'test_labels': y_test
    }



0
1
2
3
4


In [30]:
cv_splits['id_column'] = 'cell_ids'
cv_splits['label_column'] = 'celltype'
cv_splits['random_state'] = random_state
cv_splits['n_splits'] = n_splits


# Save all splits in one file
with open(os.path.join(save_dir, 'cv_splits.json'), 'w') as f:
    json.dump(cv_splits, f)