Create two subset trying to keep as much as possible the same proportions of cell states and cell type across stages

In [2]:
import numpy as np 
import scanpy as sc
from anndata import AnnData
from scipy import sparse
from tqdm.notebook import tqdm

In [4]:
#Load preprocessed data
data_path = 'data/preprocessed_data.h5ad'
adata = sc.read_h5ad(data_path)
adata

AnnData object with n_obs × n_vars = 205153 × 16906
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'Stage_Code', 'Tissue', 'Risk_Category', 'First_Avail_TP', 'MYCN_Status', 'ALK_Status', 'TP53_Status', 'Response', 'Vital_Status', 'Age_at_IDX_in_months', 'Treatment', 'First_Avail_Time_Point', 'sample_name', 'biospecimen_id', 'percent.mt', 'seurat_clusters', 'sample_label_wo_prefix', 'S.Score', 'G2M.Score', 'Phase', 'malignancy', 'cell_state', 'RNA_snn_res.0.2', 'MES_Score', 'ADRN_Score', 'MES_ADRN_diff', 'Event', 'organism_ontology_term_id', 'donor_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'disease_ontology_term_id', 'tissue_type', 'cell_type_ontology_term_id', 'assay_ontology_term_id', 'suspension_type', 'tissue_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    uns: 'log1p'

In [125]:
# Subsample the data, making sure to keep the same proportions of cell states and cell type across stages
N_subsample = 100000
stage_codes = ['DX', 'PTX']
original_dist_cell_state = adata.obs['cell_state'].value_counts(normalize=True)
original_cell_type_dist = adata.obs['cell_type'].value_counts(normalize=True)
adata_stage = []

for stage in stage_codes:
    stage_cells = adata[adata.obs['Stage_Code'] == stage].copy()
    sampled_indices = []

    for state, proportion_state in original_dist_cell_state.items():
        # Filter for cells in this state
        state_cells = stage_cells[stage_cells.obs['cell_state'] == state]
        
        # Calculate the number of samples for this state
        n_to_sample_state = int(proportion_state * (N_subsample / 2))
        n_to_sample_state = min(n_to_sample_state, len(state_cells))
        

        for cell_type, proportion_type in original_cell_type_dist.items():
            # Subsample based on cell_type proportion
            type_cells = state_cells[state_cells.obs['cell_type'] == cell_type]
            n_to_sample_type = int(proportion_type * n_to_sample_state)
            n_to_sample_type = min(n_to_sample_type, len(type_cells))

            sampled = type_cells.obs.sample(n=n_to_sample_type, random_state=0)
            sampled_indices.extend(sampled.index.tolist())
            
    # Create the subsampled dataset for this stage
    stage_cells_matched = stage_cells[sampled_indices].copy()
    adata_stage.append(stage_cells_matched)

# Combine both stages
subset = adata_stage[0].concatenate(adata_stage[1], batch_key="Stage_Code", batch_categories=stage_codes)

subset


  subset = adata_stage[0].concatenate(adata_stage[1], batch_key="Stage_Code", batch_categories=stage_codes)


AnnData object with n_obs × n_vars = 96690 × 16906
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'Stage_Code', 'Tissue', 'Risk_Category', 'First_Avail_TP', 'MYCN_Status', 'ALK_Status', 'TP53_Status', 'Response', 'Vital_Status', 'Age_at_IDX_in_months', 'Treatment', 'First_Avail_Time_Point', 'sample_name', 'biospecimen_id', 'percent.mt', 'seurat_clusters', 'sample_label_wo_prefix', 'S.Score', 'G2M.Score', 'Phase', 'malignancy', 'cell_state', 'RNA_snn_res.0.2', 'MES_Score', 'ADRN_Score', 'MES_ADRN_diff', 'Event', 'organism_ontology_term_id', 'donor_id', 'development_stage_ontology_term_id', 'sex_ontology_term_id', 'disease_ontology_term_id', 'tissue_type', 'cell_type_ontology_term_id', 'assay_ontology_term_id', 'suspension_type', 'tissue_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'

In [None]:
# Split the subsampled data in two sets, preserving the same proportions of cell states
# Lists to collect indices
indices_test = []
indices_train = []

# Loop through each cell_state and split it
for state, group in subset.obs.groupby('cell_state'):
    # Shuffle and split indices
    shuffled = group.sample(frac=1.0, random_state=42)
    n = len(shuffled)
    split = n // 2

    indices_test.extend(shuffled.iloc[:split].index)
    indices_train.extend(shuffled.iloc[split:].index)

# Create two AnnData objects
adata_train = subset[indices_test].copy()
adata_test = subset[indices_train].copy()

  for state, group in subset.obs.groupby('cell_state'):


In [None]:
indices_test.write_h5ad('./data/tst_data.h5ad', compression='gzip')
adata_train.write_h5ad('./data/finetunning_data.h5ad', compression='gzip')

### Distributions

In [122]:
#Original distribution of cell states
print(adata.obs['Stage_Code'].value_counts(normalize=True) * 100)
print(adata.obs['cell_type'].value_counts(normalize=True) * 100)
adata.obs['cell_state'].value_counts(normalize=True) * 100

Stage_Code
PTX    52.070406
DX     47.929594
Name: proportion, dtype: float64
cell_type
neuroblast (sensu Vertebrata)    97.117273
fibroblast                        2.430381
Schwann cell                      0.452345
Name: proportion, dtype: float64


cell_state
ADRN-Calcium          25.567747
ADRN-Baseline         23.743255
Interm-OxPhos         18.084064
ADRN-Dopaminergic     17.401647
ADRN-Proliferating     9.820475
MES                    5.382812
Name: proportion, dtype: float64

In [None]:
#Train distribution of cell states
print(adata_test.obs['Stage_Code'].value_counts(normalize=True) * 100)
print(adata_test.obs['cell_type'].value_counts(normalize=True) * 100)
print(adata_test.obs['cell_state'].value_counts(normalize=True) * 100)
print('----------------------')
print(adata_test[adata_test.obs['Stage_Code'] == 'DX'].obs['cell_state'].value_counts(normalize=True) * 100)
print(adata_test[adata_test.obs['Stage_Code'] == 'PTX'].obs['cell_state'].value_counts(normalize=True) * 100)
print('----------------------')
print(adata_test[adata_test.obs['Stage_Code'] == 'DX'].obs['cell_type'].value_counts(normalize=True) * 100)
print(adata_test[adata_test.obs['Stage_Code'] == 'PTX'].obs['cell_type'].value_counts(normalize=True) * 100)

Stage_Code
DX     50.018617
PTX    49.981383
Name: proportion, dtype: float64
cell_type
neuroblast (sensu Vertebrata)    98.841635
fibroblast                        1.083899
Schwann cell                      0.074466
Name: proportion, dtype: float64
cell_state
ADRN-Calcium          25.883253
ADRN-Baseline         23.839566
Interm-OxPhos         18.484197
ADRN-Dopaminergic     17.491312
ADRN-Proliferating    10.011584
MES                    4.290088
Name: proportion, dtype: float64
----------------------
cell_state
ADRN-Calcium          26.012158
ADRN-Baseline         24.585418
Interm-OxPhos         18.522807
ADRN-Dopaminergic     17.517886
ADRN-Proliferating    10.318018
MES                    3.043712
Name: proportion, dtype: float64
cell_state
ADRN-Calcium          25.754252
ADRN-Baseline         23.093159
Interm-OxPhos         18.445557
ADRN-Dopaminergic     17.464719
ADRN-Proliferating     9.704921
MES                    5.537392
Name: proportion, dtype: float64
-------------------

In [None]:
#Test distribution of cell states
print(adata_train.obs['Stage_Code'].value_counts(normalize=True) * 100)
print(adata_train.obs['cell_type'].value_counts(normalize=True) * 100)
print(adata_train.obs['cell_state'].value_counts(normalize=True) * 100)
print('----------------------')
print(adata_train[adata_train.obs['Stage_Code'] == 'DX'].obs['cell_state'].value_counts(normalize=True) * 100)
print(adata_train[adata_train.obs['Stage_Code'] == 'PTX'].obs['cell_state'].value_counts(normalize=True) * 100)
print('----------------------')
print(adata_train[adata_train.obs['Stage_Code'] == 'DX'].obs['cell_type'].value_counts(normalize=True) * 100)
print(adata_train[adata_train.obs['Stage_Code'] == 'PTX'].obs['cell_type'].value_counts(normalize=True) * 100)

Stage_Code
PTX    50.730153
DX     49.269847
Name: proportion, dtype: float64
cell_type
neuroblast (sensu Vertebrata)    98.831341
fibroblast                        1.092128
Schwann cell                      0.076532
Name: proportion, dtype: float64
cell_state
ADRN-Calcium          25.882183
ADRN-Baseline         23.840649
Interm-OxPhos         18.483432
ADRN-Dopaminergic     17.490589
ADRN-Proliferating    10.011169
MES                    4.291979
Name: proportion, dtype: float64
----------------------
cell_state
ADRN-Calcium          25.890008
ADRN-Baseline         24.647355
Interm-OxPhos         18.677582
ADRN-Dopaminergic     17.707809
ADRN-Proliferating    10.041982
MES                    3.035264
Name: proportion, dtype: float64
cell_state
ADRN-Calcium          25.874582
ADRN-Baseline         23.057164
Interm-OxPhos         18.294871
ADRN-Dopaminergic     17.279622
ADRN-Proliferating     9.981244
MES                    5.512517
Name: proportion, dtype: float64
-------------------