In [1]:
from Bio.SeqIO.FastaIO import SimpleFastaParser
import pandas as pd
import numpy as np

In [2]:
root='/media/johannes/Crucial SSD/PP2CS/'
seq_files = {
    'test': 'ec_vs_NOec_pide20_c50_test.fasta',
    'train': 'ec_vs_NOec_pide20_c50_train.fasta',
    'val': 'ec_vs_NOec_pide20_c50_val.fasta'
}
data_sets = ['train','test','val']
ec_annotation_file = 'merged_anno.txt'

data_frames = {}

for data_set in data_sets:
    with open(root + seq_files[data_set]) as fasta:
        ids = []
        seqs = []
        lengths = []
        for title, seq in SimpleFastaParser(fasta):
            ids.append(title)
            seqs.append(seq)
            lengths.append(len(seq))
        data = {"sequence": seqs, "length": lengths,"dataset":data_set}
        data_frames[data_set] = pd.DataFrame(data=data,index=ids)
annotations =  pd.read_csv(root + ec_annotation_file,sep="\\t",names=['index','EC'],index_col=0, engine="python")

for data_set in data_sets:
    data_frames[data_set] = data_frames[data_set].merge(annotations,how="left",left_index=True, right_index=True)
    data_frames[data_set]['ec_or_nc'] = np.where(data_frames[data_set]['EC'].isna(), 'NC', 'EC')

In [3]:
import h5py
import torch

embeddings_train = dict()
embeddings_val = dict()

with h5py.File(root + 'ec_vs_NOec_pide20_c50_val.h5', 'r') as f:
    # List all groups
    print("Keys: %s" % f.keys())
    a_group_key = list(f.keys())[0]
    for key in f.keys():
        embeddings_val[key] = np.array(f[key][:]) # needs about 600MB RAM


pos_dict = dict()
emb_train_list = []
embeddings_train_concat = torch.empty([sum(data_frames['train']['length']), 1024], dtype=torch.float32)
with h5py.File(root + 'ec_vs_NOec_pide20_c50_train.h5', 'r') as f:
    # List all groups
    a_group_key = list(f.keys())[0]
    pos = 0
    for i, key in enumerate(f.keys()):
        if i %1000 == 0:
            print (i/len(f.keys()))
        embeddings_train_concat[pos:pos + len(f[key])] = torch.tensor(f[key][:], dtype=torch.float32)
        for x in range(pos, pos + len(f[key])):
            pos_dict[x] = key
        pos += len(f[key]) # store last position of protein sequence
print ("Finished loading training embeddings")

Keys: <KeysViewHDF5 ['A0A023PZC7', 'A0A023PZE6', 'A0A163UT06', 'A0A411KUP5', 'A0QRX9', 'A1L4Q6', 'A3MWN7', 'A3RGB0', 'A4R2Q6', 'A5D7C3', 'A5TZH0', 'A5UQF4', 'A6NNP5', 'A6VLC4', 'A7KAI8', 'A7UZ95', 'A7XCE8', 'A8E657', 'A8ESZ6', 'A8FDN5', 'A8HMZ4', 'A8MTZ7', 'A8QHQ0', 'B2A2M1', 'B3DHH5', 'B4YNG0', 'B7M9S5', 'B8II14', 'C0HJE6', 'C4XIR5', 'C6Y4B9', 'D0N4E0', 'D3Z9M3', 'D4AN96', 'D5KXG8', 'E1WAB4', 'F4KCE9', 'F5HEN7', 'F5HGJ4', 'G0S902', 'G2TRP5', 'G2TRR5', 'G5ECG2', 'J3K844', 'K7EIQ3', 'O07623', 'O07636', 'O10339', 'O13532', 'O13714', 'O13954', 'O14050', 'O14220', 'O19888', 'O23550', 'O25711', 'O26133', 'O27941', 'O28187', 'O28280', 'O28700', 'O28816', 'O29484', 'O29496', 'O29762', 'O29839', 'O30238', 'O30248', 'O31878', 'O31897', 'O31915', 'O32126', 'O32864', 'O34843', 'O35870', 'O43676', 'O43952', 'O48459', 'O52728', 'O53780', 'O61764', 'O64221', 'O64642', 'O67062', 'O67818', 'O71191', 'O74466', 'O80678', 'O80996', 'O82289', 'O82391', 'O83391', 'O83705', 'O83711', 'O83777', 'O83805', 'O8

In [4]:
# Basic idea:
# for all residue embeddings of a protein in the dev/test set
#   go through all residue embeddings of the train set
#   note category of most similar embedding in train set
# do majority vote to determine category of the protein in the dev/test set

torch.set_num_threads(12)

def most_similar_in_train(b,b_n, emb):
    sim_names = ["EMPTY"] * emb.shape[0]
    # Source https://en.wikipedia.org/wiki/Cosine_similarity
    concat_res = torch.Tensor()
    emb = torch.chunk(emb, int(emb.shape[0]/250 + 1)) # RAM dependent
    for a in emb:
        with torch.no_grad():
            upper = torch.matmul(a,b.T)
            a_n = a.norm(dim=1)[:, None]
            lower = torch.matmul(a_n, b_n.T)
            cosine_similarity = upper/lower + 1e-8 # eps for stability
            arg_cur_max = torch.argmax(cosine_similarity,dim=1)
            del cosine_similarity,a,lower,upper
            concat_res = torch.cat([concat_res, arg_cur_max])
    for i, res in enumerate(concat_res):
        sim_names[i] = pos_dict[res.item()]
    for index, name in enumerate(sim_names):
        sim_names[index] = data_frames['train']['ec_or_nc'][name]
    return sim_names #data_frames['train']['ec_or_nc'][sim_prot]


In [None]:
pred = dict()
b = embeddings_train_concat
b_n = b.norm(dim=1)[:, None] 
for i, protein in enumerate(data_frames['val'].index):
    sim = most_similar_in_train(b, b_n, torch.tensor(embeddings_val[protein], dtype=torch.float32))
    if sim.count("EC") > sim.count("NC"):
        res = "EC"
    else:
        res = "NC"
    print ("Protein " + protein+ ": Predicted: " + res + ", Actual: " + data_frames['val']['ec_or_nc'][protein])
    print ("Vote ended: {} EC vs {} NC".format(sim.count("EC"), sim.count("NC")))
    pred[protein] = res

Protein Q2V3S8: Predicted: NC, Actual: NC
Vote ended: 0 EC vs 83 NC
Protein Q9VES1: Predicted: NC, Actual: NC
Vote ended: 12 EC vs 185 NC
Protein P25201: Predicted: EC, Actual: EC
Vote ended: 277 EC vs 263 NC
Protein Q82IY3: Predicted: EC, Actual: EC
Vote ended: 333 EC vs 116 NC
Protein O64642: Predicted: NC, Actual: EC
Vote ended: 307 EC vs 457 NC
Protein Q48509: Predicted: NC, Actual: NC
Vote ended: 1 EC vs 61 NC
Protein P0DL35: Predicted: NC, Actual: NC
Vote ended: 0 EC vs 66 NC
Protein P37249: Predicted: NC, Actual: NC
Vote ended: 38 EC vs 136 NC
Protein Q8ZNR3: Predicted: NC, Actual: EC
Vote ended: 113 EC vs 669 NC
Protein P50534: Predicted: NC, Actual: EC
Vote ended: 119 EC vs 654 NC
Protein P15972: Predicted: NC, Actual: NC
Vote ended: 1 EC vs 64 NC
Protein P16796: Predicted: NC, Actual: NC
Vote ended: 0 EC vs 66 NC
Protein P19369: Predicted: NC, Actual: NC
Vote ended: 5 EC vs 65 NC
Protein P74095: Predicted: NC, Actual: NC
Vote ended: 29 EC vs 47 NC


In [None]:
print (data_frames['train']['ec_or_nc']['Q4ZJZ1'])
print (data_frames['train']['ec_or_nc']['P0DL49'])
print (data_frames['train']['ec_or_nc']['Q4ZJZ2'])
print (data_frames['train']['ec_or_nc']['P82746'])
print (sim.count("EC"))
print (sim.count("NC"))