# CAFA5 DeepGoZero

Based on Geraseva's original notebook.
Adapted for inference on CAFA5 dataset 

### Instructions
For each Ontology class, a submission must be generated. 
Go to Hyperparamter Definition, and run once for each ONT class. You will then need to do some manual postprocessing to combine all 3 submissions.
Note, the test dataset is currently running for a 100-protein sample for memory purposes. You will find the commented out line relevant.
It may be, depending on your specs, that you need a test dataloader. An example of that is provided in the inference cells.

Note that there are three levels of GO-fidelity provided in Geraseva's dataset (NB THIS IS NOT CAFA5 BUT ORIGINAL REPO). These are also discussed under Hyperparameter Definition.
TODO: Diamond + Blast predictions are provided in Gerasevas data. If you have your own diamond predictions, you can combine that as well.
A link to a example in DeepGoPlus do this is provided in the final cell.


In [1]:
import numpy as np 
import pandas as pd 
from IPython.display import clear_output
from tqdm import tqdm
import os

In [2]:
pd.options.display.max_rows = 4000

In [3]:
import torch as th
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from sklearn.metrics import roc_curve, auc, matthews_corrcoef
import copy
from torch.utils.data import DataLoader, IterableDataset, TensorDataset
from itertools import cycle
import math
from dgl.nn import GraphConv, GATConv
import dgl
from collections import deque, Counter

## Hyperparameter Definition

* Note for retraining, GO Norm would have to be replaced with our own to match size. But irrelevant for inference.
* With default terms and model file (deepgozero_zero_10.th & terms_zero_10.pkl), we observe 10490 annotations in output. Smaller than our 40K.
* Note the extended comment on 3 models/terms files per GO class, allowing for max (3x3) inference run configurations. See comment for details





In [4]:
data_root='../input/deepgozero-data/data'
ont='bp'  # Ontology class definition, because DGZ has models for each class. mf, bp, or cc
device='cpu' #'cuda:0'
batch_size=37 # batch size currently not used for inference, training only
epochs=256 # epochs currently not used for inference, training only
load=True # If False, retrain. Note not relevant for CAFA5, keep True for inference
go_file = f'{data_root}/go.norm' # GO DGZ file. contains 10490 GO annotations
"""
Note, for each GO ontology class, we we have 3 possible model and terms files
- deepgozero.th
- deepgozero_zero.th
- deepgozero_zero_10.th
- terms.pkl
- terms_zero.pkl
- terms_zero_10.pkl

These correspond to different sets of GO terms to train with. For example for GO=MF
- len(terms) = 6868
- len(terms_zero) =6863
- len(terms_zero_10)=2041

The notebook default are the packages ending with zero_10, also consumes the least memory.
However it may be good to test out the other ones, if time allows
"""
terms_mode = 0 # 0,1 or 2

if terms_mode == 0:
    model_file = f'{data_root}/{ont}/deepgozero_zero_10.th' # Model file
    terms_file = f'{data_root}/{ont}/terms_zero_10.pkl'
elif terms_mode == 1:
    model_file = f'{data_root}/{ont}/deepgozero_zero.th' # Model file
    terms_file = f'{data_root}/{ont}/terms_zero.pkl'
elif terms_mode == 2:
    model_file = f'{data_root}/{ont}/deepgozero.th' # Model file
    terms_file = f'{data_root}/{ont}/terms.pkl'

out_file = "../output/DGZ/predictions_deepgozero_zero_10.pkl"
test_df = pd.read_csv('../input/cafa-fasta-4/test_df.csv') # instead of taking the whole dataset, we will take a 100 row sample
#test_df= pd.read_csv('../input/cafa-fasta-4/test_sample.csv') # CAFA5 data test dataframe,  generated using https://github.com/bio-ontology-research-group/deepgozero/blob/main/interpro_data.py
threshold=0.1 # Probability threshold to accept prediction

sub_file = f'../output/DGZ/{ont}_th-{threshold}_{os.path.basename(model_file)[:-3]}_submission.tsv'

## DeepGoZero repository code


We use code from https://github.com/bio-ontology-research-group/deepgozero. 
Instead of cloning the repo we copy all necessary code into the notebook and run it.

Note, insufficient for CAFA5 retraining, only inference.

In [5]:
# from https://github.com/bio-ontology-research-group/deepgozero/blob/main/utils.py
AALETTER = [
    'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I',
    'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']
AANUM = len(AALETTER)
AAINDEX = dict()
for i in range(len(AALETTER)):
    AAINDEX[AALETTER[i]] = i + 1
INVALID_ACIDS = set(['U', 'O', 'B', 'Z', 'J', 'X', '*'])
MAXLEN = 2000
NGRAMS = {}
for i in range(20):
    for j in range(20):
        for k in range(20):
            ngram = AALETTER[i] + AALETTER[j] + AALETTER[k]
            index = 400 * i + 20 * j + k + 1
            NGRAMS[ngram] = index

def is_ok(seq):
    for c in seq:
        if c in INVALID_ACIDS:
            return False
    return True

def to_ngrams(seq):
    l = min(MAXLEN, len(seq) - 3)
    ngrams = np.zeros((l,), dtype=np.int32)
    for i in range(l):
        ngrams[i] = NGRAMS.get(seq[i: i + 3], 0)
    return ngrams

def to_tokens(seq):
    tokens = np.zeros((MAXLEN, ), dtype=np.float32)
    l = min(MAXLEN, len(seq))
    for i in range(l):
        tokens[i] = AAINDEX.get(seq[i], 0)
    return tokens

def to_onehot(seq, start=0):
    onehot = np.zeros((21, MAXLEN), dtype=np.float32)
    l = min(MAXLEN, len(seq))
    for i in range(start, start + l):
        onehot[AAINDEX.get(seq[i - start], 0), i] = 1
    onehot[0, 0:start] = 1
    onehot[0, start + l:] = 1
    return onehot

In [6]:
# from https://github.com/bio-ontology-research-group/deepgozero/blob/main/utils.py
# Ontology hierarchy handling methods
class Ontology(object):

    def __init__(self, filename='data/go.obo', with_rels=False):
        self.ont = self.load(filename, with_rels)
        self.ic = None
        self.ic_norm = 0.0

    def has_term(self, term_id):
        return term_id in self.ont

    def get_term(self, term_id):
        if self.has_term(term_id):
            return self.ont[term_id]
        return None

    def calculate_ic(self, annots):
        cnt = Counter()
        for x in annots:
            cnt.update(x)
        self.ic = {}
        for go_id, n in cnt.items():
            parents = self.get_parents(go_id)
            if len(parents) == 0:
                min_n = n
            else:
                min_n = min([cnt[x] for x in parents])

            self.ic[go_id] = math.log(min_n / n, 2)
            self.ic_norm = max(self.ic_norm, self.ic[go_id])
    
    def get_ic(self, go_id):
        if self.ic is None:
            raise Exception('Not yet calculated')
        if go_id not in self.ic:
            return 0.0
        return self.ic[go_id]

    def get_norm_ic(self, go_id):
        return self.get_ic(go_id) / self.ic_norm

    def load(self, filename, with_rels):
        ont = dict()
        obj = None
        with open(filename, 'r') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                if line == '[Term]':
                    if obj is not None:
                        ont[obj['id']] = obj
                    obj = dict()
                    obj['is_a'] = list()
                    obj['part_of'] = list()
                    obj['regulates'] = list()
                    obj['alt_ids'] = list()
                    obj['is_obsolete'] = False
                    continue
                elif line == '[Typedef]':
                    if obj is not None:
                        ont[obj['id']] = obj
                    obj = None
                else:
                    if obj is None:
                        continue
                    l = line.split(": ")
                    if l[0] == 'id':
                        obj['id'] = l[1]
                    elif l[0] == 'alt_id':
                        obj['alt_ids'].append(l[1])
                    elif l[0] == 'namespace':
                        obj['namespace'] = l[1]
                    elif l[0] == 'is_a':
                        obj['is_a'].append(l[1].split(' ! ')[0])
                    elif with_rels and l[0] == 'relationship':
                        it = l[1].split()
                        # add all types of relationships
                        obj['is_a'].append(it[1])
                    elif l[0] == 'name':
                        obj['name'] = l[1]
                    elif l[0] == 'is_obsolete' and l[1] == 'true':
                        obj['is_obsolete'] = True
            if obj is not None:
                ont[obj['id']] = obj
        for term_id in list(ont.keys()):
            for t_id in ont[term_id]['alt_ids']:
                ont[t_id] = ont[term_id]
            if ont[term_id]['is_obsolete']:
                del ont[term_id]
        for term_id, val in ont.items():
            if 'children' not in val:
                val['children'] = set()
            for p_id in val['is_a']:
                if p_id in ont:
                    if 'children' not in ont[p_id]:
                        ont[p_id]['children'] = set()
                    ont[p_id]['children'].add(term_id)
     
        return ont

    def get_anchestors(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        q = deque()
        q.append(term_id)
        while(len(q) > 0):
            t_id = q.popleft()
            if t_id not in term_set:
                term_set.add(t_id)
                for parent_id in self.ont[t_id]['is_a']:
                    if parent_id in self.ont:
                        q.append(parent_id)
        return term_set

    def get_prop_terms(self, terms):
        prop_terms = set()

        for term_id in terms:
            prop_terms |= self.get_anchestors(term_id)
        return prop_terms


    def get_parents(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        for parent_id in self.ont[term_id]['is_a']:
            if parent_id in self.ont:
                term_set.add(parent_id)
        return term_set


    def get_namespace_terms(self, namespace):
        terms = set()
        for go_id, obj in self.ont.items():
            if obj['namespace'] == namespace:
                terms.add(go_id)
        return terms

    def get_namespace(self, term_id):
        return self.ont[term_id]['namespace']
    
    def get_term_set(self, term_id):
        if term_id not in self.ont:
            return set()
        term_set = set()
        q = deque()
        q.append(term_id)
        while len(q) > 0:
            t_id = q.popleft()
            if t_id not in term_set:
                term_set.add(t_id)
                for ch_id in self.ont[t_id]['children']:
                    q.append(ch_id)
        return term_set


In [7]:
# from https://github.com/bio-ontology-research-group/deepgozero/blob/main/deepgozero.py
# Primary load data methods.  
def compute_roc(labels, preds):
    # Compute ROC curve and ROC area for each class
    fpr, tpr, _ = roc_curve(labels.flatten(), preds.flatten())
    roc_auc = auc(fpr, tpr)

    return roc_auc

def load_normal_forms(go_file, terms_dict):
    nf1 = []
    nf2 = []
    nf3 = []
    nf4 = []
    relations = {}
    zclasses = {}
    
    def get_index(go_id):
        if go_id in terms_dict:
            index = terms_dict[go_id]
        elif go_id in zclasses:
            index = zclasses[go_id]
        else:
            zclasses[go_id] = len(terms_dict) + len(zclasses)
            index = zclasses[go_id]
        return index

    def get_rel_index(rel_id):
        if rel_id not in relations:
            relations[rel_id] = len(relations)
        return relations[rel_id]
                
    with open(go_file) as f:
        for line in f:
            line = line.strip().replace('_', ':')
            if line.find('SubClassOf') == -1:
                continue
            left, right = line.split(' SubClassOf ')
            # C SubClassOf D
            if len(left) == 10 and len(right) == 10:
                go1, go2 = left, right
                nf1.append((get_index(go1), get_index(go2)))
            elif left.find('and') != -1: # C and D SubClassOf E
                go1, go2 = left.split(' and ')
                go3 = right
                nf2.append((get_index(go1), get_index(go2), get_index(go3)))
            elif left.find('some') != -1:  # R some C SubClassOf D
                rel, go1 = left.split(' some ')
                go2 = right
                nf3.append((get_rel_index(rel), get_index(go1), get_index(go2)))
            elif right.find('some') != -1: # C SubClassOf R some D
                go1 = left
                rel, go2 = right.split(' some ')
                nf4.append((get_index(go1), get_rel_index(rel), get_index(go2)))
    return nf1, nf2, nf3, nf4, relations, zclasses 
    
def load_data(data_root, ont, terms_file):
    terms_df = pd.read_pickle(terms_file)
    terms = terms_df['gos'].values.flatten()
    terms_dict = {v: i for i, v in enumerate(terms)}
    print('Terms', len(terms))
    
    ipr_df = pd.read_pickle(f'{data_root}/{ont}/interpros.pkl')
    iprs = ipr_df['interpros'].values
    iprs_dict = {v:k for k, v in enumerate(iprs)}
    return iprs_dict, terms_dict

def get_data(df, iprs_dict, terms_dict):
    data = th.zeros((len(df), len(iprs_dict)), dtype=th.float32)
    labels = th.zeros((len(df), len(terms_dict)), dtype=th.float32)
    for i, row in enumerate(df.itertuples()):
        #Enumerate over each row in dataframe
        for ipr in row.interpros:
            # For each ipr in the interpro column of a row, check if its in the dict
            # TODO, how is this interpro column made?! is it interproscan?
            if ipr in iprs_dict:
                data[i, iprs_dict[ipr]] = 1 #if so, add a count
        for go_id in row.prop_annotations: # prop_annotations for full model
            if go_id in terms_dict:
                g_id = terms_dict[go_id]
                labels[i, g_id] = 1
    return data, labels


In [8]:
# from https://github.com/bio-ontology-research-group/deepgozero/blob/main/deepgozero.py
# Model definition
class Residual(nn.Module):

    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return x + self.fn(x)
    
        
class MLPBlock(nn.Module):

    def __init__(self, in_features, out_features, bias=True, layer_norm=True, dropout=0.1, activation=nn.ReLU):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias)
        self.activation = activation()
        self.layer_norm = nn.BatchNorm1d(out_features, track_running_stats=False) if layer_norm else None
        self.dropout = nn.Dropout(dropout) if dropout else None

    def forward(self, x):
        x = self.activation(self.linear(x))
        if self.layer_norm:
            x = self.layer_norm(x)
        if self.dropout:
            x = self.dropout(x)
        return x


class DGELModel(nn.Module):

    def __init__(self, nb_iprs, nb_gos, nb_zero_gos, nb_rels, device, hidden_dim=1024, embed_dim=1024, margin=0.1):
        super().__init__()
        self.nb_gos = nb_gos
        self.nb_zero_gos = nb_zero_gos
        input_length = nb_iprs
        net = []
        net.append(MLPBlock(input_length, hidden_dim))
        net.append(Residual(MLPBlock(hidden_dim, hidden_dim)))
        self.net = nn.Sequential(*net)

        # ELEmbeddings
        self.embed_dim = embed_dim
        self.hasFuncIndex = th.LongTensor([nb_rels]).to(device)
        self.go_embed = nn.Embedding(nb_gos + nb_zero_gos, embed_dim)
        self.go_norm = nn.BatchNorm1d(embed_dim)
        k = math.sqrt(1 / embed_dim)
        nn.init.uniform_(self.go_embed.weight, -k, k)
        self.go_rad = nn.Embedding(nb_gos + nb_zero_gos, 1)
        nn.init.uniform_(self.go_rad.weight, -k, k)
        # self.go_embed.weight.requires_grad = False
        # self.go_rad.weight.requires_grad = False
        
        self.rel_embed = nn.Embedding(nb_rels + 1, embed_dim)
        nn.init.uniform_(self.rel_embed.weight, -k, k)
        self.all_gos = th.arange(self.nb_gos).to(device)
        self.margin = margin

        
    def forward(self, features):
        x = self.net(features)
        go_embed = self.go_embed(self.all_gos)
        hasFunc = self.rel_embed(self.hasFuncIndex)
        hasFuncGO = go_embed + hasFunc
        go_rad = th.abs(self.go_rad(self.all_gos).view(1, -1))
        x = th.matmul(x, hasFuncGO.T) + go_rad
        logits = th.sigmoid(x)
        return logits

    def predict_zero(self, features, data):
        x = self.net(features)
        go_embed = self.go_embed(data)
        hasFunc = self.rel_embed(self.hasFuncIndex)
        hasFuncGO = go_embed + hasFunc
        go_rad = th.abs(self.go_rad(data).view(1, -1))
        x = th.matmul(x, hasFuncGO.T) + go_rad
        logits = th.sigmoid(x)
        return logits


    def el_loss(self, go_normal_forms):
        nf1, nf2, nf3, nf4 = go_normal_forms
        nf1_loss = self.nf1_loss(nf1)
        nf2_loss = self.nf2_loss(nf2)
        nf3_loss = self.nf3_loss(nf3)
        nf4_loss = self.nf4_loss(nf4)
        # print()
        # print(nf1_loss.detach().item(),
        #       nf2_loss.detach().item(),
        #       nf3_loss.detach().item(),
        #       nf4_loss.detach().item())
        return nf1_loss + nf3_loss + nf4_loss + nf2_loss

    def class_dist(self, data):
        c = self.go_norm(self.go_embed(data[:, 0]))
        d = self.go_norm(self.go_embed(data[:, 1]))
        rc = th.abs(self.go_rad(data[:, 0]))
        rd = th.abs(self.go_rad(data[:, 1]))
        dist = th.linalg.norm(c - d, dim=1, keepdim=True) + rc - rd
        return dist
        
    def nf1_loss(self, data):
        pos_dist = self.class_dist(data)
        loss = th.mean(th.relu(pos_dist - self.margin))
        return loss

    def nf2_loss(self, data):
        c = self.go_norm(self.go_embed(data[:, 0]))
        d = self.go_norm(self.go_embed(data[:, 1]))
        e = self.go_norm(self.go_embed(data[:, 2]))
        rc = th.abs(self.go_rad(data[:, 0]))
        rd = th.abs(self.go_rad(data[:, 1]))
        re = th.abs(self.go_rad(data[:, 2]))
        
        sr = rc + rd
        dst = th.linalg.norm(c - d, dim=1, keepdim=True)
        dst2 = th.linalg.norm(e - c, dim=1, keepdim=True)
        dst3 = th.linalg.norm(e - d, dim=1, keepdim=True)
        loss = th.mean(th.relu(dst - sr - self.margin)
                    + th.relu(dst2 - rc - self.margin)
                    + th.relu(dst3 - rd - self.margin))

        return loss

    def nf3_loss(self, data):
        # R some C subClassOf D
        n = data.shape[0]
        # rS = self.rel_space(data[:, 0])
        # rS = rS.reshape(-1, self.embed_dim, self.embed_dim)
        rE = self.rel_embed(data[:, 0])
        c = self.go_norm(self.go_embed(data[:, 1]))
        d = self.go_norm(self.go_embed(data[:, 2]))
        # c = th.matmul(c, rS).reshape(n, -1)
        # d = th.matmul(d, rS).reshape(n, -1)
        rc = th.abs(self.go_rad(data[:, 1]))
        rd = th.abs(self.go_rad(data[:, 2]))
        
        rSomeC = c + rE
        euc = th.linalg.norm(rSomeC - d, dim=1, keepdim=True)
        loss = th.mean(th.relu(euc + rc - rd - self.margin))
        return loss


    def nf4_loss(self, data):
        # C subClassOf R some D
        n = data.shape[0]
        c = self.go_norm(self.go_embed(data[:, 0]))
        rE = self.rel_embed(data[:, 1])
        d = self.go_norm(self.go_embed(data[:, 2]))
        
        rc = th.abs(self.go_rad(data[:, 1]))
        rd = th.abs(self.go_rad(data[:, 2]))
        sr = rc + rd
        # c should intersect with d + r
        rSomeD = d + rE
        dst = th.linalg.norm(c - rSomeD, dim=1, keepdim=True)
        loss = th.mean(th.relu(dst - sr - self.margin))
        return loss

In [9]:
# https://github.com/bio-ontology-research-group/deepgozero/blob/main/torch_utils.py
# Not currently implemented for our test set, but may be necessary given its size.

import torch

class FastTensorDataLoader:
    """
    A DataLoader-like object for a set of tensors that can be much faster than
    TensorDataset + DataLoader because dataloader grabs individual indices of
    the dataset and calls cat (slow).
    Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
    
    This DataLoader generates tuple outputs of format (tensors,labels)
    """
    def __init__(self, *tensors, batch_size=32, shuffle=False):
        """
        Initialize a FastTensorDataLoader.
        :param *tensors: tensors to store. Must have the same length @ dim 0.
        :param batch_size: batch size to load.
        :param shuffle: if True, shuffle the data *in-place* whenever an
            iterator is created out of this object.
        :returns: A FastTensorDataLoader.
        """
        assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
        self.tensors = tensors

        self.dataset_len = self.tensors[0].shape[0]
        self.batch_size = batch_size
        self.shuffle = shuffle

        # Calculate # batches
        n_batches, remainder = divmod(self.dataset_len, self.batch_size)
        if remainder > 0:
            n_batches += 1
        self.n_batches = n_batches
    def __iter__(self):
        if self.shuffle:
            r = torch.randperm(self.dataset_len)
            self.tensors = [t[r] for t in self.tensors]
        self.i = 0
        return self

    def __next__(self):
        if self.i >= self.dataset_len:
            raise StopIteration
        batch = tuple(t[self.i:self.i+self.batch_size] for t in self.tensors)
        self.i += self.batch_size
        return batch

    def __len__(self):
        return self.n_batches

In [10]:
# from https://github.com/bio-ontology-research-group/deepgozero/blob/main/deepgozero.py

# Prep dictionary terms
loss_func = nn.BCELoss()
iprs_dict, terms_dict = load_data(data_root, ont, terms_file)
n_terms = len(terms_dict)
n_iprs = len(iprs_dict)
    
nf1, nf2, nf3, nf4, relations, zero_classes = load_normal_forms(go_file, terms_dict)
n_rels = len(relations)
n_zeros = len(zero_classes)
    
normal_forms = nf1, nf2, nf3, nf4
nf1 = th.LongTensor(nf1).to(device)
nf2 = th.LongTensor(nf2).to(device)
nf3 = th.LongTensor(nf3).to(device)
nf4 = th.LongTensor(nf4).to(device)
normal_forms = nf1, nf2, nf3, nf4



Terms 10101


In [11]:
print('Number of terms:',n_terms)
print('Number of interpros:',n_iprs)

Number of terms: 10101
Number of interpros: 26406


## Initialize Model

In [12]:
# initialize
net = DGELModel(n_iprs, n_terms, n_zeros, n_rels, device).to(device)
print(net)


DGELModel(
  (net): Sequential(
    (0): MLPBlock(
      (linear): Linear(in_features=26406, out_features=1024, bias=True)
      (activation): ReLU()
      (layer_norm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): Residual(
      (fn): MLPBlock(
        (linear): Linear(in_features=1024, out_features=1024, bias=True)
        (activation): ReLU()
        (layer_norm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (go_embed): Embedding(45257, 1024)
  (go_norm): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (go_rad): Embedding(45257, 1)
  (rel_embed): Embedding(9, 1024)
)


In [13]:
# If not loading checkpoint, prepare for training
if not load:
    train_df = pd.read_pickle(f'{data_root}/{ont}/train_data.pkl')
    train_data = get_data(train_df, iprs_dict, terms_dict)
    print(train_data[0].shape)
    train_loader = FastTensorDataLoader(
            *train_data, batch_size=batch_size, shuffle=True)
    del train_df,train_data 

    valid_df = pd.read_pickle(f'{data_root}/{ont}/valid_data.pkl')
    valid_data = get_data(valid_df, iprs_dict, terms_dict)
    print(valid_data[0].shape)
    valid_loader = FastTensorDataLoader(
            *valid_data, batch_size=batch_size, shuffle=False) 

    del valid_df, valid_data

In [14]:
import gc
gc.collect()

0

In [15]:
# If not loading checkpoint, prepare for training
if not load:
    optimizer = th.optim.Adam(net.parameters(), lr=5e-4)
    scheduler = MultiStepLR(optimizer, milestones=[5, 20], gamma=0.1)
    best_loss = 10000.0

    print('Training the model')
    for epoch in range(epochs):
        net.train()
        train_loss = 0
        train_elloss = 0
        lmbda = 0.1
        train_steps = len(train_loader)
        for batch_features, batch_labels in train_loader:
            batch_features = batch_features.to(device)
            batch_labels = batch_labels.to(device)
            logits = net(batch_features)
            loss = F.binary_cross_entropy(logits, batch_labels)
            el_loss = net.el_loss(normal_forms)
            total_loss = loss + el_loss
            train_loss += loss.detach().item()
            train_elloss = el_loss.detach().item()
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
                    
        train_loss /= train_steps
                
        net.eval()
        with th.no_grad():
            valid_steps = len(valid_loader)
            valid_loss = 0
            preds = []
            for batch_features, batch_labels in valid_loader:
                batch_features = batch_features.to(device)
                batch_labels = batch_labels.to(device)
                logits = net(batch_features)
                batch_loss = F.binary_cross_entropy(logits, batch_labels)
                valid_loss += batch_loss.detach().item()
                preds = np.append(preds, logits.detach().cpu().numpy())
            valid_loss /= valid_steps
            roc_auc = compute_roc(valid_labels, preds)
            print(f'Epoch {epoch}: Loss - {train_loss}, EL Loss: {train_elloss}, Valid loss - {valid_loss}, AUC - {roc_auc}')
    
            print('EL Loss', train_elloss)
            if valid_loss < best_loss:
                best_loss = valid_loss
                print('Saving model')
                th.save(net.state_dict(), model_file)

            scheduler.step()
            


## Inference based on CAFA5 test data

In [16]:
"""
Custom get_data and load_data for CAFA5 test dataset.
Note that 

Major changes

get_data: 
* Convet string interproscan column back into list form (previously generated during runtime so not necessary)
* Prepare array input for model
* Removed label prep and return
* Returning protein ids for submission later

load_data:
* Take pre-prepared test dataframe as input rather than generation
* Removed valid and train df generation
* No longer used here because we loaded iprs_dict and terms_dict previously. Saves memory.

Recall there are both 3 model files and 3 terms files. Keep in mind for experiments
# terms file here in the original approach is /kaggle/input/deepgozero-data/data/mf/terms_zero_10.pkl, but note how there are three sets
#- terms_zero_10.pkl
#- terms_zero.pkl
#- terms.pkl
"""
import ast
def get_data_test(df, iprs_dict, terms_dict):
    df['interproscan']=df['interproscan'].apply(ast.literal_eval) 
    data = th.zeros((len(df), len(iprs_dict)), dtype=th.float32)
    prot_ids = []
    for i, row in enumerate(df.itertuples()):
        # Add protein id for submission writeup later
        prot_id = row.proteins
        prot_ids.append(prot_id)
        # Add relevant interpro embedding into matrix
        for ipr in row.interproscan:
            if ipr in iprs_dict:
                data[i, iprs_dict[ipr]] = 1
               
   
    return data,prot_ids



def load_data_test(data_root, ont, terms_file, df):
    terms_df = pd.read_pickle(terms_file)
    terms = terms_df['gos'].values.flatten()
    terms_dict = {v: i for i, v in enumerate(terms)}
    print('Terms', len(terms))
    
    ipr_df = pd.read_pickle(f'{data_root}/{ont}/interpros.pkl')
    iprs = ipr_df['interpros'].values
    iprs_dict = {v:k for k, v in enumerate(iprs)}

    #df['interproscan']=df['interproscan'].apply(ast.literal_eval) # Convet entries to list, otherwise will just iterate over letters
    test_data = get_data_test(df, iprs_dict, terms_dict)
    
    return iprs_dict, terms_dict, test_data, df

In [17]:
"""
# We've loaded iprs_dict and terms_dict previously, so no need to call load_data
# iprs_dict: contains all interpro embeddings in the entire dataset (from original repo, not CAFA5, but we'll use it as its trained on it- otherwise input matrix will be of wrong shape)
# terms_dict: contains the Ontology specific GO annotations, note these come in different sizes (see hyperparameter optimization)
"""
# With default terms, test_data is tensor of torch.Size([138417, 26406]), which  is equal to len(test_df) x len(interpro embeds)
test_data, prot_ids= get_data_test(test_df, iprs_dict,terms_dict)

In [18]:
print(len(prot_ids))

138417


In [19]:

go = Ontology(f'{data_root}/go.obo', with_rels=True)

# Loading best model
print('Loading the best model')
net.load_state_dict(th.load(model_file, map_location=device))
net.eval()

"""
# Original Test evaluation. Not necessary for us, but could be useful for batching if needed.
# Set to evaluation mode
with th.no_grad():
    test_steps = int(math.ceil(len(test_labels) / batch_size))
    test_loss = 0
    preds = []
    # Forward pass test data into modelm calculate BCE as well as ROC AUC
    for batch_features, batch_labels in tqdm(test_loader,total=len(test_loader)):
        batch_features = batch_features.to(device)
        # 
        print("batch features")
        print(batch_features)
        batch_labels = batch_labels.to(device)
        print("batch labels")
        print(batch_labels)
        logits = net(batch_features)
        batch_loss = F.binary_cross_entropy(logits, batch_labels)
        test_loss += batch_loss.detach().cpu().item()
        preds = np.append(preds, logits.detach().cpu().numpy())
    test_loss /= test_steps
    preds = preds.reshape(-1, n_terms)
    roc_auc = compute_roc(test_labels, preds)
    print(f'Test Loss - {test_loss}, AUC - {roc_auc}')
"""
with th.no_grad():

    preds = []
    batch_features = test_data.to(device)
    logits = net(batch_features)   
    preds = np.append(preds, logits.detach().cpu().numpy())
    preds = preds.reshape(-1, n_terms)
    

w = open(sub_file, 'wt')
preds = list(preds)
# Propagate scores using ontology structure
# Iterates over each of the score vectors in preds
# (len(scores)) with default terms and model setup 10101
for i, scores in tqdm(enumerate(preds), total=len(preds)):
    
    # Use the index in preds to fetch protein id
    prot_id = prot_ids[i]
    prop_annots = {}
    for go_id, j in terms_dict.items():
        score = scores[j]
        # iterates over the ancestors of a given term in the ontology.
        # If an ancestor term already has a score in prop_annots, 
        # it is updated with the maximum of the current score and the new score. 
        # If it does not have a score, it is assigned the new score.
        for sup_go in go.get_anchestors(go_id):
            if sup_go in prop_annots:
                prop_annots[sup_go] = max(prop_annots[sup_go], score)
            else:
                prop_annots[sup_go] = score
    # loop over prop_annots.items() 
    # updates the scores in the original scores vector based on the propagated scores.
    for go_id, score in prop_annots.items():
        if go_id in terms_dict:
            scores[terms_dict[go_id]] = score
    #sort them and go over them, looking at thresholds to write to submission file
    # For default terms_zero_10, len of prop_annots and sannots is 10490. This is a bit more than 10101 score length, probably due to GO hierarchy.
    
    sannots = sorted(prop_annots.items(), key=lambda x: x[1], reverse=True)
    for go_id, score in sannots:
            if score >= threshold:
                w.write(prot_id + '\t' + go_id + '\t%.3f\n' % score)
    w.write('\n')
w.close()
    

#  TODO, add diamond /blast results. Existing implementation can be found here https://www.kaggle.com/code/geraseva/deepgoplus under "Combine diamond preds and deepgo"

Loading the best model


 27%|████████▊                       | 37863/138417 [2:09:08<5:41:34,  4.91it/s]