In [None]:
import pandas as pd
import numpy as np
import regex as re
import plotly.express as px
import seaborn as sns
import os, logging
from gseapy import enrichr
log = logging.getLogger("topic_preserve")
#log.addHandler(logging.StreamHandler())
log.setLevel(logging.DEBUG)
from overlap import get_pval

In [None]:
os.chdir("../topics/datasets/gtex10/")

In [None]:
algorithms = ["topsbm", "lda", "wgcna", "tm"]

In [None]:
pop = pd.read_csv("topsbm/topsbm_level_0_word-dist.csv", index_col=0).index

In [None]:
def get_df_topic_tissue(algorithm, l=3):
    df_topic_dist = pd.read_csv(f"{algorithm}/{algorithm}_level_{l}_topic-dist.csv",index_col=1).drop("i_doc",1)
    df_topic_dist = df_topic_dist.subtract(df_topic_dist.mean(0),1).abs().divide(df_topic_dist.std(0),1)

    df_files = pd.read_csv("files.dat", index_col=0)
    df_topic_dist["tissue"] = df_files.reindex(index=df_topic_dist.index)["SMTS"]
    df_topic_tissue = df_topic_dist.groupby("tissue").mean().transpose()
    
    return df_topic_tissue

def get_most_significant_topictissue(algorithm, tissue, l=3):
    return get_df_topic_tissue(algorithm,l)[tissue].sort_values(ascending=False).index[0]

def get_most_significant_topictissue_genelist(algorithm, tissue, l=3):
    df_topics = pd.read_csv(f"{algorithm}/{algorithm}_level_{l}_topics.csv")
    topic = get_most_significant_topictissue(algorithm, tissue, l)
    
    return df_topics[topic].dropna().values

In [None]:
level_map = {
    "topsbm":2,
    "lda": 3,
    "wgcna": 0,
    "tm": 0
}

In [None]:
pattern = "[ch][1-8]{0,1}.all.v7.[12]{1}.symbols.gmt"
mdb_dir = "/home/jovyan/work/phd/MSigDB/"
gene_sets = [mdb_dir+file for file in list(filter(lambda file: re.match(pattern, file) is not None, os.listdir(mdb_dir)))]

In [None]:
#['Adipose Tissue' 'Blood' 'Blood Vessel' 'Brain' 'Colon' 'Esophagus' 'Heart' 'Muscle' 'Skin' 'Thyroid']
tissue = "Heart"
for ialg, algorithm in enumerate(algorithms):
    for compare_algo in algorithms[1+ialg:]:
        first = get_most_significant_topictissue_genelist(algorithm, tissue, level_map[algorithm])
        second = get_most_significant_topictissue_genelist(compare_algo, tissue, level_map[compare_algo])
        
        p = get_pval(first, second, pop)
        if p < 0.01:
            log.info(algorithm+"\t"+str(p)+"\t"+compare_algo)
            commons = first[np.isin(first,second)]
            if len(commons) > 10:
                for g in commons:
                    print(g)
            log.info(algorithm+"\t"+str(p)+"\t"+compare_algo)
            log.info(len(commons))
            log.info(enrichr(list(commons), gene_sets).results)