In [1]:
import os
import sys
import pandas as pd
sys.path.append('../../')
from deepfold.data.utils.ontology import Ontology
from tools.evaluate_diamondscore import get_diamond_preds, get_diamond_scores
from tools.evaluate_deepmodel import get_model_preds
from sklearn.metrics import f1_score, recall_score, precision_score


In [2]:
base_path = '/home/niejianzheng/xbiome/datasets/protein'
# 存储所有数据文件路径
data_ls = os.walk(base_path).__next__()[2]
data_path_dict = {}
for data in data_ls:
    file_name = data.split('.')[0] + '_' + data.split('.')[1]
    data_path_dict[file_name] = os.path.join(base_path, data)

In [3]:
data_path_dict

{'test_data_fa': '/home/niejianzheng/xbiome/datasets/protein/test_data.fa',
 'uniprot_sprot_dat': '/home/niejianzheng/xbiome/datasets/protein/uniprot_sprot.dat',
 'test_emb_h5': '/home/niejianzheng/xbiome/datasets/protein/test_emb.h5',
 'test_diamond_res': '/home/niejianzheng/xbiome/datasets/protein/test_diamond.res',
 'test_data_pkl': '/home/niejianzheng/xbiome/datasets/protein/test_data.pkl',
 'go_obo': '/home/niejianzheng/xbiome/datasets/protein/go.obo',
 'train_data_pkl': '/home/niejianzheng/xbiome/datasets/protein/train_data.pkl',
 'esm1b_t33_650M_UR50S_embeddings_mean_train_pkl': '/home/niejianzheng/xbiome/datasets/protein/esm1b_t33_650M_UR50S_embeddings_mean_train.pkl',
 'test_annotations_txt': '/home/niejianzheng/xbiome/datasets/protein/test_annotations.txt',
 'train_data_dmnd': '/home/niejianzheng/xbiome/datasets/protein/train_data.dmnd',
 'predictions_pkl': '/home/niejianzheng/xbiome/datasets/protein/predictions.pkl',
 'train_annotations_txt': '/home/niejianzheng/xbiome/datas

In [4]:
go_rels = Ontology(data_path_dict['go_obo'], with_rels=True)

train_df = pd.read_pickle(data_path_dict['train_data_pkl'])
annotations = train_df['prop_annotations'].values
annotations = list(map(lambda x: set(x), annotations))

test_df = pd.read_pickle(data_path_dict['test_data_pkl'])
test_annotations = test_df['prop_annotations'].values
test_annotations = list(map(lambda x: set(x), test_annotations))
go_rels.calculate_ic(annotations + test_annotations)


In [5]:
test_df.head()

Unnamed: 0,index,proteins,accessions,sequences,annotations,interpros,orgs,exp_annotations,prop_annotations,cafa_target
1253,6710,ACK1_YEAST,Q07622; D6VRF1;,MVNQGQPQPNLYDKHINMFPPARARESSHKLGNANSDRHGLPAQNI...,"[GO:0005739|HDA, GO:0008047|IBA, GO:0031505|IM...","[IPR006597, IPR011990]",559292,"[GO:0005739, GO:0031505, GO:0009967]","[GO:0043227, GO:0005737, GO:0110165, GO:003150...",True
33330,210625,KIFC1_CRIGR,Q60443;,MKEALEPAKKRTRGLGAVTKIDTSRSKGPLLSSLSQPQGPTAAQKG...,"[GO:0005769|IEA, GO:0005874|IEA, GO:0005815|IE...","[IPR019821, IPR001752, IPR036961, IPR027417]",10029,[GO:0030496],"[GO:0110165, GO:0005575, GO:0030496]",False
18394,109472,DSC3_HUMAN,Q14574; A6NN35; Q14200; Q9HAZ9;,MAAAGPRRSVRGAVCLHLLLTLVIFSRAGEACKKVILNVPSKLEAD...,"[GO:0030054|IDA, GO:0005911|IBA, GO:0001533|TA...","[IPR002126, IPR015919, IPR020894, IPR000233, I...",9606,"[GO:0030054, GO:0001533, GO:0016020, GO:000588...","[GO:0009913, GO:0140096, GO:0022610, GO:004358...",True
95,687,2AAA_SCHPO,Q9UT08; Q10293;,MQTENQVNDLYPIAVLIDELKHDEITYRLNALERLSTIALALGPER...,"[GO:0005737|IBA, GO:0005829|HDA, GO:0090443|ID...","[IPR011989, IPR016024, IPR000357, IPR021133]",284812,"[GO:0005829, GO:0090443, GO:0110085, GO:004473...","[GO:0043227, GO:0022402, GO:0035556, GO:004578...",True
11368,62522,CH1CO_SYNAS,Q2LQN9;,MKGPIKFNALSLQGRSVMSNQSNDTTITQRRDTMNELTEEQKLLME...,"[GO:0003995|IEA, GO:0050660|IDA, GO:0052890|ID...","[IPR006089, IPR006091, IPR036250, IPR009075, I...",56780,"[GO:0050660, GO:0052890, GO:0051262]","[GO:0043167, GO:1901363, GO:0051262, GO:005125...",False


## Diamond model

In [6]:

diamond_scores = get_diamond_scores(data_path_dict['test_diamond_res'])
blast_preds = get_diamond_preds(train_df, test_df, diamond_scores)

## DeeGoPlus Model

In [7]:
predictions = pd.read_pickle(data_path_dict['predictions_pkl'])

terms_df = pd.read_pickle(data_path_dict['terms_pkl'])
terms = terms_df['terms'].values.flatten()

ics = {}
for term in terms:
    ics[term] = go_rels.get_ic(term)

prot_index = {}
for i, row in enumerate(train_df.itertuples()):
    prot_index[row.proteins] = i



In [8]:
predictions['prot_len'] = predictions['sequences'].str.len()
predictions.head()

Unnamed: 0,index,proteins,accessions,sequences,annotations,interpros,orgs,exp_annotations,prop_annotations,cafa_target,labels,preds,prot_len
1253,6710,ACK1_YEAST,Q07622; D6VRF1;,MVNQGQPQPNLYDKHINMFPPARARESSHKLGNANSDRHGLPAQNI...,"[GO:0005739|HDA, GO:0008047|IBA, GO:0031505|IM...","[IPR006597, IPR011990]",559292,"[GO:0005739, GO:0031505, GO:0009967]","[GO:0043227, GO:0005737, GO:0110165, GO:003150...",True,"[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...","[0.9028676, 0.012715803, 0.0014727293, 0.00745...",623
33330,210625,KIFC1_CRIGR,Q60443;,MKEALEPAKKRTRGLGAVTKIDTSRSKGPLLSSLSQPQGPTAAQKG...,"[GO:0005769|IEA, GO:0005874|IEA, GO:0005815|IE...","[IPR019821, IPR001752, IPR036961, IPR027417]",10029,[GO:0030496],"[GO:0110165, GO:0005575, GO:0030496]",False,"[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...","[0.9455861, 0.0015807436, 3.999437e-05, 0.0003...",622
18394,109472,DSC3_HUMAN,Q14574; A6NN35; Q14200; Q9HAZ9;,MAAAGPRRSVRGAVCLHLLLTLVIFSRAGEACKKVILNVPSKLEAD...,"[GO:0030054|IDA, GO:0005911|IBA, GO:0001533|TA...","[IPR002126, IPR015919, IPR020894, IPR000233, I...",9606,"[GO:0030054, GO:0001533, GO:0016020, GO:000588...","[GO:0009913, GO:0140096, GO:0022610, GO:004358...",True,"[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...","[0.8087489, 0.0016527537, 0.00028462024, 0.000...",896
95,687,2AAA_SCHPO,Q9UT08; Q10293;,MQTENQVNDLYPIAVLIDELKHDEITYRLNALERLSTIALALGPER...,"[GO:0005737|IBA, GO:0005829|HDA, GO:0090443|ID...","[IPR011989, IPR016024, IPR000357, IPR021133]",284812,"[GO:0005829, GO:0090443, GO:0110085, GO:004473...","[GO:0043227, GO:0022402, GO:0035556, GO:004578...",True,"[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...","[0.9657183, 0.0043522124, 0.0008709415, 0.0034...",590
11368,62522,CH1CO_SYNAS,Q2LQN9;,MKGPIKFNALSLQGRSVMSNQSNDTTITQRRDTMNELTEEQKLLME...,"[GO:0003995|IEA, GO:0050660|IDA, GO:0052890|ID...","[IPR006089, IPR006091, IPR036250, IPR009075, I...",56780,"[GO:0050660, GO:0052890, GO:0051262]","[GO:0043167, GO:1901363, GO:0051262, GO:005125...",False,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.13663456, 0.00061553024, 0.00029366414, 0.0...",414


## threshold = 0.2

In [9]:
def get_diamond_preds_multilabel(blast_preds, terms):
    diamond_preds = []
    terms_dict = {prot:idx for idx, prot in enumerate(terms)}
    for protein_res in blast_preds:
        multi_label = [0] * len(terms)
        for prot_id, score in protein_res.items():
            if prot_id in terms_dict:
                idx = terms_dict[prot_id]
                multi_label[int(idx)] = score
        diamond_preds.append(multi_label)
    return  diamond_preds

In [10]:
def idx_to_term(preds, terms, threshod):
    model_preds = []
    for prot_pred in preds:
        term_name_list = terms[prot_pred > threshod]
        model_preds.append(term_name_list)
    return model_preds

In [11]:
def filter_preds_multilabel(preds, threshold):
    filter_res = []
    for pred in preds:
        res = [1 if pp > threshold else 0 for pp in pred]
        filter_res.append(res)
    return filter_res

def filter_preds(preds, threshold):
    filter_res = []
    for pred in preds:
        res = [k for k,v in pred.items() if v > threshold]
        filter_res.append(res)
    return  filter_res

In [12]:
blast_preds_multilabel = get_diamond_preds_multilabel(blast_preds, terms)
blast_preds_multilabel = filter_preds_multilabel(blast_preds_multilabel, threshold=0.2)
blast_preds_filter = filter_preds(blast_preds, threshold=0.2)

In [13]:
model_preds = get_model_preds(predictions, terms)
model_preds_filter = filter_preds(model_preds, threshold=0.2)
model_pred_multilabel = predictions['preds'].tolist()
model_pred_multilabel = filter_preds_multilabel(model_pred_multilabel, threshold=0.2)
predictions['model_pred_multilabel'] = list(model_pred_multilabel)
predictions['model_preds'] = list(model_preds_filter)

In [14]:
predictions['balst_pred_multilabel'] = list(blast_preds_multilabel)
predictions['blast_preds'] = list(blast_preds_filter)

In [15]:
predictions.head()

Unnamed: 0,index,proteins,accessions,sequences,annotations,interpros,orgs,exp_annotations,prop_annotations,cafa_target,labels,preds,prot_len,model_pred_multilabel,model_preds,balst_pred_multilabel,blast_preds
1253,6710,ACK1_YEAST,Q07622; D6VRF1;,MVNQGQPQPNLYDKHINMFPPARARESSHKLGNANSDRHGLPAQNI...,"[GO:0005739|HDA, GO:0008047|IBA, GO:0031505|IM...","[IPR006597, IPR011990]",559292,"[GO:0005739, GO:0031505, GO:0009967]","[GO:0043227, GO:0005737, GO:0110165, GO:003150...",True,"[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...","[0.9028676, 0.012715803, 0.0014727293, 0.00745...",623,"[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[GO:0110165, GO:0005575, GO:0003674, GO:000815...","[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[GO:0000131, GO:0003674, GO:0003824, GO:000548..."
33330,210625,KIFC1_CRIGR,Q60443;,MKEALEPAKKRTRGLGAVTKIDTSRSKGPLLSSLSQPQGPTAAQKG...,"[GO:0005769|IEA, GO:0005874|IEA, GO:0005815|IE...","[IPR019821, IPR001752, IPR036961, IPR027417]",10029,[GO:0030496],"[GO:0110165, GO:0005575, GO:0030496]",False,"[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...","[0.9455861, 0.0015807436, 3.999437e-05, 0.0003...",622,"[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[GO:0110165, GO:0005575, GO:0003674, GO:000815...","[1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, ...","[GO:0000003, GO:0000070, GO:0000226, GO:000027..."
18394,109472,DSC3_HUMAN,Q14574; A6NN35; Q14200; Q9HAZ9;,MAAAGPRRSVRGAVCLHLLLTLVIFSRAGEACKKVILNVPSKLEAD...,"[GO:0030054|IDA, GO:0005911|IBA, GO:0001533|TA...","[IPR002126, IPR015919, IPR020894, IPR000233, I...",9606,"[GO:0030054, GO:0001533, GO:0016020, GO:000588...","[GO:0009913, GO:0140096, GO:0022610, GO:004358...",True,"[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...","[0.8087489, 0.0016527537, 0.00028462024, 0.000...",896,"[1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, ...","[GO:0110165, GO:0005575, GO:0032501, GO:000727...","[1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, ...","[GO:0003674, GO:0003824, GO:0005488, GO:000551..."
95,687,2AAA_SCHPO,Q9UT08; Q10293;,MQTENQVNDLYPIAVLIDELKHDEITYRLNALERLSTIALALGPER...,"[GO:0005737|IBA, GO:0005829|HDA, GO:0090443|ID...","[IPR011989, IPR016024, IPR000357, IPR021133]",284812,"[GO:0005829, GO:0090443, GO:0110085, GO:004473...","[GO:0043227, GO:0022402, GO:0035556, GO:004578...",True,"[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...","[0.9657183, 0.0043522124, 0.0008709415, 0.0034...",590,"[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[GO:0110165, GO:0005575, GO:0003674, GO:002241...","[1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, ...","[GO:0000070, GO:0000159, GO:0000226, GO:000027..."
11368,62522,CH1CO_SYNAS,Q2LQN9;,MKGPIKFNALSLQGRSVMSNQSNDTTITQRRDTMNELTEEQKLLME...,"[GO:0003995|IEA, GO:0050660|IDA, GO:0052890|ID...","[IPR006089, IPR006091, IPR036250, IPR009075, I...",56780,"[GO:0050660, GO:0052890, GO:0051262]","[GO:0043167, GO:1901363, GO:0051262, GO:005125...",False,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.13663456, 0.00061553024, 0.00029366414, 0.0...",414,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[GO:0003674, GO:0008150, GO:0043167, GO:190136...","[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[GO:0000166, GO:0003674, GO:0003824, GO:000399..."


In [16]:
def compute_acc(labels, preds):
    scores = []
    for label, pred in zip(labels, preds):
        acc = precision_score(label, pred)
        scores.append(acc)
    return scores

def compute_recall(labels, preds):
    scores = []
    for label, pred in zip(labels, preds):
        acc = recall_score(label, pred)
        scores.append(acc)
    return scores


def compute_f1score(labels, preds):
    scores = []
    for label, pred in zip(labels, preds):
        acc = f1_score(label, pred)
        scores.append(acc)
    return scores

In [17]:
multilabels = predictions['labels'].tolist()
model_pred_acc = compute_acc(multilabels, model_pred_multilabel)
blast_pred_acc = compute_acc(multilabels, blast_preds_multilabel)

model_pred_recall = compute_recall(multilabels, model_pred_multilabel)
blast_pred_recall = compute_recall(multilabels, blast_preds_multilabel)

model_pred_f1score = compute_f1score(multilabels, model_pred_multilabel)
blast_pred_f1score = compute_f1score(multilabels, blast_preds_multilabel)

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_pr

In [29]:
df = pd.DataFrame(zip(model_pred_acc, blast_pred_acc, model_pred_recall, blast_pred_recall, model_pred_f1score, blast_pred_f1score), index=list(predictions.proteins),
                                              columns =['model_acc', 'blast_acc', 'model_recall', 'blast_recall', 'model_f1_score', 'blast_f1_score'])
df['prot_len'] = predictions['prot_len'].tolist()

In [30]:
df.head()

Unnamed: 0,model_acc,blast_acc,model_recall,blast_recall,model_f1_score,blast_f1_score,prot_len
ACK1_YEAST,0.272727,0.238095,0.486486,0.405405,0.349515,0.3,623
KIFC1_CRIGR,0.033333,0.025316,0.666667,0.666667,0.063492,0.04878,622
DSC3_HUMAN,0.589286,0.451613,0.634615,0.807692,0.611111,0.57931,896
2AAA_SCHPO,0.480519,0.352459,0.389474,0.452632,0.430233,0.396313,590
CH1CO_SYNAS,0.681818,0.230769,0.652174,0.652174,0.666667,0.340909,414


In [49]:
import plotly.express as px
fig = px.scatter(df, x="blast_acc", y="model_acc",  hover_data=[df.index])
fig.show()

In [50]:
import plotly.express as px
fig = px.scatter(df, x="blast_recall", y="model_recall",  hover_data=[df.index])
fig.show()

In [51]:
import plotly.express as px
fig = px.scatter(df, x="blast_f1_score", y="model_f1_score",  hover_data=[df.index])
fig.show()

In [52]:
import plotly.express as px
fig = px.scatter(df, x="prot_len", y="model_acc",  hover_data=[df.index])
fig.show()

In [53]:
import plotly.express as px
fig = px.scatter(df, x="prot_len", y="blast_acc",  hover_data=[df.index])
fig.show()

In [54]:
import plotly.express as px
fig = px.scatter(df, x="prot_len", y="model_recall",  hover_data=[df.index])
fig.show()

In [55]:
import plotly.express as px
fig = px.scatter(df, x="prot_len", y="blast_recall",  hover_data=[df.index])
fig.show()

In [56]:
import plotly.express as px
fig = px.scatter(df, x="prot_len", y="blast_f1_score",  hover_data=[df.index])
fig.show()

In [57]:
import plotly.express as px
fig = px.scatter(df, x="prot_len", y="model_f1_score",  hover_data=[df.index])
fig.show()