In [8]:
from torch_geometric.data import HeteroData, DataLoader
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, to_hetero , SAGEConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader

import torch
from torch import nn 
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

import os 
import pandas as pd
import numpy as np
from tqdm import tqdm
from itertools import product
import random
from collections import Counter
import warnings
warnings.filterwarnings("ignore") 


In [9]:
# *****************************************************************************
# Load the Dataframes :
path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"

graph_data = torch.load(f'{path_work}/train_nn/graph_file.1107.pt')


In [None]:
# *****************************************************************************
# The model :
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(-1, hidden_channels, add_self_loops=False)
        self.conv2 = GCNConv(-1, out_channels, add_self_loops=False)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class Classifier(torch.nn.Module):
    def forward(self, x, edge_index):
        edge_feat_B1 = x["B1"][edge_index[("B1", "infects", "A")][0]]
        edge_feat_A = x["A"][edge_index[("B1", "infects", "A")][1]]
        return (edge_feat_B1 * edge_feat_A).sum(dim=-1)
    
class Model(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.gnn_B2_B1 = GNN(hidden_channels, out_channels)
        self.gnn_B1_A = GNN(hidden_channels, out_channels)
        self.classifier = Classifier()
    def forward(self, graph_data):
        # Propagate B2 features to B1
        x_B2 = graph_data['B2'].x
        edge_index_B2_B1 = graph_data[('B2', 'expressed', 'B1')].edge_index
        x_B1_from_B2 = self.gnn_B2_B1(x_B2, edge_index_B2_B1)
        # Propagate new B1 features to A
        x_B1 = x_B1_from_B2  
        edge_index_B1_A = graph_data[('B1', 'infects', 'A')].edge_index
        x_A_from_B1 = self.gnn_B1_A(x_B1, edge_index_B1_A)
        # Classification based on new features
        x = {'B1': x_B1_from_B2, 'A': x_A_from_B1}
        pred = self.classifier(x['B1'], x['A'])
        return pred



In [2]:
# *****************************************************************************
# Pre-process data :
transform = T.RandomLinkSplit(
    num_val=0.1, 
    num_test=0.2, 
    #disjoint_train_ratio=...,  
    neg_sampling_ratio=1.0,  
    add_negative_train_samples=True, 
    edge_types=("B1", "infects", "A"),
    rev_edge_types=("A", "harbors", "B1"), 
)

train_data, val_data, test_data = transform(graph_data)

train_loader = LinkNeighborLoader(
    data=train_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), train_data["B1", "infects", "A"].edge_label_index),
    edge_label=train_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

val_loader = LinkNeighborLoader(
    data=val_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), val_data["B1", "infects", "A"].edge_label_index),
    edge_label=val_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

test_loader = LinkNeighborLoader(
    data=test_data,  
    num_neighbors= [-1],  
    edge_label_index=(("B1", "infects", "A"), test_data["B1", "infects", "A"].edge_label_index),
    edge_label=test_data["B1", "infects", "A"].edge_label,
    batch_size=128,
    shuffle=True,
)

NameError: name 'T' is not defined

In [None]:
# *****************************************************************************
# Training :
def train(model, loader, optimizer, criterion, edge_type):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        edge_labels = data[edge_type].y[data[edge_type].train_mask]
        loss = criterion(out[data[edge_type].train_mask], edge_labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

@torch.no_grad()
def evaluate(model, loader, criterion, edge_type):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    for data in loader:
        data = data.to(device)
        pred = model(data)
        edge_labels = data[edge_type].y[data[edge_type].val_mask]
        val_loss = criterion(pred[data[edge_type].val_mask], edge_labels)
        total_loss += val_loss.item()
        _, pred_class = pred.max(dim=1)
        all_preds.extend(pred_class.cpu().numpy())
        all_labels.extend(edge_labels.cpu().numpy())
    # Calculate the metrics
    f1 = f1_score(all_labels, all_preds, average='macro')
    precision = precision_score(all_labels, all_preds, average='macro')
    accuracy = accuracy_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_preds)
    return total_loss / len(loader), f1, precision, accuracy, auc

def main():
    hidden_channels = 580 
    out_channels = 100 
    model = GNN(hidden_channels, out_channels).to(device)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    edge_type = ("B1", "infects", "A")  # adjust this as per your requirement
    for epoch in range(100): 
        train_loss = train(model, train_loader, optimizer, criterion, edge_type)
        if epoch % 10 == 0:
            test_loss, f1, precision, accuracy, auc = evaluate(model, test_loader, criterion, edge_type)
            print(f'Epoch: {epoch}, Train Loss: {train_loss}, Test Loss: {test_loss}, F1 Score: {f1}, Precision: {precision}, Accuracy: {accuracy}, AUC: {auc}')
    # Save the model
    torch.save(model.state_dict(), f"{path_work}/GCNConv.model.1307.pt")

if __name__ == "__main__":
    main()

In [None]:
# Cross-validation
for fold, (train_idx, test_idx) in enumerate(kfold.split(graph_data.node_items['A'], graph_data.y)):
    train_data = graph_data[train_idx]
    test_data = graph_data[test_idx]
    train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=32)
    for epoch in range(100):  # adjust as needed
        train_loss = train(model, train_loader, optimizer, criterion)
        print(f"Fold: {fold+1}, Epoch: {epoch+1}, Train loss: {train_loss}")
    acc, prec, rec, f1, auroc = test(model, test_loader)
    print(f"Fold: {fold+1}, Accuracy: {acc}, Precision: {prec}, Recall: {rec}, F1-score: {f1}, AUROC: {auroc}")

