## Step 9: Create CellTypist model from Allen Brain Cell atlas data 

In Step 8, we manually annotated many of the clusters in the Wheeler et al mouse EAE single-cell RNA-seq dataset as non-astrocyte cell types. Next we'll use [CellTypist (PMID: 35549406)](https://pubmed.ncbi.nlm.nih.gov/35549406/), a machine learning tool for automated cell type annotation, to check if it agrees with our manual annotations.

The first step in using CellTypist is the creation of a model from a reference dataset. In this case we use the Allen Brain Cell atlas ([Yao et al 2023, PMID: 38092916](https://pubmed.ncbi.nlm.nih.gov/38092916/)) as a reference for typical gene expression across brain cell types.

In [1]:
import os
os.chdir('..') # changing working directory to parent 'EpiMemAstros' directory, adjust as needed
import pandas as pd
from pathlib import Path
import numpy as np
import anndata as ad
from scipy import sparse
import scanpy as sc
import celltypist

### Part 1: Load the ABC atlas, create cell type groups, and downsample

First we'll load the atlas anndata object. Note that this H5AD file requires hundreds of GBs if the AnnData object is read fully into memory. Depending on your hardware specs, you may need to load this file in 'backed' mode so that it is not fully loaded into memory. [See the Scanpy documentation for details.](https://scanpy.readthedocs.io/en/stable/generated/scanpy.read_h5ad.html)

*Alternatively*, you may skip ahead to **Part 2** below and load a pre-made downsampled anndata object which we provided in the Zenodo repository you downloaded in Step 0.

In [2]:
mouse_atlas = sc.read_h5ad('outputs/allen_brain_cell_atlas-RAW.h5ad')
mouse_atlas

AnnData object with n_obs × n_vars = 4042976 × 32285
    obs: 'cell_barcode_x', 'library_label_x', 'anatomical_division_label', 'cell_barcode_y', 'barcoded_cell_sample_label', 'library_label_y', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'abc_sample_id', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster'
    var: 'gene_symbol', 'name', 'mapped_ncbi_identifier', 'comment'

Next, we swap out ENSEMBL ids for gene symbols in the anndata object .var names, and we normalize and log-transform the count matrix.

In [3]:
mouse_atlas.var['ensembl_id'] = mouse_atlas.var.index
mouse_atlas.var['gene_symbol'] = mouse_atlas.var['gene_symbol'].astype(str)
mouse_atlas.var['original_gene_symbol'] = mouse_atlas.var.gene_symbol
mouse_atlas.var.index = mouse_atlas.var.gene_symbol
mouse_atlas.var_names_make_unique()

sc.pp.normalize_total(mouse_atlas, target_sum=10000)
sc.pp.log1p(mouse_atlas)

We next group the atlas clusters into broader cell type categories. For cell type annotations, we aggregate neuronal clusters into broader classes: glutamatergic neurons, immature glutamatergic neurons, GABAergic neurons, immature GABAergic neurons, dopaminergic neurons, and serotonergic neurons.

In [4]:
obs_df = mouse_atlas.obs
obs_df['celltype'] = np.where(obs_df['class'].isin(['01 IT-ET Glut',
                                                     '02 NP-CT-L6b Glut',
                                                     '03 OB-CR Glut',
                                                     '13 CNU-HYa Glut',
                                                     '14 HY Glut',
                                                     '15 HY Gnrh1 Glut',
                                                     '16 HY MM Glut',
                                                     '17 MH-LH Glut',
                                                     '18 TH Glut',
                                                     '19 MB Glut',
                                                     '23 P Glut',
                                                     '24 MY Glut',
                                                     '25 Pineal Glut',
                                                     '29 CB Glut']),
                              "Glutamatergic neuron", "")
obs_df['celltype'] = np.where(obs_df['class'].isin([
                                                     '06 CTX-CGE GABA',
                                                     '07 CTX-MGE GABA',
                                                     '08 CNU-MGE GABA',
                                                     '09 CNU-LGE GABA',
                                                     '10 LSX GABA',
                                                     '11 CNU-HYa GABA',
                                                     '12 HY GABA',
                                                     '20 MB GABA',
                                                     '26 P GABA',
                                                     '27 MY GABA',
                                                     '28 CB GABA']),
                              "GABAergic neuron", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['class'].isin(['04 DG-IMN Glut']),
                              "Immature glutamatergic neuron", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['class'].isin(['05 OB-IMN GABA']),
                              "Immature GABAergic neuron", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['class'].isin(['21 MB Dopa']),
                              "Dopaminergic neuron", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['class'].isin(['22 MB-HB Sero']),
                              "Serotonergic neuron", obs_df['celltype'])

Next we aggregate all astrocyte clusters into a single astrocyte category.

In [5]:
obs_df['celltype'] = np.where(obs_df['subclass'].isin(['317 Astro-CB NN', '318 Astro-NT NN',
       '319 Astro-TE NN', '320 Astro-OLF NN']), "Astrocyte", obs_df['celltype'])

Now we'll rename the remaining cluster classes so they're easier to read:

In [6]:
obs_df['celltype'] = np.where(obs_df['subclass']=='316 Bergmann NN', "Bergmann glia", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='321 Astroependymal NN', "Astroependymal", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='322 Tanycyte NN', "Tanycyte", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='323 Ependymal NN', "Ependymal", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='324 Hypendymal NN', "Hypendymal", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='325 CHOR NN', "Choroid plexus", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='326 OPC NN', "OPC", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='327 Oligo NN', "Oligodendrocyte", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='328 OEC NN', "Olfactory epithelial", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='329 ABC NN', "Arachnoid barrier", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='330 VLMC NN', "Vascular leptomeningeal", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='331 Peri NN', "Pericyte", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='332 SMC NN', "Smooth muscle", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='333 Endo NN', "Endothelial", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='334 Microglia NN', "Microglia", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='335 BAM NN', "BAM", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='336 Monocytes NN', "Monocyte", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='337 DC NN', "Dendritic", obs_df['celltype'])
obs_df['celltype'] = np.where(obs_df['subclass']=='338 Lymphoid NN', "Lymphoid", obs_df['celltype'])

mouse_atlas.obs = obs_df
mouse_atlas.obs.celltype = mouse_atlas.obs.celltype.astype('category')

Now let's check how many cells we have in each renamed cell type category:

In [7]:
obs_df.celltype.value_counts()

celltype
Glutamatergic neuron             1978218
GABAergic neuron                  781072
Oligodendrocyte                   422574
Astrocyte                         299182
OPC                               122605
Immature GABAergic neuron         107502
Endothelial                        88011
Microglia                          86232
Immature glutamatergic neuron      84352
Pericyte                           24907
Smooth muscle                      14614
Vascular leptomeningeal             9104
BAM                                 5626
Dopaminergic neuron                 4301
Bergmann glia                       3321
Ependymal                           3259
Serotonergic neuron                 2466
Astroependymal                      1232
Olfactory epithelial                1132
Tanycyte                            1072
Arachnoid barrier                    857
Choroid plexus                       476
Lymphoid                             404
Dendritic                            285
Hypendy

We can see we have hundreds of thousands to millions of certain neuron categories and other cell type categories like Monocytes and Hypendymal cells have very few cells. Because this reference atlas would be prohibitively large to run with CellTypist as is, we choose to randomly downsample each cell type category so that each category has a maximum of 5,000 cells. We also remove any cell type categories which have fewer than 400 cells, as we did in Step 3, as these cell type categories comprise less than 0.01% of the dataset.

In [8]:
target_cells = 5000
min_cells = 400
cluster_key = "celltype"

adatas = [mouse_atlas[mouse_atlas.obs[cluster_key].isin([clust])] for clust in mouse_atlas.obs[cluster_key].cat.categories]

for dat in adatas:
    if dat.n_obs > target_cells:
        sc.pp.subsample(dat, n_obs=target_cells)
    elif dat.n_obs < min_cells:
        sc.pp.subsample(dat, n_obs=0)

adata_downsampled = adatas[0].concatenate(*adatas[1:])
del adata_downsampled.var['gene_symbol']

adata_downsampled

AnnData object with n_obs × n_vars = 83520 × 32285
    obs: 'cell_barcode_x', 'library_label_x', 'anatomical_division_label', 'cell_barcode_y', 'barcoded_cell_sample_label', 'library_label_y', 'feature_matrix_label', 'entity', 'brain_section_label', 'library_method', 'region_of_interest_acronym', 'donor_label', 'donor_genotype', 'donor_sex', 'dataset_label', 'x', 'y', 'cluster_alias', 'abc_sample_id', 'neurotransmitter', 'class', 'subclass', 'supertype', 'cluster', 'celltype', 'batch'
    var: 'name', 'mapped_ncbi_identifier', 'comment', 'ensembl_id', 'original_gene_symbol'

Now, let's check that downsampling worked as expected:

In [9]:
adata_downsampled.obs.celltype.value_counts()

celltype
Astrocyte                        5000
BAM                              5000
Microglia                        5000
Immature GABAergic neuron        5000
Glutamatergic neuron             5000
GABAergic neuron                 5000
Endothelial                      5000
Vascular leptomeningeal          5000
Smooth muscle                    5000
Pericyte                         5000
OPC                              5000
Oligodendrocyte                  5000
Immature glutamatergic neuron    5000
Dopaminergic neuron              4301
Bergmann glia                    3321
Ependymal                        3259
Serotonergic neuron              2466
Astroependymal                   1232
Olfactory epithelial             1132
Tanycyte                         1072
Arachnoid barrier                 857
Choroid plexus                    476
Lymphoid                          404
Name: count, dtype: int64

Now we'll save the downsampled anndata object:

In [10]:
adata_downsampled.write_h5ad('outputs/downsampled_abc_atlas.h5ad')

... storing 'original_gene_symbol' as categorical


### Part 2: Train a new CellTypist model

Lastly, we'll train a new CellTypist model on the downsampled anndata object and save the model:

If you skipped **Part 1** due to the large memory/disk space requirements and wish to instead use a pre-made downsampled AnnData object, you can uncomment the following code cell. Otherwise you can proceed to the following cell.

In [11]:
#adata_downsampled = sc.read_h5ad('inputs/zenodo/downsampled_abc_atlas.h5ad')

In [12]:
model = celltypist.train(adata_downsampled, labels = "celltype", n_jobs = 8,
                         check_expression = False)
model.write('outputs/celltypist_ABC_model.pkl')

🍳 Preparing data before training
✂️ 2823 non-expressed genes are filtered out
🔬 Input data has 83520 cells and 29462 genes
⚖️ Scaling input data
🏋️ Training data using logistic regression
✅ Model training done!
