### MFO Evaluation

In [5]:
from pathlib import Path
import pandas as pd



base_path = Path('path_to_data_directory')


bp_train_path = base_path / 'bp/train_data.pkl'
bp_valid_path = base_path / 'bp/valid_data.pkl'
bp_test_path = base_path / 'bp/test_data.pkl'

cc_train_path = base_path / 'cc/train_data.pkl'
cc_valid_path = base_path / 'cc/valid_data.pkl'
cc_test_path = base_path / 'cc/test_data.pkl'

mf_train_path = base_path / 'mf/train_data.pkl'
mf_valid_path = base_path / 'mf/valid_data.pkl'
mf_test_path = base_path / 'mf/test_data.pkl'


def preprocess(data_path, data_type, ont):
    data = pd.read_pickle(data_path)
    data.rename(columns={'prop_annotations': 'term'}, inplace=True)
    data = data[['proteins', 'sequences', 'term']].rename(columns={'proteins': 'protein_name'})
    data['Set'] = data_type
    data['aspect'] = ont
    return data


bp_train = preprocess(bp_train_path, "Train", "BPO")
cc_train = preprocess(cc_train_path, "Train", "CCO")
mf_train = preprocess(mf_train_path, "Train", "MFO")

bp_valid = preprocess(bp_valid_path, "Valid", "BPO")
cc_valid = preprocess(cc_valid_path, "Valid", "CCO")
mf_valid = preprocess(mf_valid_path, "Valid", "MFO")

bp_test = preprocess(bp_test_path, "Test", "BPO")
cc_test = preprocess(cc_test_path, "Test", "CCO")
mf_test = preprocess(mf_test_path, "Test", "MFO")

# Concatenate 
mf = pd.concat([mf_train, mf_valid, mf_test], ignore_index=True)
cc = pd.concat([cc_train, cc_valid, cc_test], ignore_index=True)
bp = pd.concat([bp_train, bp_valid, bp_test], ignore_index=True)

data = pd.concat([bp, cc, mf], ignore_index=True)


In [6]:
mf.head()

Unnamed: 0,protein_name,sequences,term,Set,aspect
0,CAM1_CAEEL,MSPRPEDDDLVIEPADDEGLHYGNASMEGTSTGQRPYIRLTSQLRN...,"[GO:0044260, GO:0004888, GO:0032502, GO:000471...",Train,MFO
1,PTP3_DICDI,MISSSMSYRHSTNSVYTLNPHLNIPISTSTTIPPTSFYANNTPEMI...,"[GO:0010033, GO:0007165, GO:0009966, GO:000810...",Train,MFO
2,ECT2_MOUSE,MADDSVLPSPSEITSLADSSVFDSKVAEMSKENLCLASTSNVDEEM...,"[GO:0032502, GO:0008104, GO:0051234, GO:004230...",Train,MFO
3,ECT2_HUMAN,MAENSVLTSTTGRTSLADSSIFDSKVTEISKENLLIGSTSYVEEEM...,"[GO:0071277, GO:0071214, GO:0005654, GO:000585...",Train,MFO
4,NEIL1_MOUSE,MPEGPELHLASHFVNETCKGLVFGGCVEKSSVSRNPEVPFESSAYH...,"[GO:0044260, GO:0043170, GO:0034641, GO:000628...",Train,MFO


In [3]:
df_seq = mf[['protein_name', 'sequences']].drop_duplicates().reset_index(drop = True)
df_seq.head()

Unnamed: 0,protein_name,sequences
0,CAM1_CAEEL,MSPRPEDDDLVIEPADDEGLHYGNASMEGTSTGQRPYIRLTSQLRN...
1,PTP3_DICDI,MISSSMSYRHSTNSVYTLNPHLNIPISTSTTIPPTSFYANNTPEMI...
2,ECT2_MOUSE,MADDSVLPSPSEITSLADSSVFDSKVAEMSKENLCLASTSNVDEEM...
3,ECT2_HUMAN,MAENSVLTSTTGRTSLADSSIFDSKVTEISKENLLIGSTSYVEEEM...
4,NEIL1_MOUSE,MPEGPELHLASHFVNETCKGLVFGGCVEKSSVSRNPEVPFESSAYH...


In [5]:
df_seq.describe()

Unnamed: 0,protein_name,sequences
count,43279,43279
unique,43279,42902
top,CAM1_CAEEL,MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTE...
freq,1,13


# GO Terms

In [10]:
terms = mf[['protein_name', 'term', 'Set', 'aspect']]
go_terms = terms.explode('term').reset_index(drop=True)
go_terms.head()

Unnamed: 0,protein_name,term,Set,aspect
0,CAM1_CAEEL,GO:0044260,Train,MFO
1,CAM1_CAEEL,GO:0004888,Train,MFO
2,CAM1_CAEEL,GO:0032502,Train,MFO
3,CAM1_CAEEL,GO:0004713,Train,MFO
4,CAM1_CAEEL,GO:0016310,Train,MFO


## LOAD LLM EMBEDDINGS


In [11]:
### LLM ####
ankh_embeds_array = np.load("path_to_directory/ankh_embeddings_zerogo_embeds.npy")
ankh_ids_array = np.load("path_to_directory/ankh_embeddings_zerogo_ids.npy")
ankh = pd.DataFrame({'protein_name': ankh_ids_array, 'ANKH': ankh_embeds_array.tolist()})
ankh.head()

Unnamed: 0,protein_name,ANKH
0,VGFR2_MOUSE,"[[0.02048352360725403, 0.0012277299538254738, ..."
1,VGFR2_RAT,"[[0.020595740526914597, 0.0011519812978804111,..."
2,VGFR2_HUMAN,"[[0.020025640726089478, 0.001570425578393042, ..."
3,VGFR2_DANRE,"[[0.02045329473912716, 0.0011346322717145085, ..."
4,KIT_MOUSE,"[[0.01818038336932659, 0.0016508179251104593, ..."


# Ontology

In [12]:
ontology = 'MFO'
task_name = 'zerogo/mf' 

In [14]:
import os

def paths(task_name):
    base_dir = 'define_directory'
    
    valid_pred_folder = os.path.join(base_dir, f'predictions/{task_name}/Valid')
    test_pred_folder = os.path.join(base_dir, f'predictions/{task_name}/Test')
    obo_file = os.path.join(base_dir, 'metrics/go-basic.obo')
    valid_gt = os.path.join(base_dir, f'metrics/gt/{task_name}/valid.tsv')
    test_gt = os.path.join(base_dir, f'metrics/gt/{task_name}/test.tsv')
    model_dir = os.path.join(base_dir, f'models/{task_name}')
    valid_preds = os.path.join(base_dir, f'predictions/{task_name}/Valid/valid_pred.tsv')
    test_preds = os.path.join(base_dir, f'predictions/{task_name}/Test/test_pred.tsv')
    result_df = os.path.join(base_dir, f'results/{task_name}/evaluation_all.tsv')
    result_folder = os.path.join(base_dir, f'results/{task_name}')
    
    return (valid_pred_folder, test_pred_folder, obo_file, valid_gt, test_gt, model_dir, valid_preds, test_preds, result_df, result_folder)


(valid_pred_folder, test_pred_folder, obo_file, valid_gt, test_gt, model_dir, valid_preds, test_preds, result_df, result_folder) = paths(task_name)

In [15]:
go_df = pd.read_pickle('data/mf/terms.pkl')

def valid_and_test_gt(go_terms, ontology, valid_gt, test_gt):

    goes = go_terms[(go_terms['aspect'] == ontology)]
    task =goes[goes['term'].isin(go_df['gos'])]
    task = task.drop_duplicates().reset_index(drop=True)
    set_df = task[['protein_name', 'Set']]
    set_df = set_df.drop_duplicates().reset_index(drop=True)

  
    valid = task[(task['aspect'] == ontology) & (task['Set'] == 'Valid')]
    valid = valid[['protein_name', 'term']]
    valid = valid.rename(columns={'protein_name': 'EntryID'})
    valid = valid.reset_index(drop=True)
    valid.to_csv(valid_gt, sep='\t', index=False, header=True)

    test = task[(task['aspect'] == ontology) & (task['Set'] == 'Test')]
    test = test[['protein_name', 'term']]
    test = test.rename(columns={'protein_name': 'EntryID'})
    test = test.reset_index(drop=True)
    test.to_csv(test_gt, sep='\t', index=False, header=True)
    return task, set_df

task, set_df = valid_and_test_gt(go_terms, ontology, valid_gt, test_gt)

In [16]:
task.head()

Unnamed: 0,protein_name,term,Set,aspect
0,CAM1_CAEEL,GO:0004888,Train,MFO
1,CAM1_CAEEL,GO:0004713,Train,MFO
2,CAM1_CAEEL,GO:0005109,Train,MFO
3,CAM1_CAEEL,GO:0016773,Train,MFO
4,CAM1_CAEEL,GO:0038023,Train,MFO


# Labels

In [17]:
ordered_labels = task.groupby('term')['protein_name'].count().sort_values(ascending=False)
labels =ordered_labels[ordered_labels>=10]
labels_names = labels.index.values

In [18]:
len(labels_names)

2065

In [19]:
id_labels = task.groupby('protein_name')['term'].apply(list).to_dict()
ids = np.array(task['protein_name'].unique())

go_terms_map = {label: i for i, label in enumerate(labels_names)}
labels_matrix = np.zeros((len(ids), len(labels_names)))

from tqdm import tqdm
for index, id in tqdm(enumerate(ids)):
    id_gos_list = id_labels[id]
    temp = [go_terms_map[go] for go in labels_names if go in id_gos_list]
    labels_matrix[index, temp] = 1

labels_list = []
for l in range(labels_matrix.shape[0]):
    labels_list.append(labels_matrix[l, :])

labels_df = pd.DataFrame(data={"protein_name":ids, "labels_vect":labels_list})
labels_df.head()

43272it [00:11, 3800.33it/s]


Unnamed: 0,protein_name,labels_vect
0,CAM1_CAEEL,"[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, ..."
1,PTP3_DICDI,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ..."
2,ECT2_MOUSE,"[1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
3,ECT2_HUMAN,"[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ..."
4,NEIL1_MOUSE,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ..."


In [20]:
labels_matrix.shape

(43272, 2065)

# Interproscan embedding

In [21]:
interpro = pd.read_pickle('path_to_directory/binary_zerogo.pkl')
pro_proteins = labels_df[['protein_name']]

binary = pd.merge(pro_proteins, interpro, on='protein_name', how='left')
vector_length = len(interpro['binary_vector'][0])
binary['binary_vector'] = binary['binary_vector'].apply(lambda x: x if isinstance(x, list) else [0] * vector_length)


# Merge ALL

In [22]:
merged_df = pd.merge(ankh, labels_df, on='protein_name')
merged_df = pd.merge(merged_df, binary, on='protein_name')
final_df = pd.merge(merged_df, set_df, on='protein_name')
final_df.drop_duplicates(subset=['protein_name'], inplace=True)
final_df.reset_index(drop=True, inplace=True)

In [23]:
final_df.head()

Unnamed: 0,protein_name,ANKH,labels_vect,binary_vector,Set
0,VGFR2_MOUSE,"[[0.02048352360725403, 0.0012277299538254738, ...","[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Train
1,VGFR2_RAT,"[[0.020595740526914597, 0.0011519812978804111,...","[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Train
2,VGFR2_HUMAN,"[[0.020025640726089478, 0.001570425578393042, ...","[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Train
3,KIT_MOUSE,"[[0.01818038336932659, 0.0016508179251104593, ...","[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Train
4,KIT_HUMAN,"[[0.018293721601366997, 0.001633501029573381, ...","[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",Train


In [19]:
print(f"The first protein belongs to {sum(final_df['binary_vector'][1])} families.")

The first protein belongs to 4 families.


## Protein Family Network

In [20]:
interpro = pd.read_pickle('path_to_directory/binary_zerogo.pkl')
network = pfam[pfam['protein1'].isin(final_df['protein_name']) & pfam['protein2'].isin(final_df['protein_name'])]

network.reset_index(drop=True, inplace=True)

#network

In [22]:
import numpy as np
import torch
from torch_geometric.data import Data


protein_names = final_df['protein_name'].tolist()
protein_to_index = {protein: index for index, protein in enumerate(protein_names)}

interpro = torch.tensor(np.array(final_df['binary_vector'].tolist()), dtype=torch.float64).squeeze().contiguous()
ankh = torch.tensor(np.array(final_df['ANKH'].tolist()), dtype=torch.float64).squeeze().contiguous()
labels = torch.tensor(np.array(final_df['labels_vect'].tolist()), dtype=torch.float64)


train_mask = torch.tensor(final_df['Set'] == 'Train', dtype=torch.bool) 
valid_mask = torch.tensor(final_df['Set'] == 'Valid', dtype=torch.bool) 
test_mask = torch.tensor(final_df['Set'] == 'Test', dtype=torch.bool) 



edges = [(protein_to_index[row['protein1']], protein_to_index[row['protein2']]) for _, row in network.iterrows()]
edge_index = torch.tensor(edges, dtype=torch.long).t()


data = Data(edge_index=edge_index,  x=interpro, ankh =ankh,  y = labels,  train_mask = train_mask, valid_mask = valid_mask, test_mask = test_mask )
data.is_undirected();

In [23]:
data

Data(x=[43272, 10879], edge_index=[2, 3058603], y=[43272, 2065], ankh=[43272, 1536], train_mask=[43272], valid_mask=[43272], test_mask=[43272])

In [24]:
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of node features: {data.num_node_features}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Number of validation nodes: {data.valid_mask.sum()}')
print(f'Number of test nodes: {data.test_mask.sum()}')

Number of nodes: 43272
Number of node features: 10879
Number of edges: 3058603
Average node degree: 70.68
Number of training nodes: 34710
Number of validation nodes: 3850
Number of test nodes: 4712


In [25]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, SAGEConv
from torch_geometric.utils import dropout_edge
from torch import nn

num_classes = len(labels_names)
family_features = data.x.size(1)
ankh_features = data.ankh.size(1)


class PLLM(torch.nn.Module):

    def __init__(self, input_dim):
        super(PLLM, self).__init__()
        torch.manual_seed(1234)
        self.linear1 = torch.nn.Linear(input_dim, 2048)
        self.activation1 = nn.LeakyReLU(0.1)
        self.linear2 = torch.nn.Linear(2048, 1024)
        self.activation2 = nn.LeakyReLU(0.1)
        self.linear3 = torch.nn.Linear(1024, 512)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation1(x)
        x = self.linear2(x)  
        x = self.activation2(x)
        x = self.linear3(x)
        return x
  

class GATN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads):
        super(GATN, self).__init__()
        torch.manual_seed(1234)
        self.gat1 = GATConv(in_channels=input_dim,
                            out_channels=hidden_dim,
                            heads=num_heads,
                            concat=True,
                            add_self_loops=True,
                            bias=True)
        self.gat2 = GATConv(in_channels=hidden_dim * num_heads,
                            out_channels=512,
                            heads=1,
                            concat=True,
                            add_self_loops=True,
                            bias=True)  

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)  
        x = F.relu(x)
        x = self.gat2(x, edge_index)
        return x


class Voltron(torch.nn.Module):
    def __init__(self, num_classes):
        super(Voltron, self).__init__()
        torch.manual_seed(1234)
        self.lin1 = torch.nn.Linear(1024, 512)  
        self.bn1 = nn.BatchNorm1d(512)
        self.act1 = nn.LeakyReLU(0.1)
        self.lin2 = torch.nn.Linear(512, num_classes)

    def forward(self, mlp_output, gat_output):
        x = torch.cat((mlp_output, gat_output), dim=1)  
        x = self.lin1(x)
        x = self.bn1(x)
        x = self.act1(x)
        out = self.lin2(x)
        return out


In [31]:
from torch_geometric.loader import NeighborLoader
data.x = data.x.contiguous()
data.ankh = data.ankh.contiguous()
data.edge_index = data.edge_index.contiguous()

train_loader = NeighborLoader(data, input_nodes=data.train_mask,
                              num_neighbors=[75, 30], batch_size= 128, shuffle=True)

val_loader   = NeighborLoader(data, input_nodes=data.valid_mask,
                              num_neighbors=[75, 30], batch_size=1, shuffle=False)

test_loader  = NeighborLoader(data, input_nodes=data.test_mask,
                              num_neighbors=[75, 30], batch_size=1,  shuffle=False)

In [32]:
import os
import numpy as np
import torch
import torch_geometric
from torch_geometric.data import  Data
from torch.optim.lr_scheduler import ReduceLROnPlateau
import cafaeval
from cafaeval.evaluation import cafa_eval, write_results


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Training will start on {device}')

train_loss_history=[]
val_loss_history=[]
val_fmax_history = []

def train(train_loader, val_loader, out_dir, num_epochs):
  
    mlp_model = PLLM(input_dim=family_features).to(device)
    gat_model = GATN(input_dim=ankh_features, hidden_dim = 256, num_heads=3).to(device)
    model = Voltron(num_classes = num_classes).to(device)
    
    mlp_model = mlp_model.double()
    gat_model = gat_model.double() 
    model = model.double()

    params = list(mlp_model.parameters()) + list(gat_model.parameters()) + list(model.parameters())
    optimizer = torch.optim.Adam(params, lr = 0.0001)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    

    patience_counter = 0
    best_fmax = 0.0
    best_val_loss = float('inf')
    best_s_min =  float('inf')
    
    ## TRAIN     
    for epoch in range(num_epochs):
       
        train_losses = []
        for batch in train_loader:
            mlp_model.train()
            gat_model.train()
            model.train()
            optimizer.zero_grad()
            batch = batch.to(device)

            
            mlp_out = mlp_model(batch.x)
            mlp_out = mlp_out[:batch.batch_size]
            gat_out = gat_model(batch.ankh, batch.edge_index)
            gat_out = gat_out[:batch.batch_size]
            out = model(mlp_out, gat_out)
            out = out[:batch.batch_size]
            
            y = batch.y[:batch.batch_size].double()     
            loss = criterion(out, y)
        
            train_losses.append(loss.item())
            loss.backward() 
            optimizer.step()  

        avg_train_loss = np.mean(train_losses)
        train_loss_history.append(avg_train_loss)

        ## Evaluation on validation Set 
        val_losses = []
        val_scores = [] 
        s_min_values = []
        

        model.eval()
        with torch.no_grad():
            names = np.empty(shape=(len(val_loader)*num_classes,), dtype=object)
            go = np.empty(shape=(len(val_loader)*num_classes,), dtype=object)
            confidence = np.empty(shape=(len(val_loader)*num_classes,), dtype=np.float64)
            val_protein_names = final_df[final_df['Set'] == 'Valid']['protein_name']
            
            for i, (batch, p_name) in enumerate(zip(val_loader, val_protein_names)):
                batch = batch.to(device)
                
                #Models' predictions on Validation Set
                mlp_out = mlp_model(batch.x)
                mlp_out = mlp_out[:batch.batch_size]
                gat_out = gat_model(batch.ankh, batch.edge_index)
                gat_out = gat_out[:batch.batch_size]
                out = model(mlp_out, gat_out)
                out = out[:batch.batch_size]

                y = batch.y[:batch.batch_size].double()
                
                #Validation loss
                val_loss = criterion(out, y)
                val_losses.append(val_loss.item())
                
               
                confidence[i*num_classes:(i+1)*num_classes] = torch.sigmoid(out).squeeze().detach().cpu().numpy()
                names[i*num_classes:(i+1)*num_classes] = p_name  
                go[i*num_classes:(i+1)*num_classes] = labels_names
                
            submission = pd.DataFrame(data={"Protein_name": names, "GO term": go, "Confidence": confidence})
            submission_df = submission[submission['Confidence'] >= 0.01]
            with open(valid_preds, 'w') as file:
                submission_df.to_csv(file, sep='\t', index=False, header=False)
                
            df, df_best =cafa_eval(obo_file, valid_pred_folder, valid_gt)
            f_max = df_best['f'].loc['valid_pred.tsv'].xs('molecular_function', level='ns')['f'].iloc[0]
            s_min = df_best['f'].loc['valid_pred.tsv'].xs('molecular_function', level='ns')['s'].iloc[0]
            
            s_min_values.append(s_min.item())
            val_scores.append(f_max.item())
            val_loss_values = np.mean(val_losses)
            avg_s_min = np.mean(s_min_values)
            
            
        avg_fmax  = np.mean(val_scores)
        avg_s_min = np.mean(s_min_values)
        avg_val_loss = np.mean(val_loss_values)
        
        scheduler.step()
        
        val_loss_history.append(val_loss_values)
        val_fmax_history.append(avg_fmax)

        if avg_fmax >= best_fmax:
            
            best_fmax = max(avg_fmax, best_fmax)
            mlp_path = os.path.join(out_dir, f'mf_mlp_model.pt')
            gat_path = os.path.join(out_dir, f'mf_gat_model.pt')
            voltron_path = os.path.join(out_dir, f'mf_voltron_model.pt')
                
            torch.save(mlp_model.state_dict(), mlp_path)
            torch.save(gat_model.state_dict(), gat_path)
            torch.save(model.state_dict(), voltron_path)
    
            
        if  (avg_s_min < best_s_min) or (avg_val_loss < best_val_loss): 
            best_val_loss = min(avg_val_loss, best_val_loss)
            best_s_min = min(avg_s_min, best_s_min)
            patience_counter = 0   
    
        else:
            patience_counter += 1
             

        if patience_counter == 5:
            print('=' * 100)
            print('EARLY STOPPING ACTIVATED')
            break


        print(f'EPOCH: {epoch + 1:03d} | S_min:{avg_s_min: .5f} | F_max:{avg_fmax : .3f}')

 

    return model              

Training will start on cuda


In [33]:
model = train(train_loader = train_loader, val_loader = val_loader, out_dir = model_dir, num_epochs=50)

EPOCH: 001 | S_min: 10.07319 | F_max: 0.423
EPOCH: 002 | S_min: 9.45467 | F_max: 0.425
EPOCH: 003 | S_min: 9.12944 | F_max: 0.437
EPOCH: 004 | S_min: 8.97271 | F_max: 0.449
EPOCH: 005 | S_min: 8.58325 | F_max: 0.478
EPOCH: 006 | S_min: 8.27210 | F_max: 0.505
EPOCH: 007 | S_min: 7.79740 | F_max: 0.539
EPOCH: 008 | S_min: 7.57232 | F_max: 0.562
EPOCH: 009 | S_min: 7.34423 | F_max: 0.582
EPOCH: 010 | S_min: 7.17292 | F_max: 0.596
EPOCH: 011 | S_min: 7.05719 | F_max: 0.603
EPOCH: 012 | S_min: 6.91555 | F_max: 0.614
EPOCH: 013 | S_min: 6.87475 | F_max: 0.618
EPOCH: 014 | S_min: 6.78102 | F_max: 0.624
EPOCH: 015 | S_min: 6.74714 | F_max: 0.625
EPOCH: 016 | S_min: 6.72133 | F_max: 0.628
EPOCH: 017 | S_min: 6.66571 | F_max: 0.633
EPOCH: 018 | S_min: 6.63296 | F_max: 0.635
EPOCH: 019 | S_min: 6.66021 | F_max: 0.633
EPOCH: 020 | S_min: 6.59490 | F_max: 0.638
EPOCH: 021 | S_min: 6.59309 | F_max: 0.639
EPOCH: 022 | S_min: 6.56294 | F_max: 0.640
EPOCH: 023 | S_min: 6.59091 | F_max: 0.641
EPOCH: 024

# TEST TIME

In [35]:
#Load the saved models

mlp_path = os.path.join(model_dir, f'mf_mlp_model.pt')
gat_path = os.path.join(model_dir, f'mf_gat_model.pt')
voltron_path = os.path.join(model_dir, f'mf_voltron_model.pt')

mlp_model = PLLM(input_dim=family_features).to(device)
mlp_model = mlp_model.double()
gat_model = GATN(input_dim=ankh_features, hidden_dim = 256, num_heads=3).to(device)


gat_model = gat_model.double()
model = Voltron(num_classes = num_classes).to(device)
model = model.double()

mlp_model.load_state_dict(torch.load(mlp_path))
gat_model.load_state_dict(torch.load(gat_path))
model.load_state_dict(torch.load(voltron_path))

# Evaluate on TEST SET 
model.eval()
names = np.empty(shape=(len(test_loader)*num_classes,), dtype=object)
go = np.empty(shape=(len(test_loader)*num_classes,), dtype=object)
confidence = np.empty(shape=(len(test_loader)*num_classes,), dtype=np.float64)
test_protein_names = final_df[final_df['Set'] == 'Test']['protein_name']

for i, (batch, p_name) in tqdm(enumerate(zip(test_loader, test_protein_names))):
    batch = batch.to(device)

    mlp_out = mlp_model(batch.x)
    mlp_out = mlp_out[:batch.batch_size]
    gat_out = gat_model(batch.ankh, batch.edge_index)
    gat_out = gat_out[:batch.batch_size]
    out = model(mlp_out, gat_out)
    out = out[:batch.batch_size]
    
    confidence[i*num_classes:(i+1)*num_classes] = torch.sigmoid(out).squeeze().detach().cpu().numpy()
    names[i*num_classes:(i+1)*num_classes] = p_name  # Assign p_name from test_protein_names
    go[i*num_classes:(i+1)*num_classes] = labels_names


testing = pd.DataFrame(data={"protein_name": names, "GO term": go, "Confidence": confidence})
submission_df = testing.copy()
submission_df = testing[testing['Confidence'] >= 0.01]
with open(test_preds, 'w') as file:
    submission_df.to_csv(test_preds, sep='\t', index=False, header=False)
df, df_b =cafa_eval(obo_file, test_pred_folder, test_gt)
test_f_max = df_b['f'].loc['test_pred.tsv'].xs('molecular_function', level='ns')['f'].iloc[0]
smin = df_b['f'].xs('molecular_function', level='ns')['s'].iloc[0]

precisions = df.xs('molecular_function', level='ns').loc['test_pred.tsv']['pr'].to_numpy()
recalls = df.xs('molecular_function', level='ns').loc['test_pred.tsv']['rc'].to_numpy()
sorted_index = np.argsort(recalls)
recalls = recalls[sorted_index]
precisions = precisions[sorted_index]
aupr = np.trapz(precisions, recalls)

4712it [00:28, 166.21it/s]


In [36]:
print(f'Fmax: {test_f_max:0.3f}')
print(f'Smin: {smin :0.3f}')
print(f'AUPR: {aupr:0.3f}')

Fmax: 0.693
Smin: 5.930
AUPR: 0.658
