## Marker Selection

In [1]:
import pathlib
import random
import warnings
warnings.filterwarnings('ignore')

from concurrent.futures import ProcessPoolExecutor, as_completed
from itertools import combinations

import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
import xarray as xr
from cemba_data.tools.hdf5.anndata import rank_features_groups
from sklearn.metrics import roc_auc_score


## Parameter

In [2]:
cluster_col = 'MajorType'
study_name = 'Test'

use_clusters = [
    'MGE-Sst', 'CA3', 'CA1', 'CA3-St18', 'Unc5c', 'Gfra1', 'IT-L5', 'NP-L6',
    'CGE-Lamp5', 'CT-L6', 'IG-CA2', 'DG-po', 'DG', 'CGE-Vip', 'PAL-Inh',
    'PT-L5', 'MGE-Pvalb', 'OLF', 'MSN-D2', 'L6b', 'IT-L6', 'IT-L23', 'IT-L4',
    'OLF-Exc', 'CLA', 'Foxp2', 'MSN-D1', 'LSX-Inh', 'D1L-Fstl4', 'D1L-PAL',
    'Chd7'
]

cpu = 30
top_n = 10000
adj_p_cutoff = 1e-2
delta_rate_cutoff = -0.3
min_cluster_cell_number = 10
max_test_cell_population = 10
auroc_cutoff = 0.8

### Stable Parameter

In [3]:
random_seed = 0
chunk_size = 1
output_dir = 'Markers'
tidy_data_path = '/home/hanliu/project/mouse_rostral_brain/study/ClusteringSummary/Summary/TotalClusteringResults.msg'
output_dir = pathlib.Path(output_dir)
output_dir.mkdir(exist_ok=True)

## Load Data

### Cell Tidy Data

In [4]:
cell_tidy_data = pd.read_msgpack(tidy_data_path)

if use_clusters is not None:
    cell_tidy_data = cell_tidy_data[cell_tidy_data[cluster_col].isin(use_clusters)]
cell_tidy_data.shape[0]

95149

In [5]:
records = []
for cluster, sub_df in cell_tidy_data.groupby(cluster_col):
    if sub_df.shape[0] < max_test_cell_population:
        records.append(sub_df)
    else:
        records.append(sub_df.sample(max_test_cell_population, random_state=random_seed))
cell_tidy_data = pd.concat(records)
cell_tidy_data[cluster_col].value_counts()

IG-CA2       10
Unc5c        10
Gfra1        10
IT-L4        10
MGE-Pvalb    10
PAL-Inh      10
D1L-Fstl4    10
IT-L23       10
NP-L6        10
D1L-PAL      10
DG-po        10
Foxp2        10
CA1          10
IT-L6        10
Chd7         10
OLF-Exc      10
PT-L5        10
CGE-Vip      10
CLA          10
DG           10
L6b          10
CGE-Lamp5    10
CT-L6        10
OLF          10
MSN-D1       10
CA3-St18     10
IT-L5        10
LSX-Inh      10
MSN-D2       10
MGE-Sst      10
CA3          10
Name: MajorType, dtype: int64

### Gene meta

In [6]:
gene_meta = pd.read_csv(
    '/home/hanliu/project/mouse_rostral_brain/study/ClusterMethylMarker/gencode.vM22.annotation.gene.flat.filtered_white_genes.tsv.gz',
    index_col='gene_id',
    sep='\t')
gene_meta.index.name = 'gene'
gene_name_to_id = {v: k for k, v in gene_meta['gene_name'].iteritems()}
gene_idbase_to_id = {i.split('.')[0]: i for i in gene_meta.index}

### Adata

In [7]:
gene_mcds = xr.open_mfdataset(
    '/home/hanliu/project/mouse_rostral_brain/study/Level1-CellClass/ALL_manual/Adata/SelectedCell.gene_da_rate.*.mcds'
)
use_gene = gene_mcds.get_index('gene') & gene_meta.index
gene_meta = gene_meta.reindex(use_gene)

In [8]:
gene_mcds

<xarray.Dataset>
Dimensions:      (cell: 104340, gene: 55487, mc_type: 2)
Coordinates:
  * mc_type      (mc_type) object 'CGN' 'CHN'
  * gene         (gene) object 'ENSMUSG00000102693.1' ... 'ENSMUSG00000064372.1'
    strand_type  <U4 'both'
    gene_chrom   (gene) object 'chr1' 'chr1' 'chr1' ... 'chrM' 'chrM' 'chrM'
    gene_start   (gene) int64 3073252 3102015 3205900 ... 14144 15288 15355
    gene_end     (gene) int64 3074321 3102124 3671497 ... 15287 15354 15421
  * cell         (cell) object '1A_M_0' '1A_M_1' ... '8J_M_1291' '8J_M_1292'
Data variables:
    gene_da      (cell, gene, mc_type) float64 dask.array<shape=(104340, 55487, 2), chunksize=(10000, 55487, 2)>

In [9]:
gene_mcds = gene_mcds['gene_da'].sel(mc_type='CHN', 
                                     cell=cell_tidy_data.index,
                                     gene=gene_meta.index)
gene_mcds

<xarray.DataArray 'gene_da' (cell: 310, gene: 19828)>
dask.array<shape=(310, 19828), dtype=float64, chunksize=(2, 19828)>
Coordinates:
    mc_type      <U3 'CHN'
  * gene         (gene) object 'ENSMUSG00000051951.5' ... 'ENSMUSG00000095950.2'
    strand_type  <U4 'both'
    gene_chrom   (gene) object 'chr1' 'chr1' 'chr1' ... 'chrY' 'chrY' 'chrY'
    gene_start   (gene) int64 3205900 3999556 4490930 ... 10412736 10533608
    gene_end     (gene) int64 3671497 4409240 4497353 ... 10444690 10536041
  * cell         (cell) object '9H_M_1742' '9H_M_2429' ... '8B_M_2733'

In [10]:
adata = anndata.AnnData(X=gene_mcds.values,
                        obs=pd.DataFrame([], gene_mcds.get_index('cell')),
                        var=pd.DataFrame([], gene_mcds.get_index('gene')))
adata

## Calculate Cluster Mean

In [12]:
records = {}
for cluster, sub_df in cell_tidy_data.groupby(cluster_col):
    sub_adata = adata[sub_df.index, :]
    gene_mean = sub_adata.X.mean(axis=0)
    records[cluster] = pd.Series(gene_mean, index=sub_adata.var_names)
cluster_mean_gene_df = pd.DataFrame(records)

## Filter by cluster delta

In [13]:
gene_cluster_delta = cluster_mean_gene_df.max(axis=1) - cluster_mean_gene_df.min(axis=1)
delta_judge = gene_cluster_delta > delta_rate_cutoff

In [14]:
adata = adata[:, delta_judge].copy()

## One-vs-rest test

In [15]:
def get_sig_features(rank_gene_dict):
    pvals_adj = pd.DataFrame(rank_gene_dict['pvals_adj'])
    names = pd.DataFrame(rank_gene_dict['names'])
    return pvals_adj, names

In [16]:
adata.obs['cluster'] = cell_tidy_data[cluster_col].copy()

In [17]:
# reverse_adata, centered by 1 because after normalization all prior is center to 1
adata.X = (adata.X - 1) * -1 + 1

In [18]:
sc.tl.rank_genes_groups(adata, 'cluster', 
                        n_genes=top_n, method='wilcoxon')
pvals_adj, names = get_sig_features(adata.uns['rank_genes_groups'])

... storing 'cluster' as categorical


In [19]:
# sc.pl.rank_genes_groups_dotplot(adata)

## Get sig gene df

In [20]:
results = []
for col in cell_tidy_data[cluster_col].unique():
    df = pd.DataFrame({'pvals_adj': pvals_adj[col].tolist(), 
                       'gene_id': names[col].tolist()})
    df['cluster'] = col
    results.append(df)
total_results = pd.concat(results)
total_results['gene_name'] = total_results['gene_id'].map(gene_meta['gene_name'])
total_results['-lgp'] = -np.log10(total_results['pvals_adj'])
total_results['-lgp'] = total_results['-lgp'].replace(np.inf, 1000)
total_results = total_results[total_results['pvals_adj'] < adj_p_cutoff].copy()

In [21]:
total_results.shape[0]

4541

## add rate_delta

In [23]:
def get_delta(cluster, gene):
    row = cluster_mean_gene_df.loc[gene].copy()
    cluster_value = row.pop(cluster)
    other_mean = row.mean()
    delta = cluster_value - other_mean
    return delta

In [24]:
total_results['cluster_delta'] = total_results.apply(lambda i: get_delta(i['cluster'], i['gene_id']), axis=1)

In [25]:
total_results = total_results[total_results['cluster_delta'] < delta_rate_cutoff].copy()

In [26]:
total_results.shape[0]

4417

## Add AUROC and filter by AUROC

In [27]:
def get_auroc(gene_id, cluster):
    yscore = adata.obs_vector(gene_id)
    ylabel = adata.obs['cluster'] == cluster
    score = roc_auc_score(ylabel, yscore)
    score = abs(score - 0.5) + 0.5
    return score

In [28]:
total_results['AUROC'] = total_results[['gene_id', 'cluster']].apply(
    lambda i: get_auroc(i['gene_id'], i['cluster']), axis=1)

In [29]:
total_results = total_results[total_results['AUROC'] > auroc_cutoff]

## Save final list

In [32]:
total_results['ref'] = 'Rest'
final = total_results[['gene_name', 'gene_id', 'cluster',
                       'ref', 'pvals_adj', 'cluster_delta', 'AUROC']].copy()
final.to_csv(f'{cluster_col}.{study_name}.one_vs_rest_DEG.csv')

In [33]:
print(final['cluster'].unique().size)
final['cluster'].value_counts()

30


D1L-PAL      665
Gfra1        560
OLF          376
Unc5c        305
DG           272
MGE-Pvalb    254
CA3          246
CA3-St18     194
L6b          160
MSN-D1       160
DG-po        138
NP-L6        136
CLA          123
PT-L5        111
CGE-Lamp5     97
IG-CA2        87
Chd7          79
MSN-D2        74
CA1           64
CT-L6         50
IT-L4         49
MGE-Sst       39
D1L-Fstl4     35
OLF-Exc       33
IT-L23        26
IT-L6         25
Foxp2         18
IT-L5         15
LSX-Inh       15
CGE-Vip       11
Name: cluster, dtype: int64