## CAS v1 Client Demo

`pbmc_10k_v3` dataset downloaded from 10x Genomics:

https://support.10xgenomics.com/single-cell-gene-expression/datasets/3.0.0/pbmc_10k_v3

In [None]:
import os
import sys
import matplotlib.pylab as plt
import scanpy as sc
import numpy as np
import pandas as pd
import scipy.sparse as sp
sys.path.append('../src')

from cas_client_helper import *

sc.settings.set_figure_params(dpi=80, facecolor='white')

In [None]:
# load raw pbmc_10k_v3 data (downloaded from 10x Genomics website)
adata_loaded = sc.read_10x_mtx(
    '/home/jupyter/data/casp-cli-demo/pbmc_10k_v3',
    var_names='gene_symbols',
    cache=True)

In [None]:
# validate and reformat adata
adata = validate_adata_for_cas(
    adata_loaded,
    int_count_matrix='X',
    gene_symbols_column_name='__index__',
    gene_ids_column_name='gene_ids',
    missing_features_policy='replace_with_zero',
    extra_features_policy='ignore',
    casp_feature_list_csv_path='../resources/casp_v1_feature_list.csv')

In [None]:
# save the raw counts
adata.raw = adata

## The standard `scanpy` workflow

basic cell QC, normalization, clustering, embedding

In [None]:
warnings.simplefilter('ignore', UserWarning)

# basic cell QC
sc.pp.filter_cells(adata, min_genes=1000)
sc.pp.filter_genes(adata, min_cells=20)
adata.var['mt'] = adata.var['gene_symbols'].str.startswith('MT-')
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

# sc.pl.scatter(adata, x='total_counts', y='pct_counts_mt')
# sc.pl.scatter(adata, x='total_counts', y='n_genes_by_counts')

adata = adata[adata.obs.n_genes_by_counts < 4000, :]
adata = adata[adata.obs.pct_counts_mt < 20, :]

# feature selection
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata)

# sc.pl.highly_variable_genes(adata)

adata = adata[:, adata.var.highly_variable]
sc.pp.scale(adata, max_value=10)

# clustering
sc.tl.pca(adata, svd_solver='arpack')

# sc.pl.pca(adata, color='CST3', gene_symbols='gene_symbols')
# sc.pl.pca_variance_ratio(adata, log=True)

# embedding
sc.pp.neighbors(adata, n_neighbors=100, n_pcs=40)
sc.tl.umap(adata)

In [None]:
sc.write('./data/pbmc_10k_v3_basic.h5ad', adata)

## Exploring the raw data

In [None]:
adata = sc.read('./data/pbmc_10k_v3_basic.h5ad')

In [None]:
adata

In [None]:
sc.pl.umap(adata, color=['NKG7', 'CST3', 'CD79A', 'CD8A', 'CCR7'], gene_symbols='gene_symbols', use_raw=False)

## Cell type assignment with CAS

In [None]:
!pip uninstall -y cell-annotation-service-client

In [None]:
!pip install --upgrade git+https://github.com/broadinstitute/cell-annotation-service-client.git@fg-annotate

In [None]:
from casp_cli import service

cli = service.CASPClientService()

In [None]:
# revert the raw adata (integer counts, no gene filter)
adata_raw = adata.raw.to_adata().copy()
adata_raw.raw = adata_raw

cas_query_res = cli.annotate_anndata(adata_raw)

## Explore CAS output

In [None]:
cas_query_res[0]

In [None]:
reduce_cas_query_result_by_majority_vote(adata, cas_query_res)
sc.pl.umap(adata, color='cas_cell_type')

In [None]:
# cluster
sc.tl.leiden(adata, resolution=5.0)

In [None]:
sc.pl.umap(adata, color='leiden')

In [None]:
# reduce annotations per cluster
cluster_detailed_info_dict = reduce_cas_query_result_by_majority_vote_per_cluster(
    adata, cas_query_res, cluster_key='leiden')

# visualize
sc.pl.umap(adata, color='cas_per_cluster_cell_type')

In [None]:
def highlight_cluster(adata, cluster_id, top_k=10):
    fig, ax = plt.subplots()
    ax.scatter(adata.obsm['X_umap'][:, 0], adata.obsm['X_umap'][:, 1], s=2, edgecolor='none', color='gray', alpha=0.25)
    adata_subset = adata[adata.obs['leiden'] == cluster_id]
    ax.scatter(adata_subset.obsm['X_umap'][:, 0], adata_subset.obsm['X_umap'][:, 1], s=2, edgecolor='none', color='red', alpha=1.)
    ax.grid(False)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    plt.show()
    print(f'{"CELL TYPE":100s} {"FREQUENCY"}')
    for cell_type, freq in cluster_detailed_info_dict[cluster_id][:top_k]:
        print(f'{cell_type:100s} {freq:.4f}')
    print()
        
for cluster_id in adata.obs['leiden'].values.categories:
    highlight_cluster(adata, cluster_id)