In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
tf_counts = []
with open('./data/GM12878.count.each.column', 'r') as fp:
    for line in fp:
        tf_counts.append(int(line.rstrip()))

In [None]:
filter_fc_GM12878_df = pd.read_csv("./results/filter_fc_SimpleCNN_2d_GM12878_deepSEA.eval.repl.tsv", header=0, index_col=0, sep="\t")
GM12878_df = pd.read_csv("./results/SimpleCNN_2d_GM12878_deepSEA.eval.repl.tsv", header=0, index_col=0, sep="\t")
filter_fc_GM12878_df.columns = [ "filter_"+i for i in filter_fc_GM12878_df.columns.tolist()]
merge_df = pd.concat([GM12878_df, filter_fc_GM12878_df], axis=1)
merge_df['datasize'] = tf_counts

In [None]:
plt.rcParams.update({'font.size': 5})
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", rc=custom_params)
fig, ax =plt.subplots(4,1,constrained_layout=True, figsize=(25, 40), sharex=False)
#plt.figure(dpi=150,figsize=(30,8))
g1 = sns.barplot(data=merge_df, x="TF_name", y="AUC",hue="datasize", ax=ax[0])
g1.axhline(0.9, color='r', dashes=(2,2))
g1.xaxis()
g1.legend(loc='upper right')
h1 = sns.histplot(data=merge_df, x="AUC", kde=True, ax=ax[1])
g2 = sns.barplot(data=merge_df, x="TF_name", y="AUPR",hue="datasize", ax=ax[2])
g2.axhline(0.35, color='r', dashes=(2,2))
h2 = sns.histplot(data=merge_df, x="AUPR", kde=True, ax=ax[3])
ax[0].get_legend().remove()
ax[0].set_xlabel('')
ax[1].set_xlabel('')
ax[2].set_xlabel('')
ax[3].set_xlabel('')
#plt.xticks(rotation=90)
for tick in ax[0].get_xticklabels():
    tick.set_rotation(45)


for tick in ax[2].get_xticklabels():
    tick.set_rotation(45)

plt.legend(loc='upper right')
plt.savefig("results/test.pdf")

### remove low AUC/AUPR TFs

In [None]:
merge_df['AUC'].mean()
merge_df['AUPR'].mean()
# remove tfs in GM12878
# due to AUC 
print("Count of AUC lower than 0.85: ",sum(merge_df['AUC'] < 0.85))
AUC_remove = merge_df[merge_df['AUC'] < 0.85]['TF_name'].tolist()
#AUC_remove = merge_df.sort_values('AUC').head(2)['TF_name'].tolist()

# due to AUPR
print("Count of AUPR lower than 0.1: ",sum(merge_df['AUC'] < 0.85))
AUPR_remove = merge_df[merge_df['AUPR'] < 0.1]['TF_name'].tolist()
#AUPR_remove = merge_df.sort_values('AUPR').head(13)['TF_name'].tolist()

print("AUC after filter: ", merge_df[~merge_df['TF_name'].isin(set(AUC_remove + AUPR_remove))]['AUC'].mean())
print("AUPR after filter ", merge_df[~merge_df['TF_name'].isin(set(AUC_remove + AUPR_remove))]['AUPR'].mean())

In [None]:
merge_df[['TF_name','AUC','filter_AUC','datasize']].melt(['TF_name','datasize'], ['AUC','filter_AUC'])

custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", rc=custom_params)
plt.figure(dpi=150,figsize=(50,8))
g = sns.barplot(data=merge_df[['TF_name','AUC','filter_AUC','datasize']].melt(['TF_name','datasize'], ['AUC','filter_AUC']), x="TF_name", y="value",hue="variable")
g.axhline(0.9, color='r', dashes=(2,2))
plt.xticks(rotation=90)
plt.legend(loc='upper right')
#plt.show()





plt.savefig("results/test2.pdf")

### Compare Models in GM12878 test data set

In [None]:
#scFAN DeepATT DanQ TBiNet Deepformer
scFAN_GM12878_noweight = pd.read_csv("scFAN_GM12878_noweight.eval.repl.tsv", header=0, index_col=0)
DeepATT_GM12878_noweight = pd.read_csv("DeepATT_GM12878_noweight.eval.repl.tsv", header=0, index_col=0)
DanQ_GM12878_noweight = pd.read_csv("DanQ_GM12878_noweight.eval.repl.tsv", header=0, index_col=0)
TBiNet_GM12878_noweight = pd.read_csv("TBiNet_GM12878_noweight.eval.repl.tsv", header=0, index_col=0)
Deepformer_GM12878_noweight = pd.read_csv("Deepformer_GM12878_noweight.eval.repl.tsv", header=0, index_col=0)

In [None]:
scFAN_GM12878 = pd.read_csv("scFAN_GM12878.eval.repl.tsv", header=0, index_col=0)
DeepATT_GM12878 = pd.read_csv("DeepATT_GM12878.eval.repl.tsv", header=0, index_col=0)
DanQ_GM12878 = pd.read_csv("DanQ_GM12878.eval.repl.tsv", header=0, index_col=0)
TBiNet_GM12878 = pd.read_csv("TBiNet_GM12878.eval.repl.tsv", header=0, index_col=0)
Deepformer_GM12878 = pd.read_csv("Deepformer_GM12878.eval.repl.tsv", header=0, index_col=0)