# Visualization of pathway enrichment analysis for topics

In [None]:
import pandas as pd
import gseapy as gp
import matplotlib.pyplot as plt
from gseapy import gseaplot

In [None]:
dataname = 'Wang'  
data = pd.read_csv('../output/'+dataname+'/'+dataname+'_tg.csv',sep=',',index_col=0)
csv = pd.read_csv('../data/'+dataname+'_HIGHPRE_5000.csv',sep=',',index_col=0)
data.columns = csv.columns
data = data.T

In [None]:
all_res = None
for index in range(len(data.iloc[0])):
    data_temp  = data.iloc[:,index]
    pre_res = gp.prerank(rnk=data_temp, # or rnk = rnk,
                     gene_sets='../data/c2.all.v2024.1.Hs.symbols.gmt', 
                     threads=100,
                     outdir=None, # don't write to disk
                    )
    pre_res.res2d.insert(0,'topic_index',index)
    if index == 0:
        all_res = pre_res.res2d
    else:
        all_res = pd.concat([all_res, pre_res.res2d], ignore_index=True)

In [None]:
from pandas import DataFrame
from scipy.stats import uniform
from scipy.stats import randint
import numpy as np
import matplotlib.pyplot as plt

font = {'family' : 'Times New Roman',
'size'   : 20,
}
fonty = {'family' : 'Times New Roman',
'size'   : 14,
}

all_res['$-log_{10}(q)$'] = -np.log10(np.array(all_res['FDR q-val']+1e-6).astype('float'))
all_res = all_res.sort_values('topic_index')

all_res['ind'] = range(len(all_res))
df_grouped = all_res.groupby(('topic_index'))

# manhattan plot
fig = plt.figure(figsize=(20,4),dpi=100) 
ax = fig.add_subplot(111)

#colors = ["#8cc9bb","#818cae","#e8967c","#3a5181","#109781","#4bb0c8","#d94b34"]
colors = ["#b03d26","#005f81","#9ccfe6","#e0897e","#a5a7ab"]
x_labels = []
x_labels_pos = []
for num, (name, group) in enumerate(df_grouped):
    group.plot(kind='scatter', x='ind', y='$-log_{10}(q)$',color=colors[num % len(colors)], ax=ax)
    x_labels.append(name)
    x_labels_pos.append((group['ind'].iloc[-1] - (group['ind'].iloc[-1] - group['ind'].iloc[0])/2))
# add grid
#ax.grid(axis="y",linestyle="--",linewidth=.5,color="gray")
ax.tick_params(direction='in',labelsize=13)
x_labels_2 = []
x_labels_pos_2 = []
for index in range(100):
    if index%2==0:
        x_labels_pos_2.append(x_labels_pos[index])
        x_labels_2.append(x_labels[index])
ax.set_xticks(x_labels_pos_2)
ax.set_xticklabels(x_labels_2)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

ax.set_xlim([0, len(all_res)])
ax.set_ylim([0, 6.3])

ax.set_xlabel('Topic',font)
ax.set_ylabel('$-log_{10}(q)$',fonty)
#plt.legend(prop=font)
#plt.savefig('scE2TM_'+dataname+'.PDF',format='PDF',bbox_inches = 'tight',facecolor='white')
all_res_table = all_res[all_res["FDR q-val"] <=0.01]
print(len(all_res_table))
plt.show()