## Differential Expression (DE) and Differential Splicing (DS) analysis [M132TS Downsampling Analysis]

**Inputs and Outputs**
- Inputs:
  - Harmonized and annotated short-read and long-read AnnData (raw, SCT)
- Outputs:
  - Figures
  - Tables of global and per-cluster DE and DS pvalues for all genes.

In [1]:
# %matplotlib inline

import os
import sys
import matplotlib.pylab as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import logging
from operator import itemgetter

import scanpy as sc
import anndata
from umap import UMAP

SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

logger = logging.getLogger()
logger.setLevel(logging.INFO)
log_info = logger.warning

import warnings
warnings.filterwarnings("ignore")

sc.settings.set_figure_params(dpi=80, facecolor='white')
mpl.rcParams.update(mpl.rcParamsDefault)

In [2]:
# load sample metadata
import yaml

with open('./downsampling_series_sample_metadata.yaml', 'r') as f:
    sample_meta_dict = yaml.safe_load(f)    

notebook_mode = True

if not notebook_mode:
    sample_key = sys.argv[1]
else:
    sample_key = 'gencode_1m'
    
log_info(f'Processing {sample_key} ...')

Processing gencode_1m ...


In [7]:
repo_root = '/home/jupyter/mb-ml-data-disk/MAS-seq-analysis'
long_tx_counts_root ='data/t-cell-vdj/long/quant/ds'
short_h5_path = 'output/t-cell-vdj-cite-seq/M132TS_immune.h5ad'
output_root = 'output/t-cell-vdj-cite-seq/ds'
fig_output_root = 'output/t-cell-vdj-cite-seq/ds/figures'
misc_output_root = 'output/t-cell-vdj-cite-seq/ds/misc'
output_prefix = 'M132TS_immune'

output_prefix_full = f'{output_prefix}_{sample_key}'

final_long_adata_raw_h5_path = os.path.join(
    repo_root, output_root, f'{output_prefix_full}_final_long_raw.h5ad')

group_resolved_ds_df_csv_path = os.path.join(
    repo_root, output_root, f'{output_prefix_full}_group_resolved_ds.csv')

group_resolved_ds_de_df_csv_path =  os.path.join(
    repo_root, output_root, f'{output_prefix_full}_group_resolved_ds_de.csv')

global_ds_df_csv_path =  os.path.join(
    repo_root, output_root, f'{output_prefix_full}_global_ds.csv')

global_ds_de_df_csv_path =  os.path.join(
    repo_root, output_root, f'{output_prefix_full}_global_ds_de.csv')

# constants
GENE_IDS_KEY = 'gencode_overlap_gene_ids'
GENE_NAMES_KEY = 'gencode_overlap_gene_names'

## Configuration

In [8]:
min_cells_per_transcript = 1
min_cells_per_gene = 10

group_cells_by_obs_key = 'mehrtash_leiden'

n_mc_samples = 1_000_000

## Preprocess

In [9]:
adata_long = sc.read(os.path.join(repo_root, final_long_adata_raw_h5_path))

In [10]:
adata_long

AnnData object with n_obs × n_vars = 5276 × 234552
    obs: 'CD45_TotalSeqC', 'CD45R_B220_TotalSeqC', 'CD45RA_TotalSeqC', 'CD45RO_TotalSeqC', 'mehrtash_leiden'
    var: 'transcript_ids', 'gene_ids', 'gene_names', 'transcript_names', 'de_novo_gene_ids', 'de_novo_transcript_ids', 'is_de_novo', 'is_gene_id_ambiguous', 'is_tcr_overlapping', 'gencode_overlap_gene_names', 'gencode_overlap_gene_ids', 'is_gencode_gene_overlap_ambiguous'
    uns: 'mehrtash_leiden_colors'
    obsm: 'X_pca_SCT_long', 'X_pca_SCT_short', 'X_pca_raw_short', 'X_tsne_raw_short', 'X_umap_SCT_long', 'X_umap_SCT_short'

In [11]:
total_umis = adata_long.X.sum()
log_info(f'Total UMIs: {total_umis}')

Total UMIs: 329274.0


## Filtering

In [12]:
# remove genes that are lowly expressed
from collections import defaultdict
gene_id_to_tx_indices_map = defaultdict(list)
for i, gid in enumerate(adata_long.var[GENE_IDS_KEY].values):
    gene_id_to_tx_indices_map[gid].append(i)

included_gene_ids = []
tx_counts_i = np.asarray(adata_long.X.sum(0)).flatten()
for gid, tx_indices in gene_id_to_tx_indices_map.items():
    if np.sum(tx_counts_i[tx_indices]) >= min_cells_per_gene:
        included_gene_ids.append(gid)

adata_long = adata_long[:, adata_long.var[GENE_IDS_KEY].values.isin(included_gene_ids)]

# remove transcript that are very lowly expressed
sc.pp.filter_genes(adata_long, min_cells=min_cells_per_transcript)
tpm_threshold = 1_000_000 * min_cells_per_transcript / total_umis

log_info(f'Removing isoforms with TPM < {tpm_threshold:.2f}')

Trying to set attribute `.var` of view, copying.
Removing isoforms with TPM < 3.04


In [13]:
adata_long

AnnData object with n_obs × n_vars = 5276 × 8807
    obs: 'CD45_TotalSeqC', 'CD45R_B220_TotalSeqC', 'CD45RA_TotalSeqC', 'CD45RO_TotalSeqC', 'mehrtash_leiden'
    var: 'transcript_ids', 'gene_ids', 'gene_names', 'transcript_names', 'de_novo_gene_ids', 'de_novo_transcript_ids', 'is_de_novo', 'is_gene_id_ambiguous', 'is_tcr_overlapping', 'gencode_overlap_gene_names', 'gencode_overlap_gene_ids', 'is_gencode_gene_overlap_ambiguous', 'n_cells'
    uns: 'mehrtash_leiden_colors'
    obsm: 'X_pca_SCT_long', 'X_pca_SCT_short', 'X_pca_raw_short', 'X_tsne_raw_short', 'X_umap_SCT_long', 'X_umap_SCT_short'

## Isoform DE analysis

In [14]:
# mapping from gene id to spanning tx icatces
from collections import defaultdict
gene_id_to_tx_indices_map = defaultdict(list)
for i, gid in enumerate(adata_long.var[GENE_IDS_KEY].values):
    gene_id_to_tx_indices_map[gid].append(i)

# useful auxiliary data structures    
gene_ids = sorted(list(gene_id_to_tx_indices_map.keys()))
n_genes = len(gene_ids)
gene_id_to_gene_name_map = {
    gene_id: gene_name for gene_id, gene_name in zip(adata_long.var[GENE_IDS_KEY], adata_long.var[GENE_NAMES_KEY])}
gene_name_to_gene_id_map = {
    gene_name: gene_id for gene_id, gene_name in zip(adata_long.var[GENE_IDS_KEY], adata_long.var[GENE_NAMES_KEY])}
gene_names = list(map(gene_id_to_gene_name_map.get, gene_ids))

# mapping from gene id to spanning tx indices
group_ids = adata_long.obs[group_cells_by_obs_key].values.categories.values
group_id_to_obs_indices_map = defaultdict(list)
for group_id in group_ids:
    group_id_to_obs_indices_map[group_id] = [
        idx for idx in range(len(adata_long))
        if adata_long.obs[group_cells_by_obs_key].values[idx] == group_id]
    
# reduce tx expression by group (e.g. leiden clusters)
n_transcripts = adata_long.shape[1]
n_groups = len(group_id_to_obs_indices_map)
group_expr_gi = np.zeros((n_groups, n_transcripts), dtype=np.int)
for i_group, group_id in enumerate(group_ids):
    group_expr_gi[i_group, :] = np.asarray(adata_long.X[group_id_to_obs_indices_map[group_id], :].sum(0)).flatten()

## Global differential splicing analysis

In [18]:
from typing import Dict, Tuple, Any, List
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from statsmodels.stats.multitest import multipletests
from rpy2 import rinterface as ri
from rpy2 import robjects
from rpy2.rinterface_lib.embedded import RRuntimeError


ri.initr()


def get_global_ds_pval_fisher(
        gene_id: str,
        gene_id_to_tx_indices_map: Dict[str, List[int]],
        group_expr_gi: np.ndarray,
        n_mc_samples: int = 10_000) -> Dict[str, Any]:

    tx_indices = gene_id_to_tx_indices_map[gene_id]
    n_groups = group_expr_gi.shape[0]
    x_gi = group_expr_gi[:, tx_indices]
    x_g = np.sum(x_gi, -1)
    n_nnz_groups = np.sum(x_g > 0)

    if len(tx_indices) < 2 or n_nnz_groups < 2:
        return {
            'pval': 1.,
            'x_g': x_g
        }
    
    def rimport(packname):
        as_environment = ri.baseenv['as.environment']
        require = ri.baseenv['require']
        require(ri.StrSexpVector([packname]),
                quiet = ri.BoolSexpVector((True, )))
        packname = ri.StrSexpVector(['package:' + str(packname)])
        pack_env = as_environment(packname)
        return pack_env
    
    rstats = rimport('stats')
    
    args = (('x', robjects.r.matrix(robjects.IntVector(x_gi.flat), nrow=n_groups, byrow=True)),
            ('simulate.p.value', True),
            ('B', n_mc_samples))
    
    try:
        out = rstats['fisher.test'].rcall(args, ri.globalenv)
        pval = float(np.array(np.array(out)[0])[0])
        
    except RRuntimeError:
        log_info(f'Could not process {gene_id} -- returning p-value = 1. and continuing.')

    return {
        'pval': pval,
        'x_g': x_g
    }

In [19]:
gene_id_to_p_values_map = dict()
gene_id_to_x_g_map = dict()

In [20]:
from typing import Tuple

remaining_gene_ids = list(set(gene_ids).difference(gene_id_to_p_values_map.keys()))
num_processes = cpu_count()

def process_gene_id(gene_id: str) -> Dict[str, Any]:

    out = get_global_ds_pval_fisher(
        gene_id=gene_id,
        gene_id_to_tx_indices_map=gene_id_to_tx_indices_map,
        group_expr_gi=group_expr_gi,
        n_mc_samples=n_mc_samples)
    
    return {
        'gene_id': gene_id,
        'pval': out['pval'],
        'x_g': out['x_g']
    }

with Pool(processes=num_processes) as pool:
    for result in tqdm(pool.imap(func=process_gene_id, iterable=remaining_gene_ids), total=len(remaining_gene_ids)):
        gene_id_to_p_values_map[result['gene_id']] = result['pval']
        gene_id_to_x_g_map[result['gene_id']] = result['x_g']

100%|██████████| 2940/2940 [00:02<00:00, 1286.94it/s]


In [21]:
# generate dataframe
group_total_counts_dict = dict()
for i_group, group_iad in enumerate(group_ids):
    group_total_counts_dict[f'expr_{i_group}'] = list(map(lambda gene_id: gene_id_to_x_g_map[gene_id][i_group], gene_ids))
group_total_counts_dict['total_expr'] = list(map(lambda gene_id: np.sum(gene_id_to_x_g_map[gene_id]), gene_ids))

ds_pval_global = list(map(gene_id_to_p_values_map.get, gene_ids))
_, ds_pval_global_adj, _, _ = multipletests(ds_pval_global, alpha=0.05, method='fdr_bh')

global_ds_df = pd.DataFrame({
    **dict(
        gene_ids=gene_ids),
    **group_total_counts_dict,
    **dict(
        ds_pval_global=ds_pval_global,
        ds_pval_global_adj=ds_pval_global_adj)},
    index=list(map(gene_id_to_gene_name_map.get, gene_ids)))

global_ds_df.to_csv(global_ds_df_csv_path)

In [22]:
global_ds_df.head()

Unnamed: 0,gene_ids,expr_0,expr_1,expr_2,expr_3,expr_4,expr_5,expr_6,expr_7,total_expr,ds_pval_global,ds_pval_global_adj
BAD,ENSG00000002330.14,2,0,2,0,1,3,6,2,16,1.0,1.0
LAP3,ENSG00000002549.13,0,2,4,0,5,0,7,1,19,0.867133,1.0
CD99,ENSG00000002586.20,26,3,11,3,35,19,40,9,146,0.632368,1.0
CD99_PAR_Y,ENSG00000002586.20_PAR_Y,29,6,18,4,28,18,45,14,162,0.711289,1.0
MAD1L1,ENSG00000002822.16,3,2,0,0,3,1,3,6,18,0.05994,1.0


## Group-resolved differential splicing analysis

In [23]:
from typing import Dict, Tuple, Any
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from statsmodels.stats.multitest import multipletests
from rpy2 import rinterface as ri
from rpy2 import robjects
from rpy2.rinterface_lib.embedded import RRuntimeError


ri.initr()


def get_group_resolved_ds_pval_fisher(
        gene_id: str,
        gene_id_to_tx_indices_map: Dict[str, List[int]],
        group_expr_gi: np.ndarray,
        n_mc_samples: int = 10_000) -> Dict[str, Any]:

    tx_indices = gene_id_to_tx_indices_map[gene_id]
    n_groups = group_expr_gi.shape[0]
    x_gi = group_expr_gi[:, tx_indices]
    x_g = np.sum(x_gi, -1)
    
    if len(tx_indices) < 2:
        return {
            'pval_g': np.ones((n_groups,)),
            'x_g': x_g
        }

    # get leave-one-out sums
    y_ggi = np.ascontiguousarray(np.repeat(x_gi[None, ...], n_groups, axis=0))
    y_ggi[np.arange(n_groups), np.arange(n_groups), :] = 0.
    y_gi = np.sum(y_ggi, axis=1)

    def rimport(packname):
        as_environment = ri.baseenv['as.environment']
        require = ri.baseenv['require']
        require(ri.StrSexpVector([packname]),
                quiet = ri.BoolSexpVector((True, )))
        packname = ri.StrSexpVector(['package:' + str(packname)])
        pack_env = as_environment(packname)
        return pack_env
    
    rstats = rimport('stats')

    pval_g = np.zeros((n_groups,))
    
    for i_group in range(n_groups):
        
        x_i = x_gi[i_group, :]
        y_i = y_gi[i_group, :]
        
        if np.sum(x_i) == 0 or np.sum(y_i) == 0:
            pval = 1.
        
        else:            
            contingency_table = np.hstack((x_i, y_i))
            args = (('x', robjects.r.matrix(robjects.IntVector(contingency_table.flat), nrow=2, byrow=True)),
                    ('simulate.p.value', True),
                    ('B', n_mc_samples))
            try:
                out = rstats['fisher.test'].rcall(args, ri.globalenv)
                pval = float(np.array(np.array(out)[0])[0])
            except RRuntimeError:
                print(contingency_table)
                raise RuntimeError
        
        pval_g[i_group] = pval

    return {
        'pval_g': pval_g,
        'x_g': x_g
    }

In [24]:
gene_id_to_p_values_map = dict()
gene_id_to_x_g_map = dict()

In [25]:
from typing import Tuple

remaining_gene_ids = list(set(gene_ids).difference(gene_id_to_p_values_map.keys()))
num_processes = cpu_count()

def process_gene_id(gene_id: str) -> Dict[str, Any]:

    out = get_group_resolved_ds_pval_fisher(
        gene_id=gene_id,
        gene_id_to_tx_indices_map=gene_id_to_tx_indices_map,
        group_expr_gi=group_expr_gi,
        n_mc_samples=n_mc_samples)
    
    return {
        'gene_id': gene_id,
        'pval_g': out['pval_g'],
        'x_g': out['x_g']
    }

with Pool(processes=num_processes) as pool:
    for result in tqdm(pool.imap(func=process_gene_id, iterable=remaining_gene_ids), total=len(remaining_gene_ids)):
        gene_id_to_p_values_map[result['gene_id']] = result['pval_g']
        gene_id_to_x_g_map[result['gene_id']] = result['x_g']

100%|██████████| 2940/2940 [00:04<00:00, 672.97it/s]


In [26]:
# generate dataframe
pvalues_dict = dict()
for i_group, group_id in enumerate(group_ids):
    ds_pval_group = list(map(lambda gene_id: gene_id_to_p_values_map[gene_id][i_group], gene_ids))
    _, ds_pval_group_adj, _, _ = multipletests(ds_pval_group, alpha=0.05, method='fdr_bh')
    pvalues_dict[f'ds_pval_{i_group}'] = ds_pval_group
    pvalues_dict[f'ds_pval_adj_{i_group}'] = ds_pval_group_adj

group_total_counts_dict = dict()
for i_group, group_id in enumerate(group_ids):
    group_total_counts_dict[f'expr_{i_group}'] = list(map(lambda gene_id: gene_id_to_x_g_map[gene_id][i_group], gene_ids))
group_total_counts_dict['total_expr'] = list(map(lambda gene_id: np.sum(gene_id_to_x_g_map[gene_id]), gene_ids))

group_resolved_ds_df = pd.DataFrame({
    **dict(gene_ids=gene_ids),
    **pvalues_dict,
    **group_total_counts_dict},
    index=list(map(gene_id_to_gene_name_map.get, gene_ids)),)

group_resolved_ds_df.to_csv(group_resolved_ds_df_csv_path)

In [27]:
group_resolved_ds_df.head()

Unnamed: 0,gene_ids,ds_pval_0,ds_pval_adj_0,ds_pval_1,ds_pval_adj_1,ds_pval_2,ds_pval_adj_2,ds_pval_3,ds_pval_adj_3,ds_pval_4,...,ds_pval_adj_7,expr_0,expr_1,expr_2,expr_3,expr_4,expr_5,expr_6,expr_7,total_expr
BAD,ENSG00000002330.14,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,2,0,2,0,1,3,6,2,16
LAP3,ENSG00000002549.13,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,0,2,4,0,5,0,7,1,19
CD99,ENSG00000002586.20,0.509491,1.0,1.0,1.0,0.529471,1.0,0.337662,1.0,0.40959,...,1.0,26,3,11,3,35,19,40,9,146
CD99_PAR_Y,ENSG00000002586.20_PAR_Y,0.78022,1.0,1.0,1.0,0.847153,1.0,0.432567,1.0,0.422577,...,1.0,29,6,18,4,28,18,45,14,162
MAD1L1,ENSG00000002822.16,0.31968,1.0,0.36963,1.0,1.0,1.0,1.0,1.0,0.714286,...,1.0,3,2,0,0,3,1,3,6,18


## Group-resolved differential expression analysis

In [28]:
import scipy

# get gene expression from isoform expression
row_indices = []
col_indices = []
values = []
for j, gene_id in enumerate(gene_ids):
    tx_indices = gene_id_to_tx_indices_map[gene_id]
    row_indices += tx_indices
    col_indices += [j] * len(tx_indices)
    values += [1] * len(tx_indices)
Y_ij = scipy.sparse.coo_matrix((values, (row_indices, col_indices)), shape=(n_transcripts, n_genes)).tocsr()
gex_X_nj = adata_long.X @ Y_ij

# normalize
adata_long_gex = sc.AnnData(
    X=gex_X_nj,
    obs=adata_long.obs,
    var=pd.DataFrame(index=pd.Index(list(map(gene_id_to_gene_name_map.get, gene_ids)))))

adata_long_gex.var_names_make_unique()
sc.pp.normalize_per_cell(adata_long_gex)
sc.pp.log1p(adata_long_gex)

adata_short_sct = adata_long_gex

In [29]:
group_resolved_ds_df = pd.read_csv(group_resolved_ds_df_csv_path, index_col=0)

sc.tl.rank_genes_groups(adata_short_sct, 'mehrtash_leiden', method='t-test', use_raw=False)

result = adata_short_sct.uns['rank_genes_groups']
n_groups = len(group_ids)
group_id_to_group_idx_map = {group_id: group_idx for group_idx, group_id in enumerate(group_ids)}
key_to_col_name_map = {'names': 'gene_names', 'pvals': 'de_pval', 'pvals_adj': 'de_pval_adj_scanpy'}
de_df = pd.DataFrame(
    {key_to_col_name_map[key] + '_' + str(group_id_to_group_idx_map[group_id]): result[key][group_id]
     for group_id in group_ids
     for key in ['names', 'pvals', 'pvals_adj']})

de_gene_names_list = adata_short_sct.var.index.values
de_pval_list = np.zeros((len(de_gene_names_list),))
de_gene_name_to_index_map = {gene_name: idx for idx, gene_name in enumerate(de_gene_names_list)}

In [30]:
from collections import Counter

ds_gene_names_list = group_resolved_ds_df.index.values.tolist()
mutual_gene_names_list = list(set(ds_gene_names_list).intersection(de_gene_names_list))
group_resolved_ds_df = group_resolved_ds_df.loc[mutual_gene_names_list]

gene_name_counter = Counter(group_resolved_ds_df.index.values)
bad_gene_names = set()
for gene_name, multiplicity in gene_name_counter.items():
    if multiplicity > 1:
        bad_gene_names.add(gene_name)

mutual_gene_names_list = list(set(mutual_gene_names_list).difference(bad_gene_names))
group_resolved_ds_df = group_resolved_ds_df.loc[mutual_gene_names_list]

In [31]:
group_resolved_ds_de_df = group_resolved_ds_df.copy()

for group_idx, group_id in enumerate(group_ids):
    de_group_gene_names = de_df[f'gene_names_{group_idx}'].values
    de_group_pvals = de_df[f'de_pval_{group_idx}'].values
    _, de_group_pvals_adj, _, _ = multipletests(de_group_pvals, alpha=0.05, method='fdr_bh')
    de_group_gene_name_to_index_map = {gene_name: index for index, gene_name in enumerate(de_group_gene_names)}
    group_resolved_ds_de_df[f'de_pval_{group_idx}'] = de_group_pvals[
        list(map(de_group_gene_name_to_index_map.get, mutual_gene_names_list))]
    group_resolved_ds_de_df[f'de_pval_adj_{group_idx}'] = de_group_pvals_adj[
        list(map(de_group_gene_name_to_index_map.get, mutual_gene_names_list))]

In [32]:
group_resolved_ds_de_df.to_csv(group_resolved_ds_de_df_csv_path)

In [33]:
group_resolved_ds_de_df.head()

Unnamed: 0,gene_ids,ds_pval_0,ds_pval_adj_0,ds_pval_1,ds_pval_adj_1,ds_pval_2,ds_pval_adj_2,ds_pval_3,ds_pval_adj_3,ds_pval_4,...,de_pval_3,de_pval_adj_3,de_pval_4,de_pval_adj_4,de_pval_5,de_pval_adj_5,de_pval_6,de_pval_adj_6,de_pval_7,de_pval_adj_7
IFNAR2,ENSG00000159110.20,0.327672,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.427572,...,5.2e-05,0.00016,0.088051,0.350878,0.10454,0.365455,0.270833,0.496725,0.9795041,0.9872472
MYL12B,ENSG00000118680.14,0.243756,1.0,0.398601,1.0,1.0,1.0,1.0,1.0,0.015984,...,0.515332,0.578494,0.00588,0.054024,0.812816,0.95938,0.001316,0.014228,1.714555e-11,6.223198e-10
SLC38A2,ENSG00000134294.14,0.154799,1.0,1.0,1.0,1.0,1.0,0.263158,1.0,1.0,...,0.3955,0.48108,0.248958,0.584614,0.206025,0.568527,0.890368,0.945695,0.8103127,0.8981456
ATP5F1A,ENSG00000152234.16,1.0,1.0,1.0,1.0,0.748252,1.0,0.202797,1.0,0.436563,...,0.565134,0.619268,0.833579,0.925224,0.071002,0.281328,0.485082,0.676159,0.1256513,0.2854828
SDF2,ENSG00000132581.10,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.002861,0.004574,0.002859,0.031367,0.897391,0.979022,0.42009,0.629129,0.6710085,0.8115035


## Explore

In [34]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

ncols = 4
nrows = int(np.ceil(len(group_ids) / ncols))
scale = 3.

fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(scale * ncols, scale * nrows))

highlight_map = {
    # ('RPL', 'red'): [gene_name for gene_name in gene_names if gene_name.find('RPL') == 0],
    ('PTPRC', 'blue'): ['PTPRC'],
}

x_jitter_scale = 0.01
y_jitter_scale = 0.01
max_ds_log_pval = 3.
max_de_log_pval = 3.
log10_fdr_threshold = - np.log10(0.05)
rng_seed = 0

rng = np.random.RandomState(seed=rng_seed)
gene_names_list = group_resolved_ds_de_df.index.values.tolist()
log10_tpm = np.log10(
    1_000_000 * group_resolved_ds_de_df['total_expr'].values /
    np.sum(group_resolved_ds_de_df['total_expr'].values))

for group_idx, ax in zip(range(len(group_ids)), axs.flatten()):

    X_COL_NAME = f'de_pval_adj_{group_idx}'
    Y_COL_NAME = f'ds_pval_adj_{group_idx}'

    xx = np.minimum(-np.log10(group_resolved_ds_de_df[X_COL_NAME].values), max_de_log_pval)
    yy = np.minimum(-np.log10(group_resolved_ds_de_df[Y_COL_NAME].values), max_ds_log_pval)

    n_ds = np.sum((yy > log10_fdr_threshold) & (xx <= log10_fdr_threshold))
    n_de = np.sum((xx > log10_fdr_threshold) & (yy <= log10_fdr_threshold))
    n_de_ds = np.sum((xx > log10_fdr_threshold) & (yy > log10_fdr_threshold))
    n_boring = np.sum((xx <= log10_fdr_threshold) & (yy <= log10_fdr_threshold))
    
    xx = np.maximum(0, xx + x_jitter_scale * np.max(xx) * rng.randn(len(xx)))
    yy = np.maximum(0, yy + y_jitter_scale * np.max(yy) * rng.randn(len(yy)))

    scatter = ax.scatter(
        xx,
        yy,
        color='gray',
        # c=log10_tpm,
        cmap=plt.cm.Reds,
        s=2, alpha=0.5)

    ax.axhline(log10_fdr_threshold, lw=1, linestyle='--')
    ax.axvline(log10_fdr_threshold, lw=1, linestyle='--')

    for manifest, highlighted_gene_names in highlight_map.items():
        label = manifest[0]
        color = manifest[1]
        indices = [
            gene_names_list.index(gene_name)
            for gene_name in highlighted_gene_names]
        ax.scatter(
            xx[indices],
            yy[indices],
            s=20,
            color=color,
            label=label,
            marker='o',
            facecolor=[1, 1, 1, 0],
            linewidths=1)

    ax.set_xlabel(r'$-log_{10}~P_{DE}$')
    ax.set_ylabel(r'$-log_{10}~P_{DS}$')
    
    ax.text(
        0.5 * log10_fdr_threshold, 0.5 * (np.max(yy) + log10_fdr_threshold),
        str(n_ds),
        fontsize=8,
        rotation='horizontal',
        ha='center',
        color='red',
        bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

    ax.text(
        0.5 * log10_fdr_threshold, 0.5 * log10_fdr_threshold,
        str(n_boring),
        fontsize=8,
        rotation='horizontal',
        ha='center',
        color='red',
        bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

    ax.text(
        0.5 * (np.max(xx) + log10_fdr_threshold), 0.5 * log10_fdr_threshold,
        str(n_de),
        fontsize=8,
        rotation='horizontal',
        ha='center',
        color='red',
        bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

    ax.text(
        0.5 * (np.max(xx) + log10_fdr_threshold), 0.5 * (np.max(yy) + log10_fdr_threshold),
        str(n_de_ds),
        fontsize=8,
        rotation='horizontal',
        ha='center',
        color='red',
        bbox=dict(boxstyle='round', facecolor='white', alpha=0.5))

    ax.set_title(group_ids[group_idx] + " vs. rest")


fig.tight_layout()

plt.savefig(
    os.path.join(
        repo_root, fig_output_root, output_prefix_full + "_per_cluster_ds_de.pdf"))