## Marker Selection

In [1]:
from concurrent.futures import ProcessPoolExecutor, as_completed
from itertools import combinations

import anndata
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import seaborn as sns
import pathlib
from cemba_data.tools.hdf5.anndata import rank_features_groups

## Parameter

In [2]:
cluster_col = 'SubType'
min_cluster_cell_number = 10
exclude_str = ['Outlier']
adj_p_cutoff = 1e-5
log2fc_cutoff = 1
top_n = 10
cpu = 10

In [3]:
# Parameters
cluster_col = "SubType"
min_cluster_cell_number = 10
exclude_str = ["Outlier", "NonN"]
adj_p_cutoff = 1e-05
log2fc_cutoff = 1
top_n = 20
cpu = 10


### Stable Parameter

In [4]:
adata_path = 'Adata/cell_by_feature.cov_filter.rate.h5ad'
max_test_cell_population = 1000
random_seed = 0
chunk_size=200
output_dir = 'Markers'
output_dir = pathlib.Path(output_dir)
output_dir.mkdir(exist_ok=True)

## Load Data

In [5]:
adata = anndata.read_h5ad(adata_path)
adata

AnnData object with n_obs × n_vars = 2637 × 24027 
    obs: 'AllcPath', 'CCC_Rate', 'CG_Rate', 'CG_RateAdj', 'CH_Rate', 'CH_RateAdj', 'FinalReads', 'InputReads', 'MappedReads', 'Region', 'index_name', 'uid', 'BamFilteringRate', 'MappingRate', 'Pos96', 'Plate', 'Col96', 'Row96', 'Col384', 'Row384', 'FACS_Date', 'Slice', 'CellClass', 'l1-umap_0', 'l1-umap_1', 'l1-tsne_0', 'l1-tsne_1', 'MajorType', 'l2-umap_0', 'l2-umap_1', 'l2-tsne_0', 'l2-tsne_1', 'SubType', 'l3-umap_0', 'l3-umap_1', 'l3-tsne_0', 'l3-tsne_1', 'L1CellClass', 'class_tsne_0', 'class_tsne_1', 'class_umap_0', 'class_umap_1', 'Order', 'RegionName', 'MajorRegion', 'SubRegion', 'DetailRegion', 'PotentialOverlap (MMB)', 'Anterior (CCF coords)', 'Posterior (CCF coords)', 'MajorRegionColor', 'SubRegionColor', 'DissectionRegionColor'
    var: 'chrom', 'start', 'end'

In [6]:
cluster_series = adata.obs[cluster_col].astype(str)
cluster_counts = cluster_series.value_counts()
def check_cluster(cluster, count):
    if count < min_cluster_cell_number:
        return False
    for exclude in exclude_str:
        if exclude in cluster:
            return False
    return True

unique_clusters = [cluster for cluster, count in cluster_counts.items() if check_cluster(cluster, count)]
cluster_pairs = list(combinations(unique_clusters, 2))

print(len(unique_clusters), 'pass filter.')
print(len(cluster_pairs), 'pairwise comparison to test.')

9 pass filter.
36 pairwise comparison to test.


In [7]:
# filter adata as well
adata = adata[adata.obs[cluster_col].isin(unique_clusters), :]

## Pairwise test

In [8]:
def get_sig_features(rank_gene_dict):
    pvals_adj = pd.DataFrame(rank_gene_dict['pvals_adj'])
    names = pd.DataFrame(rank_gene_dict['names'])
    logfoldchanges = pd.DataFrame(rank_gene_dict['logfoldchanges'])
    gene_set = set(
        names.where((pvals_adj < adj_p_cutoff)
                    & (logfoldchanges.abs() > log2fc_cutoff)).values.flat)
    return gene_set


def pairwise_tests(adata_path, pairs):
    adata = anndata.read_h5ad(adata_path)
    total_markers = set()
    pair_marker_count = {}
    for pair in pairs:
        cluster_a, cluster_b = pair
        cluster_a_cells = cluster_series[cluster_series == cluster_a]
        if cluster_a_cells.size > max_test_cell_population:
            cluster_a_cells = cluster_a_cells.sample(max_test_cell_population,
                                                     random_state=random_seed)
        cluster_a_cells = cluster_a_cells.index
    
        cluster_b_cells = cluster_series[cluster_series == cluster_b]
        if cluster_b_cells.size > max_test_cell_population:
            cluster_b_cells = cluster_b_cells.sample(max_test_cell_population,
                                                     random_state=random_seed)
        cluster_b_cells = cluster_b_cells.index
        cells = cluster_a_cells | cluster_b_cells
        
        pair_adata = adata[cells, :].copy()
        pair_adata.obs['cluster'] = pair_adata.obs[cluster_col].astype(str).astype('category')
        
        rank_features_groups(pair_adata, groupby='cluster', n_genes=top_n)
        gene_set = get_sig_features(pair_adata.uns['rank_genes_groups'])
        # drop na
        gene_set = set([i for i in gene_set if isinstance(i, str)])
        
        total_markers.update(gene_set)
        pair_marker_count[pair] = len(gene_set)
    return total_markers, pair_marker_count

In [9]:
total_markers = set()
pair_marker_counts = {}
with ProcessPoolExecutor(cpu) as executor:
    futures = []
    for chunk_start in range(0, len(cluster_pairs), chunk_size):
        pair_chunk = cluster_pairs[chunk_start:chunk_start+chunk_size]
        future = executor.submit(pairwise_tests, adata_path, pair_chunk)
        futures.append(future)
        
    for future in as_completed(futures):
        genes, pair_marker_count = future.result()
        total_markers.update(genes)
        pair_marker_counts.update(pair_marker_count)

## Save results

In [10]:
total_markers = set([i for i in total_markers if isinstance(i, str)])
with open('Markers/cluster_markers.txt', 'w') as f:
    f.write('\n'.join(total_markers))

In [11]:
pair_marker_counts = pd.Series(pair_marker_counts)
marker_counts = pair_marker_counts.reset_index()
marker_counts.columns = ['ClusterA', 'ClusterB', 'GeneCount']
marker_counts.to_csv('Markers/cluster_pair_marker_counts.csv', index=None)

In [12]:
marker_counts[marker_counts['GeneCount'] < 3]

Unnamed: 0,ClusterA,ClusterB,GeneCount
1,PT-L5 Tenm2,PT-L5 Ptprt,0
16,PT-L5 Ptprt,PT-L5 Plcb4,0


In [13]:
marker_adata = adata[:, list(total_markers)]
marker_adata.write_h5ad('Markers/cluster_markers.h5ad')

In [14]:
marker_adata

AnnData object with n_obs × n_vars = 2637 × 322 
    obs: 'AllcPath', 'CCC_Rate', 'CG_Rate', 'CG_RateAdj', 'CH_Rate', 'CH_RateAdj', 'FinalReads', 'InputReads', 'MappedReads', 'Region', 'index_name', 'uid', 'BamFilteringRate', 'MappingRate', 'Pos96', 'Plate', 'Col96', 'Row96', 'Col384', 'Row384', 'FACS_Date', 'Slice', 'CellClass', 'l1-umap_0', 'l1-umap_1', 'l1-tsne_0', 'l1-tsne_1', 'MajorType', 'l2-umap_0', 'l2-umap_1', 'l2-tsne_0', 'l2-tsne_1', 'SubType', 'l3-umap_0', 'l3-umap_1', 'l3-tsne_0', 'l3-tsne_1', 'L1CellClass', 'class_tsne_0', 'class_tsne_1', 'class_umap_0', 'class_umap_1', 'Order', 'RegionName', 'MajorRegion', 'SubRegion', 'DetailRegion', 'PotentialOverlap (MMB)', 'Anterior (CCF coords)', 'Posterior (CCF coords)', 'MajorRegionColor', 'SubRegionColor', 'DissectionRegionColor'
    var: 'chrom', 'start', 'end'