In [1]:
import os
import numpy as np
from collections import deque
from collections import Counter

import math
import gzip
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import click as ck
from sklearn.metrics import classification_report
from sklearn.metrics.pairwise import cosine_similarity
import sys
import time
from sklearn.metrics import roc_curve, auc, matthews_corrcoef
from scipy.spatial import distance
from scipy import sparse
from matplotlib import pyplot as plt
import json

In [2]:
import os

base_path = r'/Users/robin/xbiome/datasets/protein'
# 存储所有数据文件路径
data_ls = os.walk(base_path).__next__()[2]
data_path_dict = {}
for data in data_ls:
    file_name = data.split('.')[0] + '_' + data.split('.')[1]
    data_path_dict[file_name] = os.path.join(base_path, data)

# Minimum number of annotated proteins in each GO annotation
min_count = 50

# Maximum number of sequence
MAXLEN = 2000

# GO subontology (bp, mf, cc)
onts = ['mf', 'bp', 'cc']

data_path_dict

{'train_data_pkl': '/Users/robin/xbiome/datasets/protein/train_data.pkl',
 'terms_pkl': '/Users/robin/xbiome/datasets/protein/terms.pkl',
 'test_data_pkl': '/Users/robin/xbiome/datasets/protein/test_data.pkl',
 'go_obo': '/Users/robin/xbiome/datasets/protein/go.obo',
 'predictions_pkl': '/Users/robin/xbiome/datasets/protein/predictions.pkl',
 'swissprot_pkl': '/Users/robin/xbiome/datasets/protein/swissprot.pkl',
 'test_diamond_res': '/Users/robin/xbiome/datasets/protein/test_diamond.res',
 'uniprot_sprot_dat': '/Users/robin/xbiome/datasets/protein/uniprot_sprot.dat.gz'}

In [3]:
class Ontology(object):
    def __init__(self, filename='data/go.obo', with_rels=False):
        self.ont = self.load(filename, with_rels)
        self.ic = None

    def has_term(self, term_id):
        return term_id in self.ont

    def get_term(self, term_id):
        if self.has_term(term_id):
            return self.ont[term_id]
        return None

    def calculate_ic(self, annots):
        cnt = Counter()
        for x in annots:
            cnt.update(x)
        self.ic = {}
        for go_id, n in cnt.items():
            parents = self.get_parents(go_id)
            if len(parents) == 0:
                min_n = n
            else:
                min_n = min([cnt[x] for x in parents])

            self.ic[go_id] = math.log(min_n / n, 2)

    def get_ic(self, go_id):
        if self.ic is None:
            raise Exception('Not yet calculated')
        if go_id not in self.ic:
            return 0.0
        return self.ic[go_id]

    def load(self, filename, with_rels):
        ont = dict()
        obj = None
        with open(filename, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                if line == '[Term]':
                    if obj is not None:
                        ont[obj['id']] = obj
                    obj = dict()
                    obj['is_a'] = list()
                    obj['part_of'] = list()
                    obj['regulates'] = list()
                    obj['alt_ids'] = list()
                    obj['is_obsolete'] = False
                    continue
                elif line == '[Typedef]':
                    if obj is not None:
                        ont[obj['id']] = obj
                    obj = None
                else:
                    if obj is None:
                        continue
                    l = line.split(": ")
                    if l[0] == 'id':
                        obj['id'] = l[1]
                    elif l[0] == 'alt_id':
                        obj['alt_ids'].append(l[1])
                    elif l[0] == 'namespace':
                        obj['namespace'] = l[1]
                    elif l[0] == 'is_a':
                        obj['is_a'].append(l[1].split(' ! ')[0])
                    elif with_rels and l[0] == 'relationship':
                        it = l[1].split()
                        # add all types of relationships
                        obj['is_a'].append(it[1])
                    elif l[0] == 'name':
                        obj['name'] = l[1]
                    elif l[0] == 'is_obsolete' and l[1] == 'true':
                        obj['is_obsolete'] = True
            if obj is not None:
                ont[obj['id']] = obj
        for term_id in list(ont.keys()):
            for t_id in ont[term_id]['alt_ids']:
                ont[t_id] = ont[term_id]
            if ont[term_id]['is_obsolete']:
                del ont[term_id]
        for term_id, val in ont.items():
            if 'children' not in val:
                val['children'] = set()
            for p_id in val['is_a']:
                if p_id in ont:
                    if 'children' not in ont[p_id]:
                        ont[p_id]['children'] = set()
                    ont[p_id]['children'].add(term_id)
        return ont

    def get_anchestors(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        q = deque()
        q.append(term_id)
        while (len(q) > 0):
            t_id = q.popleft()
            if t_id not in term_set:
                term_set.add(t_id)
                for parent_id in self.ont[t_id]['is_a']:
                    if parent_id in self.ont:
                        q.append(parent_id)
        return term_set

    def get_parents(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        for parent_id in self.ont[term_id]['is_a']:
            if parent_id in self.ont:
                term_set.add(parent_id)
        return term_set

    def get_namespace_terms(self, namespace):
        terms = set()
        for go_id, obj in self.ont.items():
            if obj['namespace'] == namespace:
                terms.add(go_id)
        return terms

    def get_namespace(self, term_id):
        return self.ont[term_id]['namespace']

    def get_term_set(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        q = deque()
        q.append(term_id)
        while len(q) > 0:
            t_id = q.popleft()
            if t_id not in term_set:
                term_set.add(t_id)
                for ch_id in self.ont[t_id]['children']:
                    q.append(ch_id)
        return term_set

In [4]:
BIOLOGICAL_PROCESS = 'GO:0008150'
MOLECULAR_FUNCTION = 'GO:0003674'
CELLULAR_COMPONENT = 'GO:0005575'

FUNC_DICT = {
    'cc': CELLULAR_COMPONENT,
    'mf': MOLECULAR_FUNCTION,
    'bp': BIOLOGICAL_PROCESS}

NAMESPACES = {
    'cc': 'cellular_component',
    'mf': 'molecular_function',
    'bp': 'biological_process'
}

In [5]:
def compute_roc(labels, preds):
    # Compute ROC curve and ROC area for each class
    fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
    roc_auc = auc(fpr, tpr)
    return roc_auc

def compute_mcc(labels, preds):
    # Compute ROC curve and ROC area for each class
    mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
    return mcc

In [6]:
def evaluate_annotations(go, real_annots, pred_annots):
    total = 0
    p = 0.0
    r = 0.0
    p_total= 0
    ru = 0.0
    mi = 0.0
    fps = []
    fns = []
    for i in range(len(real_annots)):
        if len(real_annots[i]) == 0:
            continue
        tp = set(real_annots[i]).intersection(set(pred_annots[i]))
        fp = pred_annots[i] - tp
        fn = real_annots[i] - tp
        for go_id in fp:
            mi += go.get_ic(go_id)
        for go_id in fn:
            ru += go.get_ic(go_id)
        fps.append(fp)
        fns.append(fn)
        tpn = len(tp)
        fpn = len(fp)
        fnn = len(fn)
        total += 1
        recall = tpn / (1.0 * (tpn + fnn))
        r += recall
        if len(pred_annots[i]) > 0:
            p_total += 1
            precision = tpn / (1.0 * (tpn + fpn))
            p += precision
    ru /= total
    mi /= total
    r /= total
    if p_total > 0:
        p /= p_total
    f = 0.0
    if p + r > 0:
        f = 2 * p * r / (p + r)
    s = math.sqrt(ru * ru + mi * mi)
    return f, p, r, s, ru, mi, fps, fns

In [7]:
def compute_fmax():
    fmax = 0.0
    tmax = 0.0
    precisions = []
    recalls = []
    smin = 1000000.0
    rus = []
    mis = []
    for t in range(1, 101, 10): # the range in this loop has influence in the AUPR output
        threshold = t / 100.0
        preds = []
        for i, row in enumerate(test_df.itertuples()):
            annots = set()
            for go_id, score in deep_preds[i].items():
                if score >= threshold:
                    annots.add(go_id)

            new_annots = set()
            for go_id in annots:
                new_annots |= go_rels.get_anchestors(go_id)
            preds.append(new_annots)
            
        # Filter classes
        preds = list(map(lambda x: set(filter(lambda y: y in go_set, x)), preds))
    
        fscore, prec, rec, s, ru, mi, fps, fns = evaluate_annotations(go_rels, labels, preds)
        avg_fp = sum(map(lambda x: len(x), fps)) / len(fps)
        avg_ic = sum(map(lambda x: sum(map(lambda go_id: go_rels.get_ic(go_id), x)), fps)) / len(fps)
        # print(f'{avg_fp} {avg_ic}')
        precisions.append(prec)
        recalls.append(rec)
        # print(f'Fscore: {fscore}, Precision: {prec}, Recall: {rec} S: {s}, RU: {ru}, MI: {mi} threshold: {threshold}')
        if fmax < fscore:
            fmax = fscore
            tmax = threshold
        if smin > s:
            smin = s
    print(f'threshold: {tmax}')
    print(f'Smin: {smin:0.3f}')
    print(f'Fmax: {fmax:0.3f}')
    precisions = np.array(precisions)
    recalls = np.array(recalls)
    sorted_index = np.argsort(recalls)
    recalls = recalls[sorted_index]
    precisions = precisions[sorted_index]
    aupr = np.trapz(precisions, recalls)
    print(f'AUPR: {aupr:0.3f}')
#     plt.figure()
#     lw = 2
#     plt.plot(recalls, precisions, color='darkorange',
#              lw=lw, label=f'AUPR curve (area = {aupr:0.2f})')
#     plt.xlim([0.0, 1.0])
#     plt.ylim([0.0, 1.05])
#     plt.xlabel('Recall')
#     plt.ylabel('Precision')
#     plt.title('Area Under the Precision-Recall curve')
#     plt.legend(loc="lower right")
#     plt.savefig(f'results/aupr_{ont}_{alpha:0.2f}.pdf')
#     df = pd.DataFrame({'precisions': precisions, 'recalls': recalls})
#     df.to_pickle(f'results/PR_{ont}_{alpha:0.2f}.pkl')

In [8]:
go_rels = Ontology(data_path_dict['go_obo'], with_rels=True)
terms_df = pd.read_pickle(data_path_dict['terms_pkl'])
terms = terms_df['terms'].values.flatten()
terms_dict = {v: i for i, v in enumerate(terms)}

In [9]:
import pickle

train_df = pd.read_pickle(data_path_dict['train_data_pkl'])
test_df = pd.read_pickle(data_path_dict['predictions_pkl'])

print(len(test_df))

3874


In [10]:
annotations = train_df['prop_annotations'].values
annotations = list(map(lambda x: set(x), annotations))
test_annotations = test_df['prop_annotations'].values
test_annotations = list(map(lambda x: set(x), test_annotations))
go_rels.calculate_ic(annotations + test_annotations)

In [11]:
ics = {}
for term in terms:
    ics[term] = go_rels.get_ic(term)

prot_index = {}
for i, row in enumerate(train_df.itertuples()):
    prot_index[row.proteins] = i

In [12]:
diamond_scores = {}
with open(data_path_dict['test_diamond_res']) as f:
    for line in f:
        it = line.strip().split()
        if it[0] not in diamond_scores:
            diamond_scores[it[0]] = {}
        diamond_scores[it[0]][it[1]] = float(it[2])

In [13]:
# DeepGOPlus
for ont in onts:
    print(ont)
    go_set = go_rels.get_namespace_terms(NAMESPACES[ont])
    go_set.remove(FUNC_DICT[ont])
    labels = test_df['prop_annotations'].values
    labels = list(map(lambda x: set(filter(lambda y: y in go_set, x)), labels))
    deep_preds = []

    for i, row in enumerate(test_df.itertuples()):
        annots_dict = {}
        for j, score in enumerate(row.preds):
            go_id = terms[j]
            annots_dict[go_id] = score
        deep_preds.append(annots_dict)

    compute_fmax()

mf
threshold: 0.11
Smin: 14.854
Fmax: 0.353
AUPR: 0.286
bp
threshold: 0.21
Smin: 52.248
Fmax: 0.329
AUPR: 0.279
cc
threshold: 0.31
Smin: 12.719
Fmax: 0.614
AUPR: 0.614


In [14]:
'''
print('AUTHOR DeepGOPlus')
print('MODEL 1')
print('KEYWORDS sequence alignment.')
for i, row in enumerate(test_df.itertuples()):
    prot_id = row.proteins
    for go_id, score in deep_preds[i].items():
        print(f'{prot_id}\t{go_id}\t{score:.2f}')
'''

"\nprint('AUTHOR DeepGOPlus')\nprint('MODEL 1')\nprint('KEYWORDS sequence alignment.')\nfor i, row in enumerate(test_df.itertuples()):\n    prot_id = row.proteins\n    for go_id, score in deep_preds[i].items():\n        print(f'{prot_id}\t{go_id}\t{score:.2f}')\n"