In [None]:
# autoreload
%load_ext autoreload
%autoreload 2

import scTRS.util as util
import scTRS.method as md
import scTRS.data_loader as dl
import submitit
import numpy as np
from anndata import read_h5ad
from statsmodels.stats.multitest import multipletests

import pandas as pd
import os
from os.path import join
import itertools
import pickle
from IPython.display import display, Markdown, Latex
import matplotlib.pyplot as plt

# constants
DATA_PATH = '/n/holystore01/LABS/price_lab/Users/mjzhang/scTRS_data'
SCORE_FILE_DIR = join(DATA_PATH, "score_file")

TRAIT_LIST = ['PASS_Schizophrenia_Ruderfer2018',
              'PASS_BipolarDisorder_Ruderfer2018',
              'PASS_Alzheimers_Jansen2019', 
              'PASS_AdultOnsetAsthma_Ferreira2019',
              'PASS_Coronary_Artery_Disease', 
              'PASS_LargeArteryStroke_Malik2018', 
              'PASS_HDL', 'PASS_LDL',
              'PASS_Rheumatoid_Arthritis', 'PASS_Lupus', 
              'PASS_FastingGlucose_Manning',
              'PASS_IBD_deLange2017', 
              'PASS_Type_1_Diabetes', 
              'PASS_Type_2_Diabetes']

df_obs_dict = dict()
for dataset in ["tms_facs", "tms_droplet"]:
    if dataset == 'tms_facs':
        file_name = join(DATA_PATH, 'tabula_muris_senis/tabula-muris-senis-facs-official-raw-obj.h5ad')
    elif dataset == 'tms_droplet':
        file_name = join(DATA_PATH, 'tabula_muris_senis/tabula-muris-senis-droplet-official-raw-obj.h5ad')
    df_obs_dict[dataset] = read_h5ad(file_name).obs.copy()

In [None]:
import glob
result_dir = "out/celltype_hetero/group_stats/"
file_list = glob.glob(join(result_dir, "*gearysc.gz"))

# read group stats
group_stats_dict = dict()

for name in file_list:
    trait_dataset, tissue = name.split('/')[-1].split('.')[0 : 2]
    if trait_dataset.endswith("liver_atlas"):
        continue
    stats_df = pd.read_csv(name, index_col=0, sep='\t')
    trait_col = "norm_score"
    ctrl_cols = [col for col in stats_df.columns if col.startswith("ctrl_norm_score_")]
    pval = ((stats_df[trait_col].values > stats_df[ctrl_cols].values.T).sum(axis=0) + 1) / (len(ctrl_cols) + 1)
    pval[np.isnan(stats_df[trait_col])] = np.nan
    pval_df = pd.DataFrame({"pval": pval, 
                            "trait": stats_df[trait_col].values, 
                            "ctrl_mean": stats_df[ctrl_cols].mean(axis=1).values,
                            "ctrl_std": stats_df[ctrl_cols].std(axis=1).values}, 
                           index=tissue + '.' + stats_df.index)
    group_stats_dict.setdefault(trait_dataset, []).append(pval_df)

for trait_dataset in group_stats_dict:
    group_stats_dict[trait_dataset] = pd.concat(group_stats_dict[trait_dataset])

# read scores
score_df_dict = dict()
for dataset in ["tms_facs", "tms_droplet"]:
    df_obs = df_obs_dict[dataset]
    df_obs["tissue.ct"] = df_obs["tissue"].astype(str) + '.' + df_obs["cell_ontology_class"].astype(str)
    if dataset == "tms_facs":
        score_dir = join(SCORE_FILE_DIR, "score.tms_facs.gwas_max_abs_z.top500")
    elif dataset == "tms_droplet":
        score_dir = join(SCORE_FILE_DIR, "score.tms_droplet.gwas_max_abs_z.top500.weight_1en2")
    elif dataset == "liver_atlas":
        score_dir = join(SCORE_FILE_DIR, "score.liver_atlas.gwas_max_abs_z.top500")
    else:
        raise NotImplementedError
        
    for trait in TRAIT_LIST:
        score_df = pd.read_csv(join(score_dir, f"{trait}.score.gz"), sep='\t', index_col=0)
        score_df_dict[f"{trait}@{dataset}"] = score_df
        df_obs['trs_score'] = score_df['zscore'].reindex(df_obs.index).values
        n_sig = df_obs.groupby("tissue.ct").apply(lambda x : sum(x['trs_score'] > 3.))
        group_stats_dict[f"{trait}@{dataset}"]["n_sig"] = n_sig.reindex(group_stats_dict[f"{trait}@{dataset}"].index)
    

# A scan for heterogeneity
We first scan the within cell-type heterogeneity for trait-celltype pair. An interesting trait-celltype pair is identified if satisfying the following criterion:
1. Geary's C statistics is 3 standard deviation away from the control mean.
2. There is more than 10 cells with cell-specific z-score larger than 3.

In [None]:
plot_tissue_ct_dict = dict()
for trait_dataset in group_stats_dict:
    df = group_stats_dict[trait_dataset]
    df['zsc'] = (-df['trait'] + df['ctrl_mean']) / df['ctrl_std']
    plot_tissue_ct_dict[trait_dataset] = df[(df['zsc'] > 3) & (df['n_sig'] > 10)].index

In [None]:
dic_tissue_list = {'tms_facs': ['Aorta', 'BAT', 'Bladder', 'Brain_Myeloid', 'Brain_Non-Myeloid',
                            'Diaphragm', 'GAT', 'Heart', 'Kidney', 'Large_Intestine',
                            'Limb_Muscle', 'Liver', 'Lung', 'MAT', 'Mammary_Gland', 'Marrow', 
                            'Pancreas', 'SCAT', 'Skin', 'Spleen', 'Thymus', 'Tongue', 'Trachea'],
                   'tms_droplet': ['Bladder', 'Fat', 'Heart_and_Aorta', 'Kidney', 'Large_Intestine',
                               'Limb_Muscle', 'Liver', 'Lung', 'Mammary_Gland', 'Marrow',
                               'Pancreas', 'Skin', 'Spleen', 'Thymus', 'Tongue', 'Trachea']}

for dataset in ["tms_facs", "tms_droplet"]:
    # read dataset
    for tissue in dic_tissue_list[dataset]:
        tissue_adata = dl.load_tms_processed(DATA_PATH, data_name=dataset.split('_')[1], tissue=tissue)[tissue]
        plot_dict = dict()
        # collect 
        for trait in TRAIT_LIST:    
            for plot_tissue_ct in plot_tissue_ct_dict[f"{trait}@{dataset}"]:
                plot_tissue, plot_ct = plot_tissue_ct.split('.')
                if plot_tissue == tissue:
                    plot_dict.setdefault(trait, []).append(plot_ct)

        if len(plot_dict) == 0:
            continue
        display(Markdown(f"## {dataset} {tissue}"))
        zsc_dict = dict()
        for trait in plot_dict:
            title = f"{trait}:\n{', '.join(plot_dict[trait])}"
            zsc_dict[title] = score_df_dict[f"{trait}@{dataset}"]["zscore"].reindex(tissue_adata.obs.index).values
        score_index = tissue_adata.obs.index
        util.plot_score_umap(zsc_dict, score_index, tissue_adata, n_col=3)    

From the results of scan from TMS FACS data, we select a subset for further inspection / cross-validation with TMS droplet data.

In [None]:
pair_dict = {"Liver": ["PASS_HDL", "PASS_LDL"],
             "Lung": ["PASS_Coronary_Artery_Disease", "PASS_Type_2_Diabetes"],
             "Marrow": ["PASS_Alzheimers_Jansen2019"],
             "Thymus": ["PASS_Rheumatoid_Arthritis"],
             "Trachea": ["PASS_Rheumatoid_Arthritis", "PASS_Lupus"]}

for tissue in pair_dict:
    for dataset in ["tms_facs", "tms_droplet"]:
        display(Markdown(f"## {dataset} {tissue}"))
        tissue_adata = dl.load_tms_processed(DATA_PATH, data_name=dataset.split('_')[1], tissue=tissue)[tissue]
        zsc_dict = dict()
        for trait in pair_dict[tissue]:
            zsc_dict[trait] = score_df_dict[f"{trait}@{dataset}"]["zscore"].reindex(tissue_adata.obs.index).values
        score_index = tissue_adata.obs.index
        util.plot_score_umap(zsc_dict, score_index, tissue_adata, n_col=3)   