In [8]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, to_hetero , SAGEConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore") 


In [9]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/media/concha-eloko/Linux/PPT_clean"
#path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"

    # Open the DF
DF_info = pd.read_csv(f"{path_work}/DF_Dpo.final.2705.tsv", sep = "\t" ,  header = 0 )
    # Open the embeddings
DF_embeddings = pd.read_csv(f"{path_work}/Dpo.2705.embeddings.ultimate.csv", sep = ",", header= None )
DF_embeddings.rename(columns={0: 'index'}, inplace=True)

    # Filter the DF :
DF_info_filtered = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_ToReLabel = DF_info[DF_info["KL_type_LCA"].str.contains("\\|")]
all_data = pd.merge(DF_info_filtered , DF_embeddings , on = "index")

# Mind the over representation of outbreaks :
all_data = all_data.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)


In [10]:
graph_data = torch.load(f'{path_work}/graph_file.1107.pt')
#graph_data = torch.load(f'{path_work}/train_nn/graph_file.1107.pt')

graph_data

HeteroData(
  [1mA[0m={ x=[4530, 127] },
  [1mB1[0m={ x=[11339, 0] },
  [1mB2[0m={ x=[3608, 1280] },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 9677],
    y=[9677]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 13285],
    y=[13285]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 9677],
    y=[9677]
  }
)

In [23]:
all_data[all_data["index"] == "anubis__1644"]

Unnamed: 0,Phage,KL_type_LCA,Infected_ancestor,Protein_name,Dataset,index,seq,prophage_id,1,2,...,1271,1272,1273,1274,1275,1276,1277,1278,1279,1280
9673,GCF_004311345.1__phage11,KL34,GCF_004311345.1,GCF_004311345.1__phage11__99,anubis,anubis__1644,MTANYPASILPPNATAVERAIDRASAAALERLPVYLIRWVKDPDSC...,prophage_11944,-0.001691,-0.067717,...,0.119783,0.073706,0.057788,0.004341,0.007389,-0.081588,0.100995,-0.045545,0.021685,0.013167


***
# Build the graph :

In [70]:
# *****************************************************************************
# Nodes A : the bacteria (ancestors) - KLtype feature
# Nodes B1 : the prophage (phage) - No feature
# Nodes B2 : the depo (index_seq) - 1280-d embeddings

# Build the Graph Data :
graph_data = HeteroData()

    # Indexation process (shall add the N phages to predict)
indexation_nodes_A = all_data["Infected_ancestor"].unique().tolist()  
indexation_nodes_B1 = all_data["Phage"].unique().tolist() + [f"Dpo_to_predict_{n}" for n in DF_embeddings["index"].unique().tolist()]
indexation_nodes_B2 = DF_embeddings["index"].unique().tolist() 

ID_nodes_A = {item:index for index, item in enumerate(indexation_nodes_A)}
ID_nodes_A_r = {index:item for index, item in enumerate(indexation_nodes_A)}

ID_nodes_B1 = {item:index for index, item in enumerate(indexation_nodes_B1)}
ID_nodes_B1_r = {index:item for index, item in enumerate(indexation_nodes_B1)}

ID_nodes_B2 = {item:index for index, item in enumerate(indexation_nodes_B2)}
ID_nodes_B2_r = {index:item for index, item in enumerate(indexation_nodes_B2)}

In [71]:
# Make the node feature file : 
OHE = OneHotEncoder(sparse=False)
one_hot_encoded = OHE.fit_transform(all_data[["KL_type_LCA"]])
label_mapping = {label: one_hot_encoded[i] for i, label in enumerate(OHE.categories_[0])}

node_feature_A = torch.tensor([label_mapping[all_data[all_data["Infected_ancestor"] == ID_nodes_A_r[i]]["KL_type_LCA"].values[0]] for i in range(0,len(ID_nodes_A_r))], dtype=torch.float)
node_feature_B1 = torch.zeros((len(ID_nodes_B1), 0), dtype=torch.float)
node_feature_B2 = torch.tensor([DF_embeddings[DF_embeddings["index"] == ID_nodes_B2_r[i]].values[0][1:1281].tolist() for i in range(0,len(ID_nodes_B2_r))] , dtype=torch.float)

# feed the graph
graph_data["A"].x = node_feature_A
graph_data["B1"].x = node_feature_B1
graph_data["B2"].x = node_feature_B2

# Write files : 
node_feature_A_array = node_feature_A.numpy()
node_feature_B1_array = node_feature_B1.numpy()
node_feature_B2_array = node_feature_B2.numpy()

df_node_feature_A_array = pd.DataFrame(node_feature_A_array)
df_node_feature_B1_array = pd.DataFrame(node_feature_B1_array)
df_node_feature_B2_array = pd.DataFrame(node_feature_B2_array)

df_node_feature_A_array.to_csv(f"{path_work}/node_features.A.csv", index=False, header=False)
df_node_feature_B1_array.to_csv(f"{path_work}/node_features.B1.csv", index=False, header=False)
df_node_feature_B2_array.to_csv(f"{path_work}/node_features.B2.csv", index=False, header=False)

In [73]:
# Make edge file
# Node B1 (prophage) - Node A (bacteria) :
edge_index_B1_A = []
for _, row in all_data.iterrows() :
    edge_index_B1_A.append([ID_nodes_B1[row["Phage"]], ID_nodes_A[row["Infected_ancestor"]]])
edge_index_B1_A = torch.tensor(edge_index_B1_A , dtype=torch.long)

# Node A (bacteria) - Node B1 (prophage) :
edge_index_A_B1 = []
for _, row in all_data.iterrows() :
    edge_index_A_B1.append([ID_nodes_A[row["Infected_ancestor"]] , ID_nodes_B1[row["Phage"]]])
edge_index_A_B1 = torch.tensor(edge_index_A_B1 , dtype=torch.long)

# Node B2 (depolymerase) - Node B1 (prophage) :
edge_index_B2_B1 = []
for phage in all_data.Phage.unique() :
    all_data_phage = all_data[all_data["Phage"] == phage]
    for _, row in all_data_phage.iterrows() :
        edge_index_B2_B1.append([ID_nodes_B2[row["index"]], ID_nodes_B1[row["Phage"]]])
# Add in there the edges between the fake prophages and the each Dpos :
for prophage , index in ID_nodes_B1.items() :
    if prophage.count("Dpo_to_predict_") > 0 : 
        id_dpo = prophage.split("Dpo_to_predict_")[1]
        edge_index_B2_B1.append([ID_nodes_B2[id_dpo], index])
edge_index_B2_B1 = torch.tensor(edge_index_B2_B1 , dtype=torch.long)

# feed the graph
graph_data['B1', 'infects', 'A'].edge_index = edge_index_B1_A.t().contiguous()
graph_data['B2', 'expressed', 'B1'].edge_index = edge_index_B2_B1.t().contiguous()
# That one is optional  
graph_data['A', 'harbors', 'B1'].edge_index = edge_index_A_B1.t().contiguous()

# Write files : 
edge_index_B1_A_array = edge_index_B1_A.numpy()
edge_index_A_B1_array = edge_index_A_B1.numpy()
edge_index_B2_B1_array = edge_index_B2_B1.numpy()

df_edge_index_B1_A_array = pd.DataFrame(edge_index_B1_A_array)
df_edge_index_A_B1_array = pd.DataFrame(edge_index_A_B1_array)
df_edge_index_B2_B1_array = pd.DataFrame(edge_index_B2_B1_array)

df_edge_index_B1_A_array.to_csv(f"{path_work}/edge_index_B1_A_array.csv", index=False, header=False)
df_edge_index_A_B1_array.to_csv(f"{path_work}/edge_index_A_B1_array.csv", index=False, header=False)
df_edge_index_B2_B1_array.to_csv(f"{path_work}/edge_index_B2_B1_array.csv", index=False, header=False)

In [77]:
    # Make the Y file : 
graph_data['B1', 'infects', 'A'].y = torch.ones(len(graph_data['B1', 'infects', 'A'].edge_index[0]))
graph_data['B2', 'expressed', 'B1'].y = torch.ones(len(graph_data['B2', 'expressed', 'B1'].edge_index[0]))
# That one is optional  
graph_data['A', 'harbors', 'B1'].y = torch.ones(len(graph_data['A', 'harbors', 'B1'].edge_index[0]))

***
# Work on the GNN

In [87]:
# *****************************************************************************
# Data instance : 

#graph_data
#torch.save(graph_data , f'{path_work}/graph_file.1107.pt')

In [22]:
graph_data 

HeteroData(
  [1mA[0m={ x=[4530, 127] },
  [1mB1[0m={ x=[11339, 0] },
  [1mB2[0m={ x=[3608, 1280] },
  [1m(B1, infects, A)[0m={
    edge_index=[2, 9677],
    y=[9677]
  },
  [1m(B2, expressed, B1)[0m={
    edge_index=[2, 13285],
    y=[13285]
  },
  [1m(A, harbors, B1)[0m={
    edge_index=[2, 9677],
    y=[9677]
  }
)

> Plot the graph : 

In [38]:
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

def visualize_graph(G, color):
    G = G.to_homogeneous()
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()
    #plt.save(f"{path_work}/graph_representation.svg" , dpi = 800)
    
G = to_networkx(data_sub, to_undirected=False)
visualize_graph(G, color=data.y)

AttributeError: 'HeteroData' has no attribute 'nodes'

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(-1, hidden_channels, add_self_loops=False)
        self.conv2 = GCNConv(-1, out_channels, add_self_loops=False)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class Classifier(torch.nn.Module):
    def forward(self, x, edge_index):
        edge_feat_B1 = x["B1"][edge_index[("B1", "infects", "A")][0]]
        edge_feat_A = x["A"][edge_index[("B1", "infects", "A")][1]]
        return (edge_feat_B1 * edge_feat_A).sum(dim=-1)
    
class Model(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.gnn_B2_B1 = GNN(hidden_channels, out_channels)
        self.gnn_B1_A = GNN(hidden_channels, out_channels)
        self.classifier = Classifier()
    def forward(self, graph_data):
        # Propagate B2 features to B1
        x_B2 = graph_data['B2'].x
        edge_index_B2_B1 = graph_data[('B2', 'expressed', 'B1')].edge_index
        x_B1_from_B2 = self.gnn_B2_B1(x_B2, edge_index_B2_B1)
        # Propagate new B1 features to A
        x_B1 = x_B1_from_B2  
        edge_index_B1_A = graph_data[('B1', 'infects', 'A')].edge_index
        x_A_from_B1 = self.gnn_B1_A(x_B1, edge_index_B1_A)
        # Classification based on new features
        x = {'B1': x_B1_from_B2, 'A': x_A_from_B1}
        pred = self.classifier(x['B1'], x['A'])
        return pred



> Template (functional)

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(-1, hidden_channels, add_self_loops=False)
        self.conv2 = GCNConv(-1, out_channels, add_self_loops=False)
        #self.conv2 = GCNConv(-1, hidden_channels, add_self_loops=False)
        #self.conv3 = GCNConv(-1, out_channels, add_self_loops=False)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        #x = F.relu(self.conv2(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class Classifier(torch.nn.Module):
    def __init__(self, A_dim, B1_dim):
        super().__init__()
        self.pad = torch.nn.ConstantPad1d((0, A_dim - B1_dim), 0)
    def forward(self, x, edge_index):
        edge_feat_B1 = x["B1"][edge_index[("B1", "infects", "A")][0]]
        edge_feat_A = x["A"][edge_index[("B1", "infects", "A")][1]]
        # Padding B1 features to match A features dimension
        edge_feat_B1 = self.pad(edge_feat_B1)
        return (edge_feat_B1 * edge_feat_A).sum(dim=-1)
    
class Model(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.gnn_B2_B1 = GNN(hidden_channels, out_channels)
        self.gnn_B1_A = GNN(hidden_channels, out_channels)
        self.classifier = Classifier()
    def forward(self, graph_data):
        # Propagate B2 features to B1
        x_B2 = graph_data['B2'].x
        edge_index_B2_B1 = graph_data[('B2', 'expressed', 'B1')].edge_index
        x_B1_from_B2 = self.gnn_B2_B1(x_B2, edge_index_B2_B1)
        # Propagate new B1 features to A
        x_B1 = x_B1_from_B2  
        edge_index_B1_A = graph_data[('B1', 'infects', 'A')].edge_index
        x_A_from_B1 = self.gnn_B1_A(x_B1, edge_index_B1_A)
        # Concatenate the updated A features with the initial A features
        x_A_initial = graph_data['A'].x.clone() # Keeping a copy of initial features
        x_A = torch.cat((x_A_initial, x_A_from_B1), dim=-1)
        # Classification based on new features
        x = {'B1': x_B1_from_B2, 'A': x_A}
        pred = self.classifier(x, edge_index_B1_A)
        return pred

model = Model(hidden_channels= 580 , out_channels = 100)

***
> Training : 

In [2]:
transform = T.RandomLinkSplit(
    num_val=0.1, 
    num_test=0.2, 
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=1.0,  
    add_negative_train_samples=True, 
    edge_types=("B1", "infects", "A"),
    rev_edge_types=("A", "harbors", "B1"), 
)

train_data, val_data, test_data = transform(graph_data)

NameError: name 'T' is not defined

In [49]:
train_loader = LinkNeighborLoader(
    data=train_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), train_data["B1", "infects", "A"].edge_label_index),
    edge_label=train_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data=val_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), val_data["B1", "infects", "A"].edge_label_index),
    edge_label=val_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), test_data["B1", "infects", "A"].edge_label_index),
    edge_label=test_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

In [51]:
sampled_data = next(iter(train_loader))
print(sampled_data)

ImportError: 'NeighborSampler' requires either 'pyg-lib' or 'torch-sparse'

In [None]:
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

def visualize_graph(G, color):
    G = G.to_homogeneous()
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()
    #plt.save(f"{path_work}/graph_representation.svg" , dpi = 800)
    
G = to_networkx(data_sub, to_undirected=False)
visualize_graph(G, color=data.y)

In [None]:
# Define seed edges:
edge_label_index = train_data["B1", "infects", "A"].edge_label_index
edge_label = train_data["B1", "infects", "A"].edge_label

train_loader = LinkNeighborLoader(
    data=...,  # TODO
    num_neighbors=...,  # TODO
    neg_sampling_ratio=...,  # TODO
    edge_label_index=(("B1", "infects", "A"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)


In [24]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# setup Dataloader for batch processing
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for epoch in range(100):  # loop over the dataset multiple times
    all_labels = []
    all_preds = []
    for i, data in enumerate(dataloader):
        # get the inputs
        inputs, labels = data.x, data.y
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # collect all labels and predictions for metrics calculation
        all_labels += labels.tolist()
        all_preds += torch.sigmoid(outputs).detach().tolist()

    # calculate metrics
    if epoch % 10 == 0:
        preds_bin = [1 if x >= 0.5 else 0 for x in all_preds]
        precision = precision_score(all_labels, preds_bin)
        f1 = f1_score(all_labels, preds_bin)
        auc = roc_auc_score(all_labels, all_preds)
        print(f'Epoch: {epoch}, Loss: {loss.item()}, Precision: {precision}, F1-score: {f1}, AUC: {auc}')

print('Finished Training')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model and the optimizer
model = Model(hidden_channels=580, out_channels=100).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
loss_function = torch.nn.BCEWithLogitsLoss()

def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        edge_index = data[('B1', 'infects', 'A')].edge_index
        pos_out = model(data)
        # Negative sampling
        total_edge = edge_index.size(1)
        edge_index, _ = negative_sampling(edge_index, num_nodes=data['A'].x.size(0), num_neg_samples=total_edge)
        neg_out = model(data)
        out = torch.cat([pos_out, neg_out])
        # Generate labels for the loss function
        y = torch.cat([torch.ones(total_edge, device=device), torch.zeros(total_edge, device=device)])
        loss = loss_function(out, y)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    return total_loss / len(train_loader)

def test(loader):
    model.eval()
    y_true = []
    y_pred = []
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data)
        y_pred.append(pred.cpu().numpy())
        y_true.append(data.y.cpu().numpy())
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    return roc_auc_score(y_true, y_pred)

# Training loop
for epoch in range(100): 
    loss = train()
    auc = test(test_loader)
    print(f'Epoch: {epoch+1}, Loss: {loss:.4f}, AUC: {auc:.4f}')

In [None]:
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    for data in loader:
        data = data.to(device)
        pred = model(data)
        val_loss = criterion(pred[data.val_mask], data.y[data.val_mask])
        total_loss += val_loss.item()
        _, pred_class = pred.max(dim=1)
        all_preds.extend(pred_class.cpu().numpy())
        all_labels.extend(data.y[data.val_mask].cpu().numpy())
    # Calculate the metrics
    f1 = f1_score(all_labels, all_preds, average='macro')
    precision = precision_score(all_labels, all_preds, average='macro')
    accuracy = accuracy_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_preds)
    return total_loss / len(loader), f1, precision, accuracy, auc

def main():
    hidden_channels = 580 
    out_channels = 100 
    model = GNN(hidden_channels, out_channels).to(device)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    for epoch in range(100): 
        train_loss = train(model, train_loader, optimizer, criterion)
        if epoch % 10 == 0:
            test_loss, f1, precision, accuracy, auc = evaluate(model, test_loader, criterion)
            print(f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, F1 Score: {f1}, Precision: {precision}, Accuracy: {accuracy}, AUC: {auc}')
    # Save the model
    torch.save(model.state_dict(), f"{path_work}/GCNConv.model.1307.pt")

if __name__ == "__main__":
    main()

In [None]:
from sklearn.model_selection import train_test_split

# Split the data into train and test
train_idx, test_idx = train_test_split(range(len(graph_data['B1', 'infects', 'A'].y)), test_size=0.2, random_state=42)

# Initialize the model, optimizer and loss function
model = Model(hidden_channels=64)
model = to_hetero(model, graph_data, aggr='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

for epoch in range(100):  # adjust as needed
    optimizer.zero_grad()
    out = model(train_data)
    pos_score = model.classifier(out, train_data.edge_index)
    neg_score = model.classifier(out, generate_negative_samples(train_data.edge_index, train_data.num_nodes, train_data.edge_index.size(1)))
    loss = criterion(torch.cat([pos_score, neg_score]), torch.cat([torch.ones(pos_score.size()), torch.zeros(neg_score.size())]))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Train loss: {loss.item()}")

out = model(test_data)
pred = torch.sigmoid(model.classifier(out, test_data.edge_index)).numpy()
acc = accuracy_score(test_data.y.numpy(), pred > 0.5)
prec = precision_score(test_data.y.numpy(), pred > 0.5)
rec = recall_score(test_data.y.numpy(), pred > 0.5)
f1 = f1_score(test_data.y.numpy(), pred > 0.5)
auroc = roc_auc_score(test_data.y.numpy(), pred)

print(f"Accuracy: {acc}, Precision: {prec}, Recall: {rec}, F1-score: {f1}, AUROC: {auroc}")



In [None]:
# *****************************************************************************
# Train / Eval : 
def generate_negative_samples(edge_index, num_nodes, num_neg_samples):
    all_possible_edges = torch.tensor(list(product(range(num_nodes['A']), range(num_nodes['B1']))), dtype=torch.long).t()
    positive_edges = set((i.item(), j.item()) for i, j in edge_index.t().tolist())
    all_possible_edges = [edge for edge in all_possible_edges if edge not in positive_edges]
    negative_edges = random.sample(all_possible_edges, min(num_neg_samples, len(all_possible_edges)))
    return torch.tensor(negative_edges, dtype=torch.long).t()

def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch)
        pos_batch = batch.edge_index
        pos_score = model.classifier(out, pos_batch)
        neg_batch = generate_negative_samples(batch.edge_index, batch.num_nodes, pos_batch.size(1))
        neg_score = model.classifier(out, neg_batch)
        all_score = torch.cat([pos_score, neg_score], dim=0)
        all_target = torch.cat([torch.ones(pos_score.size(0)), torch.zeros(neg_score.size(0))], dim=0).to(device)
        loss = criterion(all_score, all_target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader):
    model.eval()
    all_preds, all_labels = [], []
    for batch in loader:
        out = model(batch)
        pred = torch.sigmoid(model.classifier(out, batch.edge_index)).numpy()
        all_preds.append(pred)
        all_labels.append(batch.y.numpy())
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    return accuracy_score(all_labels, all_preds > 0.5), precision_score(all_labels, all_preds > 0.5), recall_score(all_labels, all_preds > 0.5), f1_score(all_labels, all_preds > 0.5), roc_auc_score(all_labels, all_preds)

# Split the dataset into 5 folds for cross-validation
kfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)

# Initialize the model, optimizer and loss function
model = Model(hidden_channels=64)
model = to_hetero(model, data, aggr='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
#criterion = F.binary_cross_entropy_with_logits()
criterion = torch.nn.BCEWithLogitsLoss()

# Cross-validation
for fold, (train_idx, test_idx) in enumerate(kfold.split(graph_data.node_data['A'], graph_data.y)):
    train_data = graph_data[train_idx]
    test_data = graph_data[test_idx]
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=32)
    for epoch in range(100):  # adjust as needed
        train_loss = train(model, train_loader, optimizer, criterion)
        print(f"Fold: {fold+1}, Epoch: {epoch+1}, Train loss: {train_loss}")
    acc, prec, rec, f1, auroc = test(model, test_loader)
    print(f"Fold: {fold+1}, Accuracy: {acc}, Precision: {prec}, Recall: {rec}, F1-score: {f1}, AUROC: {auroc}")




In [None]:
# Cross-validation
for fold, (train_idx, test_idx) in enumerate(kfold.split(graph_data.node_items['A'], graph_data.y)):
    train_data = graph_data[train_idx]
    test_data = graph_data[test_idx]
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=32)
    for epoch in range(100):  # adjust as needed
        train_loss = train(model, train_loader, optimizer, criterion)
        print(f"Fold: {fold+1}, Epoch: {epoch+1}, Train loss: {train_loss}")
    acc, prec, rec, f1, auroc = test(model, test_loader)
    print(f"Fold: {fold+1}, Accuracy: {acc}, Precision: {prec}, Recall: {rec}, F1-score: {f1}, AUROC: {auroc}")



In [None]:
# *****************************************************************************
# Save model after training
torch.save(model.state_dict(), f"{path_work}/train_nn/MultiDomain.LSTM.model")

import json
with open(f"{path_work}/train_nn/MultiDomain.LSTM.model.out" , "w") as outfile :
    outfile.write(json.dumps(history))

