In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

In [3]:
import scanpy as sc
import pandas as pd
import numpy as np
import scipy as sp

import plotnine as p9
import sys
from pathlib import Path
from tqdm.auto import tqdm
from joblib import Parallel, delayed
import pickle

In [4]:
sc.set_figure_params()

In [5]:
sys.path.append(str(Path.home() / 'Code/sctoolkit'))

In [6]:
from sctoolkit.modules import find_modules, tag_with_score, sort_module_dict

## Load expression data

In [None]:
adata = sc.read('adata.h5ad')

## Load GWAS genes

In [8]:
gwas_ad = sc.read('otg-gwascat-20210125.h5ad') ## available at https://github.com/gokceneraslan/opentargets-genetics-python/tree/master/data
gwas_ad

AnnData object with n_obs × n_vars = 5674 × 12864
    obs: 'studyId', 'pmid', 'pubAuthor', 'pubDate', 'pubJournal', 'pubTitle', 'traitReported', 'ancestryInitial', 'gene_id', 'gene_symbol', 'nInitial', 'numAssocLoci', 'source', 'traitEfos', 'hasSumsStats', 'ancestryReplication', 'nReplication', 'nCases', 'nTotal', 'traitCategory', 'traitEfosStr'
    var: 'gene_id'

In [9]:
overlap = gwas_ad[:, gwas_ad.var_names.isin(adata.var_names)].X.sum(1).A1
gwas_ad = gwas_ad[(overlap>1) & (overlap<400)].copy()
gwas_ad

AnnData object with n_obs × n_vars = 4738 × 12864
    obs: 'studyId', 'pmid', 'pubAuthor', 'pubDate', 'pubJournal', 'pubTitle', 'traitReported', 'ancestryInitial', 'gene_id', 'gene_symbol', 'nInitial', 'numAssocLoci', 'source', 'traitEfos', 'hasSumsStats', 'ancestryReplication', 'nReplication', 'nCases', 'nTotal', 'traitCategory', 'traitEfosStr'
    var: 'gene_id'

## GWAS enrichment tests

In [10]:
def gwas_enrichment(adata, gwas_ad, obs_key=None):

    background = sorted(set(gwas_ad.var_names) | set(adata.var_names))

    print('Building expression module df...')
    module_dict = {}
    for l in tqdm(adata.uns['paris']['module_dict'].keys()):
        mod_dict = adata.uns['paris']['module_dict'][l]
        fisher_df = pd.DataFrame(dict(background=background))
        for mod, genes in mod_dict.items():
            fisher_df[mod] = fisher_df.background.isin(genes)

        module_dict[l] = fisher_df.set_index('background')

    module_df = pd.concat(module_dict.values(), axis=1, keys=module_dict.keys()).T

    gwas_df = pd.DataFrame(gwas_ad.X.astype(bool).A, index=gwas_ad.obs_names, columns=gwas_ad.var_names).T
    gwas_df.columns.name = None

    background_df = pd.DataFrame(index=background)
    gwas_df = gwas_df.join(background_df, how='outer').fillna(False).T

    print('Calculating binary test statistics...')
    final_df = binary_test_statistics(gwas_df, module_df)
    
    if obs_key is not None:
        if 'score_tags' not in adata.uns['paris']:
            print('Scoring modules and merging score info...')
            tag_with_score(adata, obs_key, zcutoff=0)
        
        final_df = final_df.reset_index().merge(adata.uns['paris']['score_tags'])
    
    return final_df

In [11]:
def binary_test_statistics(gwas_df, module_df):
    assert np.all(gwas_df.columns == module_df.columns), 'Columns of dfs must match'

    from fisher import pvalue_npy
    import sparse
    
    x = sparse.COO(gwas_df.values.astype('uint32'))
    y = sparse.COO(module_df.values.T.astype('uint32'))

    # find 2x2 values
    a = x.dot(y).todense()
    b = x.dot(1-(y.todense()))
    c = sparse.dot(1-(x.todense()), y)
    d = gwas_df.shape[1] - (a+b+c)

    # get intersections, phew
    nnz = (x.T[:, :, None] * y[:, None, :]).nonzero()
    nnz = pd.DataFrame(np.array(nnz).T, columns=['gene', 'gwas', 'module'])

    nnz['gene'] = gwas_df.columns[nnz['gene'].values]
    nnz['gwas'] = gwas_df.index[nnz['gwas'].values]
    nnz['level'] = [x[0] for x in module_df.index[nnz['module'].values]]
    nnz['module'] = [x[1] for x in module_df.index[nnz['module'].values]]

    nnz = nnz.groupby(['gwas', 'level', 'module'], observed=True)[['gene']].agg(tuple)

    # run fisher test
    pval = pvalue_npy(a.ravel().astype('uint'), 
                      b.ravel().astype('uint'), 
                      c.ravel().astype('uint'), 
                      d.ravel().astype('uint'))[1]
    pval = pval.reshape(a.shape)

    # calculate log odds, prec, recall, fscore
    with np.errstate(invalid='ignore', divide='ignore'):
        logodds = np.log2((a.astype(float)*d.astype(float)) / (b.astype(float)*c.astype(float)))
        
    precision = a/(a+c)    
    recall = a/(a+b)
    fscore = 2*precision*recall/(precision+recall)        
        
    # store everything in a data frame
    def melt(mat, valname):
        mat = pd.DataFrame(mat, index=gwas_df.index, columns=module_df.index).reset_index().rename(columns={'index': 'gwas'})
        mat = mat.melt(id_vars=['gwas'], var_name=['level', 'module'], value_name=valname).set_index(['gwas', 'level', 'module'])
        return mat

    pval = melt(pval, 'pval')
    logodds = melt(logodds, 'logodds')
    
    a = melt(a, 'pospos')
    b = melt(b, 'posneg')
    c = melt(c, 'negpos')
    d = melt(d, 'negneg')
    
    precision = melt(precision, 'precision')
    recall = melt(recall, 'recall')
    fscore = melt(fscore, 'fscore')    
    
    final_df = pval.copy()
    final_df['logOR'] = logodds.logodds
    final_df['pospos'] = a.pospos
    final_df['posneg'] = b.posneg
    final_df['negpos'] = c.negpos
    final_df['negneg'] = d.negneg

    final_df['precision'] = precision.precision
    final_df['recall'] = recall.recall
    final_df['fscore'] = fscore.fscore
    
    final_df['intersections'] = nnz.gene    
    
    final_df = final_df[(final_df.pospos>0)]
    
    return final_df

In [None]:
%%time

for t in tqdm(adata.obs.tissue.cat.categories):
    print(t)
    ad = adata[adata.obs.tissue == t].copy()
    sc.pp.filter_genes(ad, min_cells=10)

    find_modules(ad)
    tag_with_score(ad, ['Broad cell type', 'Granular cell type'], zcutoff=0.)
    filename = f'modules-{t}'

    res = gwas_enrichment(ad, gwas_ad, ['Granular cell type', 'Broad cell type']).assign(tissue=t)
    res.to_pickle(f'{filename}-enrichment-v2.pkl')
    
    with open(f'{filename}-v2.pkl', 'wb') as f:
        pickle.dump(ad.uns['paris'], f)

    del ad.uns['paris']
    ad.write(f'{filename}.h5ad')
    
    del ad
    del res
    
    gc.collect()