In [1]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, to_hetero , SAGEConv, GATv2Conv, HeteroConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score ,matthews_corrcoef

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore") 

In [2]:
import TropiGAT_graph

In [3]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
#path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"

    # Open the DF
DF_info = pd.read_csv(f"{path_work}/TropiGATv2.final_df.tsv", sep = "\t" ,  header = 0)
DF_info_lvl_0 = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_lvl_0 = DF_info_lvl_0.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)



In [4]:
%%time

graph_baseline , dico_prophage_kltype_associated = TropiGAT_graph.build_graph_baseline(DF_info_lvl_0)
graph_KL37 = TropiGAT_graph.build_graph_masking(graph_baseline , 
                                 dico_prophage_kltype_associated, 
                                 DF_info_lvl_0, 
                                 "KL37",
                                 2, 0.8, 0.1,0.1)

8815it [00:39, 221.79it/s]


CPU times: user 1min 22s, sys: 2.68 s, total: 1min 24s
Wall time: 1min 27s


In [78]:
%%time
graph_dico = {kltype : TropiGAT_graph.build_graph_masking(graph_baseline , dico_prophage_kltype_associated,DF_info_lvl_0, kltype,2, 0.8, 0.1,0.1) 
             for kltype in DF_info_lvl_0["KL_type_LCA"].unique()}

CPU times: user 5.95 s, sys: 8.74 s, total: 14.7 s
Wall time: 14.8 s


In [79]:
graph_dico

{'KL43': HeteroData(
   [1mA[0m={ x=[5069, 129] },
   [1mB1[0m={
     x=[8815, 0],
     y=[8815],
     train_mask=[8815],
     test_mask=[8815],
     eval_mask=[8815]
   },
   [1mB2[0m={ x=[4105, 1280] },
   [1m(B1, infects, A)[0m={ edge_index=[2, 8815] },
   [1m(B2, expressed, B1)[0m={ edge_index=[2, 11102] },
   [1m(A, harbors, B1)[0m={ edge_index=[2, 8815] }
 ),
 'KL14': HeteroData(
   [1mA[0m={ x=[5069, 129] },
   [1mB1[0m={
     x=[8815, 0],
     y=[8815],
     train_mask=[8815],
     test_mask=[8815],
     eval_mask=[8815]
   },
   [1mB2[0m={ x=[4105, 1280] },
   [1m(B1, infects, A)[0m={ edge_index=[2, 8815] },
   [1m(B2, expressed, B1)[0m={ edge_index=[2, 11102] },
   [1m(A, harbors, B1)[0m={ edge_index=[2, 8815] }
 ),
 'KL107': HeteroData(
   [1mA[0m={ x=[5069, 129] },
   [1mB1[0m={
     x=[8815, 0],
     y=[8815],
     train_mask=[8815],
     test_mask=[8815],
     eval_mask=[8815]
   },
   [1mB2[0m={ x=[4105, 1280] },
   [1m(B1, infects, A)[0m=

In [42]:
for index,value in enumerate(graph_KL37["B1"].y) : 
    if value != 0 :
        print(value,index)

tensor(1) 3307
tensor(1) 4515
tensor(1) 4516
tensor(1) 5713
tensor(1) 5824
tensor(1) 7558
tensor(1) 8215


# The architecture of the model : 

***
> Original : 

In [57]:
# The model : TropiGAT
class TropiGAT_small_module(torch.nn.Module):
    def __init__(self,hidden_channels, heads, edge_type = ("B2", "expressed", "B1") ,dropout = 0.2, conv = GATv2Conv):
        super().__init__()
        # GATv2 module :
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
        # FNN layers : 
        self.linear_layers = nn.Sequential(nn.Linear(heads*hidden_channels, 1000),
                                           nn.BatchNorm1d(1000),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(1000, 520),
                                           nn.BatchNorm1d(520),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(520 , 1))
        
    def forward(self, graph_data):
        x_B1_dict = self.hetero_conv(graph_data.x_dict, graph_data.edge_index_dict)
        x = self.linear_layers(x_B1_dict["B1"])
        return x.view(-1)

class TropiGAT_big_module(torch.nn.Module):
    def __init__(self,hidden_channels, heads, edge_type = ("B2", "expressed", "B1") ,dropout = 0.2, conv = GATv2Conv):
        super().__init__()
        # GATv2 module :
        self.conv = conv((-1,-1), hidden_channels, add_self_loops = False, heads = heads, dropout = dropout, shared_weights = True)
        self.hetero_conv = HeteroConv({edge_type: self.conv})
        # FNN layers : 
        self.linear_layers = nn.Sequential(nn.Linear(heads*hidden_channels, 1280),
                                           nn.BatchNorm1d(1280),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(1280, 520),
                                           nn.BatchNorm1d(520),
                                           nn.LeakyReLU(),
                                           torch.nn.Dropout(dropout),
                                           nn.Linear(520 , 1))
        
    def forward(self, graph_data):
        x = self.hetero_conv(graph_data.x_dict, graph_data.edge_index_dict)
        x = self.linear_layers(x)
        
        return x.view(-1)

In [58]:
parameters_model = {"hidden_channels" : 1000,
                    "lr" : 0.0001,
                    "heads" : 1,
                    "dropout" : 0.1,
                    "criterion" : torch.nn.BCEWithLogitsLoss(),
                   }

TropiGATv2_eg = TropiGAT_small_module(hidden_channels = 1280 , heads = 1)
TropiGATv2_eg.eval()

TropiGAT_small_module(
  (conv): GATv2Conv((-1, -1), 1280, heads=1)
  (hetero_conv): HeteroConv(num_relations=1)
  (linear_layers): Sequential(
    (0): Linear(in_features=1280, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.01)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=1000, out_features=520, bias=True)
    (5): BatchNorm1d(520, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): LeakyReLU(negative_slope=0.01)
    (7): Dropout(p=0.2, inplace=False)
    (8): Linear(in_features=520, out_features=1, bias=True)
  )
)

In [59]:
with torch.no_grad():
    out = TropiGATv2_eg(graph_KL37)

In [76]:
out.argmax(dim=0)
out[graph_KL37["B1"].train_mask]
graph_KL37["B1"].train_mask

tensor(16)

In [64]:
graph_KL37["B1"].y

tensor([0, 0, 0,  ..., 0, 0, 0])

In [60]:
@torch.no_grad()
def evaluate(model, graph, mask):
    model.eval()
    out_eval = model(graph)   
    pred = out_eval.argmax(dim=0)[mask]
    labels = graph["B1"].y[mask]
    val_loss = criterion(out_train[mask], graph["B1"].y[mask])
    # Calculate the metrics
    f1 = f1_score(labels, pred, average='binary')
    precision = precision_score(labels, pred, average='binary')
    recall = recall_score(labels, pred, average='binary') 
    mcc = matthews_corrcoef(labels, pred)  
    accuracy = accuracy_score(labels, pred)
    auc = roc_auc_score(labels, out_eval[mask])
    return val_loss.item(), (f1, precision, recall, mcc, accuracy, auc)
    


tensor([0.0089, 0.0101, 0.0101,  ..., 0.0016, 0.0049, 0.0145])

In [None]:
def train(model, graph):
    model.train()
    optimizer.zero_grad()
    out_train = model(graph)
    loss = criterion(out_train[graph["B1"].train_mask], graph["B1"].y[graph["B1"].train_mask])
    loss.backward()
    optimizer.step()
    return loss

In [None]:
def main(kltype):
    graph_data_kltype = graph_dico[graph_data]
    logging.info(f"Let's start the work with {conv}\t{hidden_channels}\t{dropout}\t{lr}\t{heads}")
    model = Model(parameters_model["hidden_channels"],
                  parameters_model["heads"],parameters_model["dropout"]) 
    model(graph_data_kltype)
    optimizer = torch.optim.Adam(model.parameters(), lr = parameters_model["lr"] , weight_decay=parameters_model["weight_decay"])
    scheduler = ReduceLROnPlateau(optimizer, 'min')
    criterion = torch.nn.BCEWithLogitsLoss()
    for epoch in range(500):
        train_loss = train(model, graph_data_kltype)
        if epoch % 25 == 0:
            # Get all metrics 
            test_loss, metrics = evaluate(model, graph_data_kltype, graph_data_kltype["B1"].test_mask)
            info_training_concise = f'Epoch: {epoch}\tTrain Loss: {train_loss}\tTest Loss: {test_loss}\tMCC: {metrics[3]}\tAUC: {metrics[5]}'
            info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, 
            F1 Score: {metrics[0]}, Precision: {metrics[1]}, Recall: {metrics[2]}, MCC: {metrics[3]}, 
            Accuracy: {metrics[4]}, AUC: {metrics[5]}'
            logging.info(info_training_concise)
            print(info_training)
            scheduler.step(test_loss)
    # Save the model
    torch.save(model.state_dict(), f"{path_ensemble}/{kltype}.TropiGATv2.2509.pt")
    # The final eval :
    print("Final evaluation ...")
    test_loss, metrics = evaluate(model, graph_data_kltype, graph_data_kltype["B1"].eval_mask)
    print(f'F1 Score: MCC :{mcc}, {f1}, Precision: {precision}, Recall: {recall}, Accuracy: {accuracy}, AUC: {auc}')
    logging.info(f"Final evaluation ...\nF1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}")

if __name__ == "__main__":
    main()

In [None]:
def train(model, graph, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out_train = model(graph)
    loss = criterion(out_train[graph["B1"].train_mask], graph["B1"].y[graph["B1"].train_mask])
    loss.backward()
    optimizer.step()
    return loss
    

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Training :
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  

parameters_model = {"hidden_channels" : 1000,
                    "lr" : 0.0001,
                    "heads" : 1,
                    "dropout" : 0.1,
                    "criterion" : torch.nn.BCEWithLogitsLoss(),
                   }

def train(model, data, optimizer, criterion, edge_type):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    out_model = model(data)
    edge_labels = data[edge_type].edge_label
    loss = criterion(out_model, edge_labels)
    loss.backward()
    optimizer.step()
    return loss.item() 

@torch.no_grad()
def evaluate(model, data, criterion, edge_type):
    model.eval()
    data = data.to(device)
    out = model(data)
    edge_labels = data[edge_type].edge_label
    val_loss = criterion(out, edge_labels)
    probs = torch.sigmoid(out)
    pred_class = probs.round()
    all_preds = pred_class
    all_labels = edge_labels
    all_probs = probs
    # Calculate the metrics
    f1 = f1_score(all_labels, all_preds, average='binary')
    precision = precision_score(all_labels, all_preds, average='binary')
    recall = recall_score(all_labels, all_preds, average='binary')  
    mcc = matthews_corrcoef(all_labels, all_preds)  
    accuracy = accuracy_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_probs)
    return val_loss.item(), f1, precision, recall, mcc, accuracy, auc 

def main():
    logging.info(f"Let's start the work with {conv}\t{hidden_channels}\t{dropout}\t{lr}\t{heads}")
    model = Model(parameters_model["conv"],parameters_model["hidden_channels"],
                  parameters_model["heads"],parameters_model["dropout"],
                  parameters_model["n_kl_types"]).to(device)
    model(train_data)
    optimizer = torch.optim.Adam(model.parameters(), lr = parameters_model["lr"] , weight_decay=0.001)
    #optimizer = torch.optim.AdamW(model.parameters(), lr = parameters_model["lr"])
    scheduler = ReduceLROnPlateau(optimizer, 'min')
    edge_type = ("B1", "infects", "A")
    for epoch in range(3000):
        train_loss = train(model, train_data, optimizer, parameters_model["criterion"], edge_type)
        if epoch % 25 == 0:
            # Get all metrics including recall and MCC from evaluate function
            test_loss, f1, precision, recall, mcc, accuracy, auc = evaluate(model, test_data, parameters_model["criterion"], edge_type)
            info_training_concise = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, MCC: {mcc}, AUC: {auc}'
            info_training = f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, F1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}'
            logging.info(info_training_concise)
            print(info_training)
            scheduler.step(test_loss)
    # Save the model
    #torch.save(model.state_dict(), f"{path_work}/GATv2Conv.debud_clean.1909.pt")
    # The final eval :
    print("Final evaluation ...")
    val_loss, f1, precision, recall, mcc, accuracy, auc = evaluate(model, val_data, criterion, edge_type)
    print(f'F1 Score: MCC :{mcc}, {f1}, Precision: {precision}, Recall: {recall}, Accuracy: {accuracy}, AUC: {auc}')
    logging.info(f"Final evaluation ...\nF1 Score: {f1}, Precision: {precision}, Recall: {recall}, MCC: {mcc}, Accuracy: {accuracy}, AUC: {auc}")

if __name__ == "__main__":
    main()

In [None]:
# Training code : 
def train(model, graph, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out_train = model(graph)
    loss = criterion(out_train[graph["B1"].train_mask], graph["B1"].y[graph["B1"].train_mask])
    loss.backward()
    optimizer.step()
    return loss

@torch.no_grad()
def evaluate(model, graph, criterion, mask):
    model.eval()
    out_eval = model(graph)   
    pred = out_eval.argmax(dim=0)[mask]
    labels = graph["B1"].y[mask]
    val_loss = criterion(out_train[mask], graph["B1"].y[mask])
    # Calculate the metrics
    f1 = f1_score(labels, pred, average='binary')
    precision = precision_score(labels, pred, average='binary')
    recall = recall_score(labels, pred, average='binary') 
    mcc = matthews_corrcoef(labels, pred)  
    accuracy = accuracy_score(labels, pred)
    auc = roc_auc_score(labels, out_eval[mask])
    return val_loss.item(), f1, precision, recall, mcc, accuracy, auc 
    

