In [None]:
# GNN

import os
import torch
from torch_geometric.data import DataLoader
from torch_geometric.utils import negative_sampling
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, roc_auc_score, roc_curve
import matplotlib.pyplot as plt

class GraphAutoencoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.encoder = GCNConv(in_channels, hidden_channels)
        # Ensure that output dimension from the last GCN layer matches the expected input for the decoder.
        self.hidden = GCNConv(hidden_channels, hidden_channels // 2)  # Smaller hidden representation
        self.decoder = torch.nn.Sequential(
            # The decoder needs to handle input from two node embeddings concatenated together.
            torch.nn.Linear((hidden_channels // 2) * 2, hidden_channels),  # Adjust input size
            torch.nn.LeakyReLU(0.01),
            torch.nn.Dropout(0.5),
            torch.nn.Linear(hidden_channels, 1)  # Output single value for edge prediction
        )

    def forward(self, x, edge_index):
        x = F.leaky_relu(self.encoder(x, edge_index))
        z = self.hidden(x, edge_index)
        z_edge = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1)
        adj_hat = torch.sigmoid(self.decoder(z_edge).squeeze())
        return adj_hat, z


def load_data_from_file(filename):
    data = torch.load(filename)
    return data

def train(loader, model, device, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for data in loader:
            data.to(device)
            optimizer.zero_grad()
            pred_adj, _ = model(data.x, data.edge_index)
            loss = loss_function(pred_adj, data.edge_index, data.num_nodes, device)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: Loss {total_loss / len(loader)}")

def validate(loader, model, device):
    model.eval()
    all_labels = []
    all_predictions = []
    all_probs = []

    with torch.no_grad():
        for data in loader:
            data.to(device)
            pred_adj, _ = model(data.x, data.edge_index)
            predicted_probs = pred_adj.view(-1).cpu().numpy()  # Use probabilities for AUC calculation
            predicted_labels = (pred_adj > 0.00257).float().cpu().numpy()
            if hasattr(data, 'z'):
                true_labels = data.z.cpu().numpy()
                all_labels.extend(true_labels)
                all_predictions.extend(predicted_labels)
                all_probs.extend(predicted_probs)
            else:
                raise ValueError("Labels not found in data.")

    precision, recall, f1_score, _ = precision_recall_fscore_support(all_labels, all_predictions, average='binary')
    accuracy = accuracy_score(all_labels, all_predictions)
    auc_roc = roc_auc_score(all_labels, all_probs)  # Calculate AUC-ROC

    # Plot ROC curve
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', label='ROC curve (area = %0.2f)' % auc_roc)
    plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic')
    plt.legend(loc="lower right")
    plt.show()

    print(f"Precision: {precision}, Recall: {recall}, F1-Score: {f1_score}, Accuracy: {accuracy}, AUC-ROC: {auc_roc}")
    return precision, recall, f1_score, accuracy, auc_roc

def loss_function(pred_adj, edge_index, num_nodes, device):
    neg_edge_index = negative_sampling(edge_index, num_nodes, num_neg_samples=edge_index.size(1))
    pos_label = torch.ones(edge_index.size(1), device=device)
    neg_label = torch.zeros(neg_edge_index.size(1), device=device)
    true_labels = torch.cat([pos_label, neg_label], dim=0)
    pred_pos = pred_adj
    pred_neg = 1 - pred_adj
    pred_labels = torch.cat([pred_pos, pred_neg], dim=0)
    return F.binary_cross_entropy(pred_labels, true_labels)

def main():
    train_dir = 'directory for train data'
    test_file = 'directory for test'

    train_files = [os.path.join(train_dir, f) for f in os.listdir(train_dir) if f.endswith('.pt')]
    train_data = [load_data_from_file(f) for f in train_files]
    train_loader = DataLoader(train_data, batch_size=10, shuffle=False)

    test_data = load_data_from_file(test_file)
    test_loader = DataLoader([test_data], batch_size=1, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GraphAutoencoder(train_data[0].num_features, 49).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.00013417234297668777, weight_decay=4.3296940930545564e-06)

    train(train_loader, model, device, optimizer, 20)
    metrics = validate(test_loader, model, device)
    print(f"Validation Metrics for {test_file}: {metrics}")

if __name__ == "__main__":
    main()