# DeepGO Pytorch implementation

In [39]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from utils import (
    get_gene_ontology,
    get_go_set,
    get_anchestors,
    get_parents,
    FUNC_DICT)
from sklearn.metrics import roc_curve, auc, matthews_corrcoef


Global variables and constants

In [15]:
# Maximum protein sequence n-grams length
MAXLEN = 2000
# Gene ontology file in OBO format
go = get_gene_ontology('data/go.obo')
DATA_ROOT = 'data/binary/'
# List of GO terms for prediction sorted in reverse topological order
functions = pd.read_pickle(DATA_ROOT + 'functions.pkl')['functions']
nb_classes = len(functions)
go_indexes = {}
for i, go_id in enumerate(functions):
    go_indexes[go_id] = i
# List of InterPRO ids
interpros = []
ipro_indexes = {}
nb_interpros = 10000
with open(DATA_ROOT + 'interpros.list') as f:
    for i in range(nb_interpros):
        it = next(f).split('\t')
        ipro_indexes[it[0]] = i
        interpros.append(it[0])

device = torch.device('cuda')

Custom data loader from pandas data frame

In [154]:
class DFGenerator(object):

    def __init__(self, df, batch_size=256):
        self.batch_size = batch_size
        self.start = 0
        self.size = len(df)
        self.df = df
        
    def __iter__(self):
        return self
    
    def __next__(self):
        return self.next()

    def reset(self):
        self.start = 0

    def next(self):
        if self.start < self.size:
            batch_index = np.arange(
                self.start, min(self.size, self.start + self.batch_size))
            df = self.df.iloc[batch_index]
            data_seq = torch.zeros((len(df), MAXLEN), dtype=torch.long)
            data_net = torch.zeros((len(df), 256), dtype=torch.float32)
            ipros = torch.zeros((len(df), len(interpros)), dtype=torch.float32)
            labels = torch.zeros((len(df), nb_classes), dtype=torch.float32)
            for i, row in enumerate(df.itertuples()):
                st = 0
                if hasattr(row, 'starts'):
                    st = row.starts
                data_seq[i, st:(st + len(row.ngrams))] = torch.from_numpy(row.ngrams)
                if isinstance(row.embeddings, np.ndarray):
                    data_net[i, :] = torch.from_numpy(row.embeddings)
                if isinstance(row.interpros, list):
                    for ipro_id in row.interpros:
                        if ipro_id in ipro_indexes:
                            ipros[i, ipro_indexes[ipro_id]] = 1
            
                for go_id in row.functions:
                    if go_id in go_indexes:
                        labels[i, go_indexes[go_id]] = 1
            self.start += self.batch_size
            data = (data_seq, ipros, data_net)
            # data = data_seq
            return (data, labels)
        else:
            self.reset()
            raise StopIteration()

Load the data from preprocessed pandas file 
and split into training, validation and testing 
sets

In [192]:
df = pd.read_pickle('data/binary/data.pkl')
split = 0.8
train_n = int(len(df) * split)
valid_n = int(train_n * split)
index = np.arange(len(df))
# Setting random shuffle seed for reproducibility
np.random.seed(seed=0)
np.random.shuffle(index)
train_df = df.iloc[index[:valid_n]]
valid_df = df.iloc[index[valid_n:train_n]]
test_df = df.iloc[index[train_n:]]

print("Training data: ", len(train_df))
print("Validation data: ", len(valid_df))
print("Testing data: ", len(test_df))

Training data:  43372
Validation data:  10843
Testing data:  13554


Model class

In [187]:
class DeepGO(nn.Module):
    def __init__(self):
        super(DeepGO, self).__init__()
        embedding_size = 128
        self.embeddings = nn.Embedding(8001, embedding_size, padding_idx=0)
        out_channels = 32
        kernel_size = 8
        self.conv1 = nn.Conv1d(embedding_size, out_channels, kernel_size)
        self.conv2 = nn.Conv1d(out_channels, 4, kernel_size)
        self.conv3 = nn.Conv1d(4, 2, kernel_size)
        out_dim = 2 * (MAXLEN - 3 * (kernel_size - 1)) + nb_interpros + 256
        self.fc = nn.Linear(out_dim, nb_classes)
        
    def forward(self, x, y, z):
        x = self.embeddings(x)
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        x = torch.cat([x, y, z], dim=1)
        x = self.fc(x)
        return F.sigmoid(x)

Train model function

In [188]:
def train(model, generator, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(generator):
        seq, ipro, net = data
        labels = labels.to(device)
        seq = seq.to(device)
        ipro = ipro.to(device)
        net = net.to(device)
        optimizer.zero_grad()
        output = model(seq, ipro, net)
        loss = F.binary_cross_entropy(output, labels)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_df), loss.item()))

Validation function

In [189]:
def valid(model, generator):
    model.eval()
    loss = 0
    with torch.no_grad():
        for data, labels in generator:
            seq, ipro, net = data
            labels = labels.to(device)
            seq = seq.to(device)
            ipro = ipro.to(device)
            net = net.to(device)
            output = model(seq, ipro, net)
            batch_loss = F.binary_cross_entropy(output, labels, size_average=False).item() # sum up batch loss
            loss += batch_loss
    loss /= len(valid_df) * nb_classes
    print('\nValidation set: Average loss: {:.4f}\n'.format(loss))
    return loss

Test function

In [190]:
def test(model_file):
    batch_size = 256
    generator = DFGenerator(test_df, batch_size=batch_size)
    preds = torch.zeros((len(test_df), nb_classes), dtype=torch.float32)
    test_labels = torch.zeros((len(test_df), nb_classes), dtype=torch.float32)
    print('Loading the model ...')
    model = torch.load(model_file)
    model.eval()
    loss = 0
    with torch.no_grad():
        print('Predicting ...')
        for idx, (data, labels) in enumerate(generator):
            seq, ipro, net = data
            labels = labels.to(device)
            seq = seq.to(device)
            ipro = ipro.to(device)
            net = net.to(device)
            output = model(seq, ipro, net)
            preds[idx * batch_size: idx * batch_size + batch_size, :] = output
            test_labels[idx * batch_size: idx * batch_size + batch_size, :] = labels
            batch_loss = F.binary_cross_entropy(output, labels, size_average=False).item() # sum up batch loss
            loss += batch_loss
    loss /= len(test_df) * nb_classes
    print('\nTest set: Average loss: {:.4f}\n'.format(loss))
    return test_labels, preds

Train model

In [193]:
model = DeepGO().to(device)
optimizer = optim.Adam(model.parameters())

train_generator = DFGenerator(train_df)
valid_generator = DFGenerator(valid_df)

epochs = 12
loss = 100
model_file = 'deepgo_model.pt'
for epoch in range(epochs):
    train(model, train_generator, optimizer, epoch)
    valid_loss = valid(model, valid_generator)
    if loss > valid_loss:
        print('Saving best model to a file: ', model_file)
        loss = valid_loss
        torch.save(model, model_file)
    else:
        print('Validation loss did not improve from ', loss)

Train Epoch: 0 [0/43372]	Loss: 0.694534
Train Epoch: 0 [30/43372]	Loss: 0.170462
Train Epoch: 0 [60/43372]	Loss: 0.170864
Train Epoch: 0 [90/43372]	Loss: 0.149778
Train Epoch: 0 [120/43372]	Loss: 0.146979
Train Epoch: 0 [150/43372]	Loss: 0.127618
Train Epoch: 0 [180/43372]	Loss: 0.139042
Train Epoch: 0 [210/43372]	Loss: 0.134022
Train Epoch: 0 [240/43372]	Loss: 0.130518
Train Epoch: 0 [270/43372]	Loss: 0.126967
Train Epoch: 0 [300/43372]	Loss: 0.132297
Train Epoch: 0 [330/43372]	Loss: 0.128641
Train Epoch: 0 [360/43372]	Loss: 0.117012
Train Epoch: 0 [390/43372]	Loss: 0.130003
Train Epoch: 0 [420/43372]	Loss: 0.129127
Train Epoch: 0 [450/43372]	Loss: 0.123272
Train Epoch: 0 [480/43372]	Loss: 0.128666

Validation set: Average loss: 0.1215

Saving best model to a file:  deepgo_model.pt


  "type " + obj.__name__ + ". It won't be checked "


Train Epoch: 1 [0/43372]	Loss: 0.116854
Train Epoch: 1 [30/43372]	Loss: 0.118485
Train Epoch: 1 [60/43372]	Loss: 0.127627
Train Epoch: 1 [90/43372]	Loss: 0.123667
Train Epoch: 1 [120/43372]	Loss: 0.124387
Train Epoch: 1 [150/43372]	Loss: 0.110611
Train Epoch: 1 [180/43372]	Loss: 0.122045
Train Epoch: 1 [210/43372]	Loss: 0.117028
Train Epoch: 1 [240/43372]	Loss: 0.114216
Train Epoch: 1 [270/43372]	Loss: 0.114583
Train Epoch: 1 [300/43372]	Loss: 0.119937
Train Epoch: 1 [330/43372]	Loss: 0.117420
Train Epoch: 1 [360/43372]	Loss: 0.107101
Train Epoch: 1 [390/43372]	Loss: 0.118682
Train Epoch: 1 [420/43372]	Loss: 0.119664
Train Epoch: 1 [450/43372]	Loss: 0.115671
Train Epoch: 1 [480/43372]	Loss: 0.120906

Validation set: Average loss: 0.1171

Saving best model to a file:  deepgo_model.pt
Train Epoch: 2 [0/43372]	Loss: 0.111540
Train Epoch: 2 [30/43372]	Loss: 0.111324
Train Epoch: 2 [60/43372]	Loss: 0.119344
Train Epoch: 2 [90/43372]	Loss: 0.116933
Train Epoch: 2 [120/43372]	Loss: 0.115652
T

Train Epoch: 11 [120/43372]	Loss: 0.074551
Train Epoch: 11 [150/43372]	Loss: 0.068804
Train Epoch: 11 [180/43372]	Loss: 0.071452
Train Epoch: 11 [210/43372]	Loss: 0.070651
Train Epoch: 11 [240/43372]	Loss: 0.065254
Train Epoch: 11 [270/43372]	Loss: 0.070064
Train Epoch: 11 [300/43372]	Loss: 0.070097
Train Epoch: 11 [330/43372]	Loss: 0.067912
Train Epoch: 11 [360/43372]	Loss: 0.063014
Train Epoch: 11 [390/43372]	Loss: 0.066958
Train Epoch: 11 [420/43372]	Loss: 0.073252
Train Epoch: 11 [450/43372]	Loss: 0.072663
Train Epoch: 11 [480/43372]	Loss: 0.074296

Validation set: Average loss: 0.1571

Validation loss did not improve from  0.11712695036161863


Test and evaluate the model

In [194]:
def compute_roc(labels, preds):
    fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
    roc_auc = auc(fpr, tpr)
    return roc_auc

def compute_mcc(labels, preds):
    mcc = matthews_corrcoef(labels.flatten(), preds.flatten())
    return mcc

def compute_fmax(labels, preds):
    preds = np.round(preds, 2)
    f_max = 0
    p_max = 0
    r_max = 0
    t_max = 0
    predictions_max = (preds > 0).astype(np.int32)
    for t in range(1, 100):
        threshold = t / 100.0
        predictions = (preds > threshold).astype(np.int32)
        total = 0
        f = 0.0
        p = 0.0
        r = 0.0
        p_total = 0
        for i in range(labels.shape[0]):
            tp = np.sum(predictions[i, :] * labels[i, :])
            fp = np.sum(predictions[i, :]) - tp
            fn = np.sum(labels[i, :]) - tp
            if tp == 0 and fp == 0 and fn == 0:
                continue
            total += 1
            if tp != 0:
                p_total += 1
                precision = tp / (1.0 * (tp + fp))
                recall = tp / (1.0 * (tp + fn))
                p += precision
                r += recall
        if p_total == 0:
            continue
        r /= total
        p /= p_total
        if p + r > 0:
            f = 2 * p * r / (p + r)
            if f_max < f:
                f_max = f
                p_max = p
                r_max = r
                t_max = threshold
                predictions_max = predictions
    return f_max, p_max, r_max, t_max, predictions_max

model_file = 'deepgo_model.pt'
labels, preds = test(model_file)
labels = labels.cpu().numpy()
preds = preds.cpu().numpy()

roc_auc = compute_roc(labels, preds)
print('ROC AUC: ', roc_auc)
# f, p, r, t, preds_max = compute_fmax(labels, preds)
# print(f, p, r, t)

mf = get_go_set(go, FUNC_DICT['mf'])
bp = get_go_set(go, FUNC_DICT['bp'])
cc = get_go_set(go, FUNC_DICT['cc'])
bp_index = list()
mf_index = list()
cc_index = list()
for i, go_id in enumerate(functions):
    if go_id in mf:
        mf_index.append(i)
    elif go_id in bp:
        bp_index.append(i)
    elif go_id in cc:
        cc_index.append(i)
# mf_index = mf_index[-50:]
# cc_index = cc_index[-50:]
# bp_index = bp_index[-50:]

mf_labels = labels[:, mf_index]
bp_labels = labels[:, bp_index]
cc_labels = labels[:, cc_index]
mf_preds = preds[:, mf_index]
bp_preds = preds[:, bp_index]
cc_preds = preds[:, cc_index]

print('MF', len(mf_index))
f, p, r, t, preds_max = compute_fmax(mf_labels, mf_preds)
print(f, p, r, t)

print('BP', len(bp_index))
f, p, r, t, preds_max = compute_fmax(bp_labels, bp_preds)
print(f, p, r, t)

print('CC', len(cc_index))
f, p, r, t, preds_max = compute_fmax(cc_labels, cc_preds)
print(f, p, r, t)


Loading the model ...
Predicting ...

Test set: Average loss: 0.1174

ROC AUC:  0.9156755823185276
MF 134
0.41611479659697576 0.5334635350576244 0.3410846405559903 0.12
BP 761
0.4431181760102735 0.45989845737030927 0.42751931022279693 0.18
CC 144
0.6056459217947423 0.6981739081395634 0.5347731972287243 0.22
