In [None]:
import scanpy as sc
import pegasus as pg
from collections import Counter
import numpy as np
import pandas as pd

In [None]:
filedir='/ahg/regevdata/projects/scgwas/data/singlecell/modules/disease/celltypeenriched/'
datadir='/ahg/regevdata/projects/scgwas/data/singlecell'

In [None]:
disease_datasets = {
    'UC':(datadir + '/z_disease/uc/ucdata.h5ad', 'subject'),
    'MS':(datadir + '/z_disease/multiplesclerosis/alldata-processed-annotated.h5ad','sample'),
    'Fibrosis':(datadir + '/lung/kropski-annotated.h5ad','patient'),
    'Alzheimers':(datadir + '/z_disease/alzheimers/totaldata-meta.h5ad','Subject'),
    'Covid_idoamit':(datadir + '/z_disease/covid/covid-processed-annotated-harmony.h5ad', 'sample'),
    'Asthma':(datadir + '/lung/teichmann/teichmann-processed-annotated.h5ad', 'sampleid')    
}

In [None]:
diseaselabel_mapping = {
    'Control':'Healthy',
    'MS':'Disease',
    'Healthy':'Healthy',
    'Inflamed':'Disease',
    'Non-inflamed':'Unknown',
    'Fibrosis':'Disease',
    'late-pathology':'Disease',
    'no-pathology':'Healthy',
    'early-pathology':'Unknown',
    'Moderate':'Healthy',
    'Severe':'Disease',
    'Y':'Disease',
    'N':'Healthy',
    'Asthma':'Disease',   
    'Unknown':'Unknown',
    'Card9_infected':'Disease',
    'WT_infected':'Healthy',
    True:'Disease',
    False:'Healthy'
}

diagnosisdict = {'viral_Covid_idoamit':'hasnCoV', 'Covid_regev':'Viral+', 'skincard9kodata':'status', 'Asthma':'disease', 'Covid_idoamit':'severity','UC':'hui', 'MS':'diagnosis', 'Fibrosis':'CleanDiagnosis', 'Alzheimers':'DiseaseStatus'}
diseaselabeldict = {'viral_Covid_idoamit':'Y', 'Covid_regev':True,'skincard9kodata':'Card9_infected', 'Asthma':'Asthma', 'Covid_idoamit':'Severe', 'Covid_idoamit_infected':'Y', 'Covid_idoamit':'Severe', 'UC':'Inflamed', 'MS':'MS', 'Fibrosis':'Fibrosis', 'Alzheimers':'late-pathology'}

In [None]:
def compute_contamination(adata):
    contamination = {}
    for ct in set(subset.obs[ctlabel]):
        scores = pd.DataFrame(adata.uns['rank_genes_groups']['scores'])[ct]
        names = pd.DataFrame(adata.uns['rank_genes_groups']['names'])[ct]
        threshold = np.mean(scores) - 6*np.std(scores)
        contamination_genes = names[scores<threshold]
        contamination[ct] = contamination_genes
    return contamination

In [None]:
celltypelabel = 'annot_level_2'
cell_counts = []
patient_counts = []
for disease, (filename, patientkey) in disease_datasets.items():
    diagnosislabel = diagnosisdict[disease]
    
    adata = sc.read(filename)
    print(disease, filename, adata.shape[0], len(set(adata.obs[patientkey])))
    cell_counts.append(adata.shape[0])
    patient_counts.append(len(set(adata.obs[patientkey])))
    metadata = adata.obs[[celltypelabel, patientkey, diagnosislabel]]
    metadata['count'] = [1]*adata.shape[0]
    metadata['original'] = ['True']*metadata.shape[0]
    total_counts = metadata.groupby([patientkey, diagnosislabel]).agg('count').to_dict()
    count_df = metadata.groupby([patientkey, celltypelabel, diagnosislabel]).agg({'count':'count', 'original':lambda x: x.iloc[0]}).fillna(0).reset_index()
    count_df = count_df[count_df['original']=='True']
    count_df['total'] = [total_counts['count'][(pid, dstatus)] for pid, dstatus in zip(count_df[patientkey], count_df[diagnosislabel])]
    count_df['normed'] = count_df['count']/count_df['total']
    count_df['disease_status'] = [diseaselabel_mapping.get(disease_status, 'Unknown') for disease_status in count_df[diagnosislabel]]
    count_df = count_df.loc[count_df['disease_status']!='Unknown', :]
    count_df.to_csv('/ahg/regevdata/projects/scgwas/data/singlecell/counts/%s_counts.txt'%disease)

In [None]:
deadatas = {}

In [None]:
for disease, (filename, patientkey) in disease_datasets.items():
    print(disease)
    adata = sc.read(filename)
    for ctlabel in ['annot_level_2', 'annot_level_3']:
        subset = adata[adata.obs[diagnosisdict[disease]]==diseaselabeldict[disease]]
        sc.tl.rank_genes_groups(subset, groupby=ctlabel, reference='rest', n_genes=subset.shape[1], method='wilcoxon')
        
        adata.uns['contamination_'+ctlabel] = compute_contamination(subset)
        level = ctlabel.split('_')[-1]
        adata.obs['DEstatus'] = [diseaselabel_mapping[diagnosis] + '_' + ct + '_L%s'%level for diagnosis, ct in zip(adata.obs[diagnosisdict[disease]], adata.obs[ctlabel])]
        #adata.obs['DEstatus'] = [diseaselabel_mapping[diagnosis] + '_' + ct for diagnosis, ct in zip(adata.obs[diagnosisdict[disease]], adata.obs[ctlabel])]
        destatus_counts = Counter(adata.obs['DEstatus'])
        destatus_counter = {}
        discard = False
            
        for ct in set(adata.obs[ctlabel]):
            ct = ct +'_L%s'%level
            discard = False
            for k in ['Healthy_'+ct, 'Disease_'+ct]:
                if destatus_counts.get(k, 0) < 5:
                    discard = True
            print(ct, destatus_counts.get('Healthy_'+ct), destatus_counts.get('Disease_'+ct), discard)
            if discard:
                continue
            sc.tl.rank_genes_groups(adata, groupby='DEstatus', 
                                    reference='Healthy_'+ct, groups=['Disease_'+ct], key_added=ct+'_DE', n_genes=adata.shape[1], method='wilcoxon')
    deadatas[disease] = adata

In [None]:
def write_matrix(adata, filename):
    # set up the ordering of genes and cells
    #adata = adatas[filename]
    genes = list(set(adata.var_names))
    gene2idx = {gene:i for i, gene in enumerate(genes)}
    
    pvalmtxs, logfoldmtxs, scoremtxs = [], [], []
    ctlabels = [col for col in adata.uns if '_DE' in col]
    print(adata.obs.columns)
    print(ctlabels)
    for delabel in ctlabels:
        #delabel = ctlabel + '_DE'
        ct = delabel.split('_')[0]
        contamination = adata.uns['contamination_'+celltypelabel].get(ct, np.array([])).tolist() + \
                        adata.uns['contamination_annot_level_3'].get(ct, np.array([])).tolist()
        
        cellsubsets = adata.uns[delabel]['names'].dtype.fields.keys()
        cell2idx = {cellsubset:i for i, cellsubset in enumerate(cellsubsets)}

        # create empty matrix
        pvalmtx = np.zeros((len(gene2idx), len(cell2idx)))

        logfoldmtx = np.zeros((len(gene2idx), len(cell2idx)))
        scoremtx = np.zeros((len(gene2idx), len(cell2idx)))

        # loop through and fill up the matrix with pvalue, logfold and score
        for gene, pval, logfold, score in zip(adata.uns[delabel]['names'], 
                                       adata.uns[delabel]['pvals_adj'], 
                                       adata.uns[delabel]['logfoldchanges'], 
                                       adata.uns[delabel]['scores']):
            for cell_subset in cellsubsets:
                
                if gene[cell_subset] in contamination:
                    p = 1
                    l = 0
                    s = 0
                else:
                    p = pval[cell_subset]
                    l = logfold[cell_subset]
                    s = score[cell_subset]
                
                if gene[cell_subset] in gene2idx:
                    pvalmtx[gene2idx[gene[cell_subset]], cell2idx[cell_subset]] = p
                    logfoldmtx[gene2idx[gene[cell_subset]], cell2idx[cell_subset]] = l
                    scoremtx[gene2idx[gene[cell_subset]], cell2idx[cell_subset]] = s

        # transform matrix to dataframe
        #level = ctlabel.split('_')[-2]
        #+"_L%s"%level
        cellsubsets = [ct for ct in cellsubsets]
        pvalmtxs.append(pd.DataFrame(pvalmtx, index=genes, columns=cellsubsets))
        logfoldmtxs.append(pd.DataFrame(logfoldmtx, index=genes, columns=cellsubsets))
        scoremtxs.append(pd.DataFrame(scoremtx, index=genes, columns=cellsubsets))
    pvalmtxs = pd.concat(pvalmtxs, axis=1)
    logfoldmtxs = pd.concat(logfoldmtxs, axis=1)
    scoremtxs = pd.concat(scoremtxs, axis=1)


    # write matrix to file
    pvalmtxs.to_csv("%s/%s_pval.csv"%(filedir, filename))
    logfoldmtxs.to_csv("%s/%s_logfold.csv"%(filedir, filename))
    scoremtxs.to_csv("%s/%s_score.csv"%(filedir, filename))

In [None]:
for disease in deadatas:
    write_matrix(deadatas[disease], disease)

In [None]:
for disease, filename in disease_datasets.items():
    print(disease, filename)

In [None]:
for disease, (filename, patientkey) in disease_datasets.items():
    print(disease)
    adata = sc.read(filename)
    for ctlabel in ['predictions']:
       
        adata.obs['DEstatus'] = ['{}_{}'.format(diseaselabel_mapping[diagnosis], ct) for diagnosis, ct in zip(adata.obs[diagnosisdict[disease]], adata.obs[ctlabel])]
        destatus_counts = Counter(adata.obs['DEstatus'])
        destatus_counter = {}
        discard = False
            
        for ct in set(adata.obs[ctlabel]):
            
            discard = False
            for k in ['Healthy_{}'.format(ct), 'Disease_{}'.format(ct)]:
                if destatus_counts.get(k, 0) < 5:
                    discard = True
            print(ct, destatus_counts.get('Healthy_{}'.format(ct)), destatus_counts.get('Disease_{}'.format(ct)), discard)
            if discard:
                continue
            sc.tl.rank_genes_groups(adata, groupby='DEstatus', 
                                    reference='Healthy_{}'.format(ct), groups=['Disease_{}'.format(ct)], key_added='{}_DE'.format(ct), n_genes=adata.shape[1], method='wilcoxon')
    deadatas[disease] = adata

In [None]:
def write_matrix(adata, filename):
    # set up the ordering of genes and cells
    genes = list(set(adata.var_names))
    gene2idx = {gene:i for i, gene in enumerate(genes)}
    
    pvalmtxs, logfoldmtxs, scoremtxs = [], [], []
    ctlabels = [col for col in adata.uns if '_DE' in col]
    print(adata.obs.columns)
    print(ctlabels)
    for delabel in ctlabels:
        #delabel = ctlabel + '_DE'
        ct = delabel.split('_')[0]
        
        cellsubsets = adata.uns[delabel]['names'].dtype.fields.keys()
        cell2idx = {cellsubset:i for i, cellsubset in enumerate(cellsubsets)}

        # create empty matrix
        pvalmtx = np.zeros((len(gene2idx), len(cell2idx)))

        logfoldmtx = np.zeros((len(gene2idx), len(cell2idx)))
        scoremtx = np.zeros((len(gene2idx), len(cell2idx)))

        # loop through and fill up the matrix with pvalue, logfold and score
        for gene, pval, logfold, score in zip(adata.uns[delabel]['names'], 
                                       adata.uns[delabel]['pvals_adj'], 
                                       adata.uns[delabel]['logfoldchanges'], 
                                       adata.uns[delabel]['scores']):
            for cell_subset in cellsubsets:
                p = pval[cell_subset]
                l = logfold[cell_subset]
                s = score[cell_subset]
                
                if gene[cell_subset] in gene2idx:
                    pvalmtx[gene2idx[gene[cell_subset]], cell2idx[cell_subset]] = p
                    logfoldmtx[gene2idx[gene[cell_subset]], cell2idx[cell_subset]] = l
                    scoremtx[gene2idx[gene[cell_subset]], cell2idx[cell_subset]] = s

        # transform matrix to dataframe
        #level = ctlabel.split('_')[-2]
        #+"_L%s"%level
        cellsubsets = [ct for ct in cellsubsets]
        pvalmtxs.append(pd.DataFrame(pvalmtx, index=genes, columns=cellsubsets))
        logfoldmtxs.append(pd.DataFrame(logfoldmtx, index=genes, columns=cellsubsets))
        scoremtxs.append(pd.DataFrame(scoremtx, index=genes, columns=cellsubsets))
    pvalmtxs = pd.concat(pvalmtxs, axis=1)
    logfoldmtxs = pd.concat(logfoldmtxs, axis=1)
    scoremtxs = pd.concat(scoremtxs, axis=1)


    # write matrix to file
    pvalmtxs.to_csv("%s/%s_pval.csv"%(filedir, filename))
    logfoldmtxs.to_csv("%s/%s_logfold.csv"%(filedir, filename))
    scoremtxs.to_csv("%s/%s_score.csv"%(filedir, filename))