# Cell Typing of Segmented Xenium Data for NSCLC

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from requirements import *
from segger.data.parquet._utils import (
    filter_transcripts,
    load_settings,
)
from sg_utils.tl.phenograph_rapids import phenograph_rapids
from sg_utils.tl.xenium_utils import anndata_from_transcripts
from sg_utils.pp.preprocess_rapids import *
from sg_utils.pl.plot_embedding import plot_embedding
from sg_utils.tl.get_group_markers import *
from sg_utils.pl.plot_group_markers import plot_group_markers
import celltypist as ct
import gc

In [7]:
dataset = 'xenium_nsclc'

## Build Cell Typist Model

In [8]:
# NSCLC Atlas
filepath = 'h5ads/core_nsclc_atlas_panel_only.h5ad'
ad_atlas = sc.read_h5ad(data_dir / dataset / filepath)

# Re-normalize counts to 10K total
ad_atlas.X = ad_atlas.layers['count'].copy()
sc.pp.downsample_counts(ad_atlas, counts_per_cell=100)
ad_atlas.layers['norm_100'] = ad_atlas.X.copy()
sc.pp.normalize_total(ad_atlas, layer='norm_100', target_sum=1e2)

# Logarthmize
ad_atlas.layers['lognorm_100'] = ad_atlas.layers['norm_100'].copy()
if 'log1p' in ad_atlas.uns:
    del ad_atlas.uns['log1p']
sc.pp.log1p(ad_atlas, layer='lognorm_100')

In [163]:
# Subsample using more granular cell types (to not lose any one cell type)
# But transfer labels using the compartment labels
gb = ad_atlas.obs.groupby('cell_type')
sample = gb.sample(2000, replace=True).index.drop_duplicates()

# Predict on log counts
ad_atlas.X = ad_atlas.layers['lognorm_1k']
with HiddenPrints():
    ct_model = ct.train(
        ad_atlas[sample],
        labels='cell_compartment',
        check_expression=False,
        n_jobs=32,
        max_iter=100,
    )

filepath = 'celltypist/nsclc_celltypist_model.pkl'
ct_model.write(data_dir / dataset / filepath)

## Transcripts to AnnData

In [4]:
# Segmentation columns to compare
segmentations = {
    'segger_cell_id_HDE46PBXJB': 'Segger',
    'baysor_cell_id_c=0.5': 'Baysor, c=0.5',
    '10x_cell_id': '10X',
    'cellpose_cell_id': 'CellPose',
    '10x_nucleus_id': '10X Nucleus',
    'baysor_cell_id_c=0.7': 'Baysor, c=0.7',
    'baysor_cell_id_c=0.9': 'Baysor, c=0.9',
}

In [11]:
# Read in all transcripts
filepath = 'labeled_transcripts.parquet'
transcripts = pd.read_parquet(data_dir / dataset / filepath)

# Filter control probes and low QV probes
xe_settings = load_settings('xenium')

transcripts = filter_transcripts(
    transcripts,
    label=xe_settings.transcripts.label,
    filter_substrings=xe_settings.transcripts.filter_substrings,
    min_qv=30,
)

In [6]:
# Convert to AnnData and preprocess
for seg_col in segmentations.keys():

    # Subset Segger data to high-confidence transcripts
    mask = np.full(transcripts.shape[0], True)
    if 'segger' in seg_col:
        score_col = seg_col.replace('cell_id', 'score')
        mask &= transcripts[score_col].gt(0.75)
    
    # Transcripts to anndata
    ad = anndata_from_transcripts(
        transcripts[mask],
        cell_label=seg_col,
        gene_label='feature_name',
        
    )

    # Add raw counts before filtering
    ad.uns['raw_counts'] = dict(
        index=ad.obs.index.tolist(),
        count=ad.raw.X.A.sum(1),
    )
    
    # Preprocess
    threshold = 5 #np.quantile(ad.uns['raw_counts']['count'], 0.05)
    preprocess_rapids(
        ad,
        filter_min_counts=threshold,
        pca_total_var=0.75,
        umap_min_dist=0.25,
        umap_n_epochs=4000,
        pca_layer='lognorm',
        knn_neighbors=20,
        phenograph_resolution=1,
    )

    # Save to file
    ad.write_h5ad(data_dir / f'h5ads/{seg_col}.h5ad')

Done: 100%|██████████| 6/6 [02:59<00:00, 29.85s/it]         


## Cell Type with CellTypist

In [7]:
ct_model = ct.Model.load(str(data_dir / 'nsclc_celltypist_model.pkl'))

# Cell type each segmentation
for seg_col in tqdm(segmentations.keys()):
    
    # Read in AnnData
    filepath = data_dir / f'h5ads/{seg_col}.h5ad'
    ad = sc.read_h5ad(filepath)
    
    # Re-normalize consistent with CellTypist model
    ad.layers['norm_1k'] = ad.raw.X.copy()
    sc.pp.normalize_total(ad, layer='norm_1k', target_sum=1e2)
    ad.layers['lognorm_1k'] = ad.layers['norm_1k'].copy()
    if 'log1p' in ad.uns: del ad.uns['log1p']
    sc.pp.log1p(ad, layer='lognorm_1k')

    phenograph_rapids(ad, min_size=1, resolution=1)
    
    # Cell type
    with HiddenPrints():
        ad.X = ad.layers['lognorm_1k']
        preds = ct.annotate(
            ad, model=ct_model, majority_voting=True,
            over_clustering='phenograph_cluster',
            min_prop=0.5,
        )

    # Label AnnData
    ad.obs['celltypist_label'] = preds.predicted_labels['predicted_labels']
    ad.obs['celltypist_label_cluster'] = preds.predicted_labels['majority_voting']
    ad.obs['celltypist_probability'] = preds.probability_matrix.max(1)
    for col in preds.probability_matrix.columns:
        ad.obs[f'{col} Probability'] = preds.probability_matrix[col]
    entropy = sp.stats.entropy(preds.probability_matrix, axis=1)
    ad.obs['celltypist_entropy'] = entropy

    # Cleanup
    del ad.layers['lognorm_1k'], ad.layers['norm_1k']

    ad.write_h5ad(filepath)

100%|██████████| 1/1 [00:25<00:00, 25.86s/it]
