# How to Train a Spatial CellTypist Model using scRNAseq Data

#### **Prerequisites**: 
- A virtual enviroment (eg conda environment) with CellTypist from the Teich Lab installed (refer to [the GitHub](https://github.com/Teichlab/celltypist))
- Preprocessed and cell typed single-cell dataset saved in AnnData format
- 10X Xenium dataset or the list of genes in the panel. 

For this workflow, I will be using the **Lung** dataset (avaliable [here]()) and the **10X Lung Xenium** dataset (avaliable [here]()). Both are used in the Segger paper (Heidari, Moorman et al. [*bioRxiv*](https://www.biorxiv.org/content/10.1101/2025.03.14.643160v1) 2025, [GitHub](https://github.com/dpeerlab/segger-analysis/)).

In [1]:
import celltypist as ct
import ast
from matplotlib import pyplot as plt
from pathlib import Path
import seaborn as sns
from tqdm import tqdm
import scanpy as sc
import pandas as pd
import numpy as np
import scipy as sp
import warnings
import json
import sys
import os

## Training a CellTypist Model on scRNAseq Data
For a more in-depth explanation of CellTypist and model training, refer to this [SAIL GitHub](https://github.com/joadams1/celltypist/blob/main/celltypist/How%20To%20Train%20a%20CellTypist%20Model.ipynb).

In some cases, it is suitable to train a CellTypist model for spatial data on an annotated scRNA-seq dataset. But to do this there are some steps you must take first to make the RNA data more like Xenium data. First, you need to subset the RNA data to genes only in the gene panel of the data you are interested in annotating. This is because the model will not perform well if it is reliant on genes not included in the Xenium dataset to accurately cell type. 

In [4]:
# NSCLC Atlas
filepath_ad = '../data_spatial/core_nsclc_atlas.h5ad' #replace with your own path
ad_atlas_all = sc.read_h5ad(filepath_ad)

#ensure var_names is the gene names and not Ensembl IDs
ad_atlas_all.var_names = ad_atlas_all.var['feature_name']

# Xenium Dataset
filepath_xen = '../data_spatial/10x_lung_cell_id.h5ad' #replace with your own path
ad_xen = sc.read_h5ad(filepath_xen)

In [9]:
ad_atlas = ad_atlas_all[:,ad_atlas_all.var_names.isin(ad_xen.var_names)].copy()
ad_atlas #check to make sure final atlas has the right number of genes

AnnData object with n_obs × n_vars = 892296 × 383
    obs: 'sample', 'uicc_stage', 'ever_smoker', 'age', 'donor_id', 'origin', 'dataset', 'ann_fine', 'cell_type_predicted', 'doublet_status', 'leiden', 'n_genes_by_counts', 'total_counts', 'total_counts_mito', 'pct_counts_mito', 'ann_coarse', 'cell_type_tumor', 'tumor_stage', 'EGFR_mutation', 'TP53_mutation', 'ALK_mutation', 'BRAF_mutation', 'ERBB2_mutation', 'KRAS_mutation', 'ROS_mutation', 'origin_fine', 'study', 'platform', 'cell_type_major', 'suspension_type', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'self_reported_ethnicity_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'sex_ontology_term_id', 'tissue_ontology_term_id', 'tissue_type', 'cell_type', 'assay', 'disease', 'organism', 'sex', 'tissue', 'self_reported_ethnicity', 'development_stage', 'observation_joinid'
    var: 'is_highly_variable', 'mito', 'n_cells_by_counts', 'mean_count

As well, Xenium captures far fewer reads per cell than scRNA-seq data, so to make the two more comparable, you need to downsample the data to Xenium-level reads. Make note of how many counts per cell you scale this data to, because it will be important that you ensure the data set you apply this model to is similarly scaled. Generally, the model is pretty robust to the exact value you chose to scale to. 

In [10]:
# Downsample and renormalize
ad_atlas.X = ad_atlas.layers['count'].copy()
sc.pp.downsample_counts(ad_atlas, counts_per_cell=100)
ad_atlas.layers['norm_100'] = ad_atlas.X.copy()
sc.pp.normalize_total(ad_atlas, layer='norm_100', target_sum=1e2)

# Logarthmize
ad_atlas.layers['lognorm_100'] = ad_atlas.layers['norm_100'].copy()
if 'log1p' in ad_atlas.uns:
    del ad_atlas.uns['log1p']
sc.pp.log1p(ad_atlas, layer='lognorm_100')

In order to balance out the cell types and not lose any during model training, you can subset your data to include the same number of cells per cell type. 

In [11]:
gb = ad_atlas.obs.groupby('cell_type')
sample = gb.sample(2000, replace=True).index.drop_duplicates()

Now that you have adjusted the scRNA-seq data, you can train a CellTypist Model. 

In [15]:
# Predict on log counts
ad_atlas.X = ad_atlas.layers['lognorm_100']

ct_model = ct.train(
    ad_atlas[sample],
    labels='cell_type',
    check_expression=False,
    n_jobs=32,
    max_iter=100,
)

filepath_ct = 'models/nsclc_celltypist_model.pkl' #replace with your path
ct_model.write(filepath)

🍳 Preparing data before training
🔬 Input data has 55198 cells and 383 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!


NameError: name 'filepath' is not defined

Now you can use this model to train other datasets. For an explanatory work-through for how to do so, see the [`Using Xenium CellTypist Models`](https://github.com/joadams1/spatial_celltypist/blob/main/notebooks/Using%20Xenium%20CellTypist%20Model.ipynb) notebook. 