## Siamese Neural Network for predicting PPIs from function annotations

### Imports

In [36]:
import numpy as np
import click as ck
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Dense, Concatenate, Dot, Activation
)
from tensorflow.keras import optimizers
from tensorflow.keras import constraints
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
from tensorflow.keras import backend as K
from tensorflow.keras.models import Sequential, Model
import math
from scipy.stats import rankdata
from elembeddings.utils import Ontology

### Prepare training and testing data

In [2]:
org_id = '4932'

def load_train_data(data_file):
    data = []
    proteins = {}
    with open(data_file, 'r') as f:
        for line in f:
            it = line.strip().split()
            id1 = it[0]
            id2 = it[1]
            if id1 not in proteins:
                proteins[id1] = len(proteins)
            if id2 not in proteins:
                proteins[id2] = len(proteins)
            data.append((proteins[id1], proteins[id2]))
    return data, proteins

def load_test_data(data_file, proteins):
    data = []
    with open(data_file, 'r') as f:
        for line in f:
            it = line.strip().split()
            id1 = it[0]
            id2 = it[1]
            if id1 not in proteins or id2 not in proteins:
                continue
            data.append((proteins[id1], proteins[id2]))
    return data

train_data, proteins = load_train_data(f'data/train/{org_id}.protein.links.v11.0.txt')
valid_data = load_test_data(f'data/valid/{org_id}.protein.links.v11.0.txt', proteins)
test_data = load_test_data(f'data/test/{org_id}.protein.links.v11.0.txt', proteins)
print('Number of proteins in training: ', len(proteins))
print('Training interactions: ', len(train_data))
print('Validation interactions: ', len(valid_data))
print('Testing interactions: ', len(test_data))

Number of proteins in training:  5926
Training interactions:  152386
Validation interactions:  37840
Testing interactions:  47284


### Load functional annotations

In [38]:
def load_annotations(data_file, proteins, propagate=False):
    go = Ontology('data/go.obo')
    annots = {}
    functions = set()
    with open(data_file, 'r') as f:
        for line in f:
            it = line.strip().split('\t')
            if it[0] not in proteins:
                continue
            p_id = proteins[it[0]]
            if p_id not in annots:
                annots[p_id] = set()
            annots[p_id].add(it[1])
            if propagate and go.has_term(it[1]):
                annots[p_id] |= go.get_anchestors(it[1])
                functions |= go.get_anchestors(it[1])
    functions = list(functions)
    return annots, functions

# Run this function with propagate=False to use annotations without propagation with ontology structure
annotations, functions = load_annotations(f'data/train/{org_id}.annotation.txt', proteins, propagate=True)
print('Loaded annotations for', len(annotations), 'proteins')
print('Total number of distinct functions', len(functions))
functions_ix = {k:i for i, k in enumerate(functions)}

Loaded annotations for 5275 proteins
Total number of distinct functions 8384


### Generator object for feeding neural network model

In [40]:
class Generator(object):

    def __init__(self, data, proteins, annotations, train_pairs, functions_ix, batch_size=128, steps=100):
        self.data = data
        self.batch_size = batch_size
        self.steps = steps
        self.start = 0
        self.functions_ix = functions_ix
        self.input_length = len(functions_ix)
        self.train_pairs = train_pairs
        self.proteins = proteins
        self.annotations = annotations
    
    def __iter__(self):
        return self
    
    def __next__(self):
        return self.next()

    def reset(self):
        self.start = 0

    def next(self):
        if self.start < self.steps:
            batch_pos = self.data[self.start * self.batch_size: (self.start + 1) * self.batch_size]
            batch_neg = []
            for pr1, pr2 in batch_pos:
                flag = np.random.choice([True, False])
                while True:
                    neg = np.random.randint(0, len(self.proteins))
                    if flag:
                        if (pr1, neg) not in train_pairs:
                            batch_neg.append((pr1, neg))
                            break
                    else:
                        if (neg, pr2) not in train_pairs:
                            batch_neg.append((neg, pr2))
                            break
            batch_data = np.array(batch_pos + batch_neg)
            labels = np.array([1] * len(batch_pos) + [0] * len(batch_neg))
            index = np.arange(len(batch_data))
            np.random.shuffle(index)
            batch_data = batch_data[index]
            labels = labels[index]
            p1 = np.zeros((len(batch_data), self.input_length), dtype=np.float32)
            p2 = np.zeros((len(batch_data), self.input_length), dtype=np.float32)
            for i in range(len(batch_data)):
                if batch_data[i, 0] in self.annotations:
                    for go_id in self.annotations[batch_data[i, 0]]:
                        p1[i, self.functions_ix[go_id]] = 1.0
                if batch_data[i, 1] in self.annotations:
                    for go_id in self.annotations[batch_data[i, 1]]:
                        p2[i, self.functions_ix[go_id]] = 1.0
            self.start += 1
            return ([p1, p2], labels)
        else:
            self.reset()
train_pairs = set(train_data)
batch_size = 128
train_steps = int(math.ceil(len(train_data) / batch_size))
train_generator = Generator(
    train_data, proteins, annotations, train_pairs, functions_ix, batch_size=batch_size, steps=train_steps)
valid_steps = int(math.ceil(len(valid_data) / batch_size))
valid_generator = Generator(
    valid_data, proteins, annotations, train_pairs, functions_ix, batch_size=batch_size, steps=valid_steps)
test_steps = int(math.ceil(len(test_data) / batch_size))
test_generator = Generator(
    test_data, proteins, annotations, train_pairs, functions_ix, batch_size=batch_size, steps=test_steps)

### Build NN model

In [41]:
feature_model = Sequential()
feature_model.add(Dense(1024, input_shape=(len(functions),), activation='relu'))
feature_model.add(Dense(512, activation='relu'))
feature_model.add(Dense(256, activation='relu'))

input1 = Input(shape=(len(functions),))
input2 = Input(shape=(len(functions),))
feature1 = feature_model(input1)
feature2 = feature_model(input2)
net = Dot(axes=1)([feature1, feature2])
net = Activation('sigmoid')(net)
model = Model(inputs=[input1, input2], outputs=net)
model.compile(loss='binary_crossentropy', optimizer='adam')
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 8384)         0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            (None, 8384)         0                                            
__________________________________________________________________________________________________
sequential_1 (Sequential)       (None, 256)          9242368     input_3[0][0]                    
                                                                 input_4[0][0]                    
__________________________________________________________________________________________________
dot_1 (Dot)                     (None, 1)            0           sequential_1[1][0]               
          

### Train NN Model

In [42]:
epochs = 12
earlystopper = EarlyStopping(patience=3)
model.fit_generator(
    train_generator,
    steps_per_epoch=train_steps,
    epochs=epochs,
    validation_data=valid_generator,
    validation_steps=valid_steps,
    callbacks=[earlystopper,])


Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12
Epoch 5/12
Epoch 6/12
Epoch 7/12
Epoch 8/12
Epoch 9/12
Epoch 10/12
Epoch 11/12
Epoch 12/12


ValueError: `steps=None` is only valid for a generator based on the `keras.utils.Sequence` class. Please specify `steps` or use the `keras.utils.Sequence` class.

In [43]:
test_loss = model.evaluate_generator(test_generator, steps=test_steps, verbose=1)
print('Test loss:', test_loss)


Test loss: 0.4660752353334616


### Get prediction scores for all pairs

In [44]:
print('Total number of test proteins:', len(proteins))
all_pairs = []
for i in range(len(proteins)):
    for j in range(len(proteins)):
        all_pairs.append((i, j))

batch_size = 128
class SimpleGenerator(object):

    def __init__(self, data, annotations, functions_ix, batch_size=128, steps=100):
        self.data = data
        self.batch_size = batch_size
        self.steps = steps
        self.start = 0
        self.functions_ix = functions_ix
        self.input_length = len(functions_ix)
        self.annotations = annotations

    def __iter__(self):
        return self
    
    def __next__(self):
        return self.next()

    def reset(self):
        self.start = 0

    def next(self):
        if self.start < self.steps:
            batch_pairs = self.data[self.start * self.batch_size: (self.start + 1) * self.batch_size]
            p1 = np.zeros((len(batch_pairs), self.input_length), dtype=np.float32)
            p2 = np.zeros((len(batch_pairs), self.input_length), dtype=np.float32)
            for i in range(len(batch_pairs)):
                if batch_pairs[i][0] in self.annotations:
                    for go_id in self.annotations[batch_pairs[i][0]]:
                        p1[i, self.functions_ix[go_id]] = 1.0
                if batch_pairs[i][1] in self.annotations:
                    for go_id in self.annotations[batch_pairs[i][1]]:
                        p2[i, self.functions_ix[go_id]] = 1.0
            labels = np.zeros((len(batch_pairs), 1), dtype=np.float32)
            self.start += 1
            return ([p1, p2], labels)
        else:
            self.reset()

all_steps = int(math.ceil(len(all_pairs) / batch_size))
all_generator = SimpleGenerator(
    all_pairs, annotations, functions_ix,
    batch_size=batch_size, steps=all_steps)
predictions = model.predict_generator(all_generator, steps=all_steps, verbose=True)


Total number of test proteins: 5926


### Evaluate predictions

In [45]:
def compute_rank_roc(ranks, n_prots):
    auc_x = list(ranks.keys())
    auc_x.sort()
    auc_y = []
    tpr = 0
    sum_rank = sum(ranks.values())
    for x in auc_x:
        tpr += ranks[x]
        auc_y.append(tpr / sum_rank)
    auc_x.append(n_prots)
    auc_y.append(1)
    auc = np.trapz(auc_y, auc_x) / n_prots
    return auc


sim = predictions.reshape(len(proteins), len(proteins))

trlabels = np.ones((len(proteins), len(proteins)), dtype=np.int32)
for c, d in train_data:
    trlabels[c, d] = 0
for c, d in valid_data:
    trlabels[c, d] = 0

top10 = 0
top100 = 0
mean_rank = 0
ftop10 = 0
ftop100 = 0
fmean_rank = 0
n = len(test_data)
labels = np.zeros((len(proteins), len(proteins)), dtype=np.int32) 
ranks = {}
franks = {}
with ck.progressbar(test_data) as prog_data:
    for c, d in prog_data:
        labels[c, d] = 1
        index = rankdata(-sim[c, :], method='average')
        rank = index[d]
        if rank <= 10:
            top10 += 1
        if rank <= 100:
            top100 += 1
        mean_rank += rank
        if rank not in ranks:
            ranks[rank] = 0
        ranks[rank] += 1

        # Filtered rank
        fil = sim[c, :] * (labels[c, :] | trlabels[c, :])
        index = rankdata(-fil, method='average')
        rank = index[d]
        if rank <= 10:
            ftop10 += 1
        if rank <= 100:
            ftop100 += 1
        fmean_rank += rank
        if rank not in franks:
            franks[rank] = 0
        franks[rank] += 1

    print()
    top10 /= n
    top100 /= n
    mean_rank /= n
    ftop10 /= n
    ftop100 /= n
    fmean_rank /= n

    rank_auc = compute_rank_roc(ranks, len(proteins))
    frank_auc = compute_rank_roc(franks, len(proteins))
    print(f'{top10:.2f} {top100:.2f} {mean_rank:.2f} {rank_auc:.2f}')
    print(f'{ftop10:.2f} {ftop100:.2f} {fmean_rank:.2f} {frank_auc:.2f}')



0.08 0.50 543.56 0.91
0.19 0.72 491.56 0.92
