In [1]:
from Bio.SeqIO.FastaIO import SimpleFastaParser
import pandas as pd
import numpy as np

In [2]:
root='/media/johannes/Crucial SSD/PP2CS/'
seq_files = {
    'test': 'ec_vs_NOec_pide20_c50_test.fasta',
    'train': 'ec_vs_NOec_pide20_c50_train.fasta',
    'val': 'ec_vs_NOec_pide20_c50_val.fasta'
}
data_sets = ['train','test','val']
ec_annotation_file = 'merged_anno.txt'

data_frames = {}

for data_set in data_sets:
    with open(root + seq_files[data_set]) as fasta:
        ids = []
        seqs = []
        lengths = []
        for title, seq in SimpleFastaParser(fasta):
            ids.append(title)
            seqs.append(seq)
            lengths.append(len(seq))
        data = {"sequence": seqs, "length": lengths,"dataset":data_set}
        data_frames[data_set] = pd.DataFrame(data=data,index=ids)
annotations =  pd.read_csv(root + ec_annotation_file,sep="\\t",names=['index','EC'],index_col=0, engine="python")

for data_set in data_sets:
    data_frames[data_set] = data_frames[data_set].merge(annotations,how="left",left_index=True, right_index=True)
    data_frames[data_set]['ec_or_nc'] = np.where(data_frames[data_set]['EC'].isna(), 'NC', 'EC')

In [3]:
import h5py
import torch

embeddings_train = dict()
embeddings_val = dict()

with h5py.File(root + 'ec_vs_NOec_pide20_c50_val.h5', 'r') as f:
    # List all groups
    #print("Keys: %s" % f.keys())
    a_group_key = list(f.keys())[0]
    for key in f.keys():
        embeddings_val[key] = np.array(f[key][:]) # needs about 600MB RAM


pos_dict = dict()
emb_train_list = []
embeddings_train_concat = torch.empty([sum(data_frames['train']['length']), 1024], dtype=torch.float32)
with h5py.File(root + 'ec_vs_NOec_pide20_c50_train.h5', 'r') as f:
    # List all groups
    a_group_key = list(f.keys())[0]
    pos = 0
    for i, key in enumerate(f.keys()):
        if i %1000 == 0:
            print (i/len(f.keys()))
        embeddings_train_concat[pos:pos + len(f[key])] = torch.tensor(f[key][:], dtype=torch.float32)
        for x in range(pos, pos + len(f[key])):
            pos_dict[x] = key
        pos += len(f[key]) # store last position of protein sequence
print ("Finished loading training embeddings")

0.0
0.029871254891417988
0.059742509782835976
0.08961376467425397
0.11948501956567195
0.14935627445708993
0.17922752934850794
0.20909878423992592
0.2389700391313439
0.2688412940227619
0.29871254891417987
0.3285838038055979
0.3584550586970159
0.38832631358843384
0.41819756847985184
0.44806882337126985
0.4779400782626878
0.5078113331541058
0.5376825880455238
0.5675538429369418
0.5974250978283597
0.6272963527197778
0.6571676076111957
0.6870388625026137
0.7169101173940318
0.7467813722854497
0.7766526271768677
0.8065238820682857
0.8363951369597037
0.8662663918511216
0.8961376467425397
0.9260089016339577
0.9558801565253756
0.9857514114167936
Finished loading training embeddings


In [4]:
# Basic idea:
# for all residue embeddings of a protein in the dev/test set
#   go through all residue embeddings of the train set
#   note category of most similar embedding in train set
# do majority vote to determine category of the protein in the dev/test set

torch.set_num_threads(6)

def most_similar_in_train(b,b_n, emb):
    sim_names = ["EMPTY"] * emb.shape[0]
    # Source https://en.wikipedia.org/wiki/Cosine_similarity
    concat_res = torch.Tensor()
    emb = torch.chunk(emb, int(emb.shape[0]/250 + 1)) # RAM dependent
    for a in emb:
        with torch.no_grad():
            upper = torch.matmul(a,b.T)
            a_n = a.norm(dim=1)[:, None]
            lower = torch.matmul(a_n, b_n.T)
            cosine_similarity = upper/lower + 1e-8 # eps for stability
            arg_cur_max = torch.argmax(cosine_similarity,dim=1)
            del cosine_similarity,a,lower,upper
            concat_res = torch.cat([concat_res, arg_cur_max])
    for i, res in enumerate(concat_res):
        sim_names[i] = pos_dict[res.item()]
    for index, name in enumerate(sim_names):
        sim_names[index] = data_frames['train']['ec_or_nc'][name]
    return sim_names #data_frames['train']['ec_or_nc'][sim_prot]


In [5]:
import time
import pickle

pred = dict()
b = embeddings_train_concat
b_n = b.norm(dim=1)[:, None] 
for i, protein in enumerate(data_frames['val'].index):
    #if embeddings_val[protein].shape[0] < 800:
    #    continue
    start = time.time()
    sim = most_similar_in_train(b, b_n, torch.tensor(embeddings_val[protein], dtype=torch.float32))
    end = time.time()
    if sim.count("EC") > sim.count("NC"):
        res = "EC"
    else:
        res = "NC"
    print ("Protein " + protein+ ": Predicted: " + res + ", Actual: " + data_frames['val']['ec_or_nc'][protein])
    print ("Vote ended: {} EC vs {} NC".format(sim.count("EC"), sim.count("NC")))
    print ("Took {} seconds, Progress: {}".format(end-start, i))
    pred[protein] = res
    f = open("file.pkl","wb")
    pickle.dump(pred,f)
    f.close()

Protein Q2V3S8: Predicted: NC, Actual: NC
Vote ended: 0 EC vs 83 NC
Took 8.429917335510254 seconds, Progress: 0
Protein Q9VES1: Predicted: NC, Actual: NC
Vote ended: 12 EC vs 185 NC
Took 14.076719999313354 seconds, Progress: 1
Protein P25201: Predicted: EC, Actual: EC
Vote ended: 277 EC vs 263 NC
Took 37.24457263946533 seconds, Progress: 2
Protein Q82IY3: Predicted: EC, Actual: EC
Vote ended: 333 EC vs 116 NC
Took 30.296613931655884 seconds, Progress: 3
Protein O64642: Predicted: NC, Actual: EC
Vote ended: 307 EC vs 457 NC
Took 53.08911156654358 seconds, Progress: 4
Protein Q48509: Predicted: NC, Actual: NC
Vote ended: 1 EC vs 61 NC
Took 6.81533670425415 seconds, Progress: 5
Protein P0DL35: Predicted: NC, Actual: NC
Vote ended: 0 EC vs 66 NC
Took 7.272502183914185 seconds, Progress: 6
Protein P37249: Predicted: NC, Actual: NC
Vote ended: 38 EC vs 136 NC
Took 12.47273564338684 seconds, Progress: 7
Protein Q8ZNR3: Predicted: NC, Actual: EC
Vote ended: 113 EC vs 669 NC
Took 53.05103969573

In [9]:
p = []
t = []

for elem in pred.keys():
    p.append(pred[elem])
    t.append(data_frames['val']['ec_or_nc'][elem])

import sklearn

print(sklearn.metrics.classification_report(t, p))
print(sklearn.metrics.confusion_matrix(t,p))

              precision    recall  f1-score   support

          EC       0.85      0.49      0.62        81
          NC       0.91      0.98      0.94       419

    accuracy                           0.90       500
   macro avg       0.88      0.74      0.78       500
weighted avg       0.90      0.90      0.89       500

[[ 40  41]
 [  7 412]]
