# Convert to knowledge graph

In [None]:
def build_knowledge_graph(df, index, kg_col, keyindex_col):
    """
    Build a knowledge graph from a specified row of the DataFrame.
    
    Parameters:
      df         : pandas DataFrame containing the knowledge graph string.
      index      : index of the row to process.
      kg_col     : column name in the DataFrame that contains the KG string.
      keyindex_col: column name used to index (privacy content removed).
      
    Returns:
      G          : A networkx MultiDiGraph constructed from the extracted data.
    """
    import json
    import re
    import networkx as nx
    import json_repair
    from collections import Counter
    import ast

    # Get the knowledge graph string from the specified row
    data_str = df.loc[index, kg_col]
    # Note: Privacy information from keyindex is not extracted and used anymore.

    # 1. Extract JSON string for the entity list using regex.
    entity_pattern = r'#Final_Entity_List_Start#\s*json\s*(\[[\s\S]*?\])\s*#Final_Entity_List_End#'
    entity_match = re.search(entity_pattern, data_str)
    if entity_match:
        entity_json_str = entity_match.group(1)
        try:
            entities = json_repair.loads(entity_json_str)
        except Exception as e:
            print("Error parsing entity JSON:", e)
            entities = []
    else:
        entities = []

    # 2. Extract JSON string for the relationship list using regex.
    rel_pattern = r'#Final_Relationship_List_Start#\s*json\s*(\[[\s\S]*?\])\s*#Final_Relationship_List_End#'
    rel_match = re.search(rel_pattern, data_str)
    if rel_match:
        rel_json_str = rel_match.group(1)
        try:
            relationships = json_repair.loads(rel_json_str)
        except Exception as e:
            print("Error parsing relationship JSON:", e)
            relationships = []
    else:
        relationships = []

    # 3. Create a MultiDiGraph
    G = nx.MultiDiGraph()

    # 4. Add entities as nodes (using "name" as the node identifier)
    for i, entity in enumerate(entities):
        node_id = entity.get("name")
        if node_id is None:
            print(f"Warning: Entity number {i} is missing 'name' field. Skipping entity. Content: {entity}")
            continue
        # Privacy: removed addition of private info.
        G.add_node(node_id, **entity)

    # 5. Process relationships and add them as edges.
    for rel in relationships:
        # If relationship is a string, try converting it to a dict.
        if isinstance(rel, str):
            try:
                rel = ast.literal_eval(rel)
            except Exception as e:
                print(f"Error parsing relationship {rel}: {e}")
                continue
        source = rel.get("sub")
        target = rel.get("obj")
        if source is None or target is None:
            print("Warning: Relationship missing 'sub' or 'obj' field. Skipping relationship.", rel)
            continue
        # Exclude 'sub' and 'obj' from edge attributes.
        edge_attr = {key: value for key, value in rel.items() if key not in ["sub", "obj"]}
        # Privacy: removed addition of private info.
        G.add_edge(source, target, **edge_attr)

    return G


def aggregate_knowledge_graph(df, kg_col, keyindex_col):
    """
    Aggregate knowledge graphs from all rows of the DataFrame into one graph.
    """
    import networkx as nx
    import spacy

    # Load spaCy English model.
    nlp = spacy.load("en_core_web_sm")

    def spacy_lemmatize(text):
        """
        Convert text to lowercase and lemmatize it using spaCy.
        For example, 'Social Engineering Attacks' becomes 'social engineering attack'
        """
        doc = nlp(text.lower())
        return " ".join(token.lemma_ for token in doc if not token.is_punct and not token.is_space)

    # Helper function: merge two values into a list (preserving order and deduplication).
    def merge_values(val1, val2):
        def to_list(x):
            if x is None:
                return []
            if isinstance(x, list):
                return x
            return [x]
        list1 = to_list(val1)
        list2 = to_list(val2)
        merged = list(dict.fromkeys(list1 + list2))
        return merged

    # 1) First, merge all the knowledge graphs from each row.
    G_total = nx.MultiDiGraph()
    for index in df.index:
        G = build_knowledge_graph(df, index, kg_col, keyindex_col)
        # Merge nodes: if a node exists, merge its attributes.
        for node, attr in G.nodes(data=True):
            if node not in G_total:
                G_total.add_node(node, **attr)
            else:
                for key in set(list(attr.keys()) + list(G_total.nodes[node].keys())):
                    val_new = attr.get(key)
                    val_old = G_total.nodes[node].get(key)
                    G_total.nodes[node][key] = merge_values(val_old, val_new)
        # Merge all edges (allowing multiple edges).
        for u, v, edge_attr in G.edges(data=True):
            G_total.add_edge(u, v, **edge_attr)

    # 2) Use spaCy to lemmatize node names and group nodes by their lemma.
    lemma_to_nodes = {}
    for node in list(G_total.nodes()):
        if isinstance(node, str):
            lemma = spacy_lemmatize(node)
        else:
            lemma = node
        lemma_to_nodes.setdefault(lemma, []).append(node)
    
    duplicate_groups_count = sum(1 for group in lemma_to_nodes.values() if len(group) > 1)
    print("There are {} groups of nodes with identical lemmas that need to be merged.".format(duplicate_groups_count))
    
    # 3) For each group with more than one node, merge them.
    #    The node with the highest total degree (in_degree + out_degree) is retained.
    for lemma, nodes_same_lemma in lemma_to_nodes.items():
        if len(nodes_same_lemma) > 1:
            degree_dict = {node: G_total.in_degree(node) + G_total.out_degree(node)
                           for node in nodes_same_lemma}
            node_keep = max(degree_dict, key=degree_dict.get)
            nodes_to_merge = [n for n in nodes_same_lemma if n != node_keep]
            
            for node_merge in nodes_to_merge:
                # 1) Merge node attributes.
                for key in set(list(G_total.nodes[node_keep].keys()) + 
                               list(G_total.nodes[node_merge].keys())):
                    val_keep = G_total.nodes[node_keep].get(key)
                    val_merge = G_total.nodes[node_merge].get(key)
                    G_total.nodes[node_keep][key] = merge_values(val_keep, val_merge)
                
                # 2) Redirect all incoming edges from node_merge to node_keep.
                in_edges = list(G_total.in_edges(node_merge, keys=True, data=True))
                for u, v, key_edge, data_edge in in_edges:
                    G_total.add_edge(u, node_keep, **data_edge)
                
                # 3) Redirect all outgoing edges from node_merge to node_keep.
                out_edges = list(G_total.out_edges(node_merge, keys=True, data=True))
                for u, v, key_edge, data_edge in out_edges:
                    G_total.add_edge(node_keep, v, **data_edge)
                
                # 4) Remove the merged node.
                G_total.remove_node(node_merge)
    
    return G_total


# === Read the DataFrame ===
import pandas as pd

df = pd.read_excel('RQ3.xlsx')

# Aggregate the knowledge graphs from all rows of the DataFrame.
G_aggregated = aggregate_knowledge_graph(df, 'knowledge graph', 'keyindex')

import random
print("Random node and attributes:", random.choice(list(G_aggregated.nodes(data=True))))
print("Random edge and attributes:", random.choice(list(G_aggregated.edges(data=True))))


# Model Training

In [None]:
import os
import pickle
import random
import time
import math
import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve
from tqdm import tqdm
import matplotlib.pyplot as plt
import tools  # For calling tools.get_embedding interface

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#############################################
# 1. Define Structural Feature Computation Function
#############################################
def compute_structural_features(G, node_pairs, max_path_length=3):
    """
    Compute structural features for node pairs and return a tensor of shape [len(node_pairs), 16].
    """
    print(f"Computing structural features for {len(node_pairs)} node pairs...")
    features = []
    for u, v in tqdm(node_pairs):
        feature_vector = []
        # 1. Shortest path (normalized as 1/(1+length))
        try:
            path_length = nx.shortest_path_length(G, source=u, target=v)
            feature_vector.append(1.0 / (1.0 + path_length))
        except:
            feature_vector.append(0.0)
        # 2. Number of common neighbors (normalized as log(1+count))
        try:
            u_neighbors = set(G.neighbors(u))
            v_neighbors = set(G.neighbors(v))
            common_neighbor_count = len(u_neighbors.intersection(v_neighbors))
            feature_vector.append(math.log(1 + common_neighbor_count))
        except:
            feature_vector.append(0.0)
        # 3. Jaccard coefficient
        try:
            union_count = len(u_neighbors.union(v_neighbors))
            jaccard = common_neighbor_count / union_count if union_count > 0 else 0
            feature_vector.append(jaccard)
        except:
            feature_vector.append(0.0)
        # 4. Adamic-Adar index
        try:
            adamic_adar = sum(1.0 / math.log(G.degree(w) + 1) for w in u_neighbors.intersection(v_neighbors))
            feature_vector.append(adamic_adar)
        except:
            feature_vector.append(0.0)
        # 5. Preferential Attachment
        try:
            pref_attachment = G.degree(u) * G.degree(v)
            feature_vector.append(math.log(1 + pref_attachment))
        except:
            feature_vector.append(0.0)
        # 6. Resource Allocation Index
        try:
            resource_allocation = sum(1.0 / G.degree(w) for w in u_neighbors.intersection(v_neighbors))
            feature_vector.append(resource_allocation)
        except:
            feature_vector.append(0.0)
        # 7-9. Count of simple paths for different lengths (length = 2, 3, 4)
        for path_len in range(2, max_path_length+1):
            try:
                paths = list(nx.all_simple_paths(G, source=u, target=v, cutoff=path_len))
                path_count = len([p for p in paths if len(p)-1 == path_len])
                feature_vector.append(math.log(1 + path_count))
            except:
                feature_vector.append(0.0)
        # Pad or truncate to 16 dimensions
        if len(feature_vector) < 16:
            feature_vector.extend([0.0] * (16 - len(feature_vector)))
        else:
            feature_vector = feature_vector[:16]
        features.append(feature_vector)
    return torch.tensor(features, dtype=torch.float)

#############################################
# 2. Define Node Embedding Generation Function (with caching)
#############################################
def get_node_embeddings_with_cache(G, cache_path, embedding_model="text-embedding-3-small"):
    """
    Generate node embeddings and return a new graph with only the 'x' attribute.
    """
    # Create a copy of the graph to avoid modifying the original.
    G_copy = nx.MultiDiGraph()
    G_copy.add_nodes_from(G.nodes(data=True))
    G_copy.add_edges_from([(u, v, k, d.copy()) for u, v, k, d in G.edges(data=True, keys=True)])
    
    embedding_cache = {}
    cache_file = os.path.join(cache_path, f'embedding_cache_{embedding_model.replace("-", "_")}.pkl')
    if os.path.exists(cache_file):
        try:
            with open(cache_file, 'rb') as f:
                embedding_cache = pickle.load(f)
            print(f"Successfully loaded embedding cache, count: {len(embedding_cache)}")
        except Exception as e:
            print(f"Failed to load cache: {e}")
    else:
        print(f"Embedding cache not found: {cache_file}")
    
    from concurrent.futures import ThreadPoolExecutor, as_completed
    nodes_to_process = []
    for node in G_copy.nodes():
        node_text = str(node)
        if node_text not in embedding_cache:
            nodes_to_process.append(node)
    print(f"Number of embeddings to generate: {len(nodes_to_process)}")
    
    def fetch_embedding(node):
        try:
            node_text = str(node)
            emb = tools.get_embedding(node_text, model=embedding_model)
            return node, node_text, emb, None
        except Exception as e:
            return node, str(node), [0.0] * 1536, str(e)
    
    if nodes_to_process:
        new_entries = 0
        with ThreadPoolExecutor(max_workers=32) as executor:
            futures = [executor.submit(fetch_embedding, node) for node in nodes_to_process]
            for future in tqdm(as_completed(futures), total=len(futures), desc="Generating embeddings"):
                try:
                    node, node_text, emb, err = future.result()
                    embedding_cache[node_text] = emb
                    new_entries += 1
                except Exception as e:
                    print(f"Error processing node: {e}")
        print(f"New embeddings added: {new_entries}")
        if new_entries > 0:
            if not os.path.exists(cache_path):
                os.makedirs(cache_path)
            try:
                with open(cache_file, 'wb') as f:
                    pickle.dump(embedding_cache, f)
                print(f"Embedding cache updated successfully, total count: {len(embedding_cache)}")
            except Exception as e:
                print(f"Failed to save cache: {e}")
    
    # Clear node attributes and retain only 'x'
    for node in G_copy.nodes():
        node_text = str(node)
        x_value = embedding_cache.get(node_text, [0.0] * 1536)
        G_copy.nodes[node].clear()
        G_copy.nodes[node]['x'] = x_value
    return G_copy

#############################################
# 3. Define Negative Sample Generation Function (edges not in graph)
#############################################
def generate_negative_edges(G, num_neg_samples, exclude_edges=None):
    """
    Generate edges that are not present in the graph as negative samples.
    """
    if exclude_edges is None:
        exclude_edges = set()
    else:
        exclude_edges = set(exclude_edges)
    nodes = list(G.nodes())
    neg_edges = []
    pbar = tqdm(total=num_neg_samples, desc="Generating negative samples")
    while len(neg_edges) < num_neg_samples:
        u = random.choice(nodes)
        v = random.choice(nodes)
        if u != v and not G.has_edge(u, v) and (u, v) not in exclude_edges:
            neg_edges.append((u, v))
            pbar.update(1)
    pbar.close()
    return neg_edges

#############################################
# 4. Split the original graph by nodes: 70% training, 30% testing (ensure no overlap to avoid data leakage)
#############################################
all_nodes = list(G_aggregated.nodes())
train_nodes, test_nodes = train_test_split(all_nodes, test_size=0.3, random_state=610)
print(f"Number of training nodes: {len(train_nodes)}, testing nodes: {len(test_nodes)}")
# Construct induced subgraphs for training and testing.
G_train_sub = G_aggregated.subgraph(train_nodes).copy()
G_test_sub = G_aggregated.subgraph(test_nodes).copy()
print(f"Training subgraph - nodes: {len(G_train_sub.nodes())}, edges: {len(G_train_sub.edges())}")
print(f"Testing subgraph - nodes: {len(G_test_sub.nodes())}, edges: {len(G_test_sub.edges())}")

#############################################
# 5. Generate Node Embeddings
#############################################
cache_path = "embedding_cache_path"  # Replace with your own path
G_train_embed = get_node_embeddings_with_cache(G_train_sub, cache_path)
G_test_embed = get_node_embeddings_with_cache(G_test_sub, cache_path)

#############################################
# 6. Construct PyG format graph and index mapping
#############################################
def create_pyg_data(G):
    # Clear edge attributes.
    for u, v, k, d in G.edges(keys=True, data=True):
        d.clear()
    data = from_networkx(G, group_node_attrs=None, group_edge_attrs=None)
    data.x = torch.tensor([G.nodes[node]['x'] for node in G.nodes()], dtype=torch.float)
    return data

data_train = create_pyg_data(G_train_embed)
data_test = create_pyg_data(G_test_embed)

train_nodes_list = list(G_train_embed.nodes())
train_node_to_idx = {node: i for i, node in enumerate(train_nodes_list)}

test_nodes_list = list(G_test_embed.nodes())
test_node_to_idx = {node: i for i, node in enumerate(test_nodes_list)}

#############################################
# 7. Construct Positive/Negative Edges and Features
#############################################
def prepare_edge_data(G, node_to_idx):
    pos_edge_pairs = [(u, v) for u, v in G.edges()]
    pos_edge_indices = [(node_to_idx[u], node_to_idx[v]) for u, v in pos_edge_pairs]
    pos_edge_index = torch.tensor(pos_edge_indices, dtype=torch.long).t()
    struct = compute_structural_features(G, pos_edge_pairs)
    return pos_edge_pairs, pos_edge_index, struct

train_pos_pairs, train_pos_index, train_pos_struct = prepare_edge_data(G_train_embed, train_node_to_idx)
test_pos_pairs, test_pos_index, test_pos_struct = prepare_edge_data(G_test_embed, test_node_to_idx)

test_neg_pairs = generate_negative_edges(G_test_embed, len(test_pos_pairs))
test_neg_indices = [(test_node_to_idx[u], test_node_to_idx[v]) for u, v in test_neg_pairs]
test_neg_index = torch.tensor(test_neg_indices, dtype=torch.long).t()
test_neg_struct = compute_structural_features(G_test_embed, test_neg_pairs)

#############################################
# 8. Define Model, Training and Validation Functions
#############################################
# 1) Model Initialization
in_channels = data_train.x.size(1)
hidden_channels = 128
struct_features_dim = train_pos_struct.size(1)

class AdvancedGNNLinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels, struct_features_dim=16, dropout=0.3):
        super(AdvancedGNNLinkPredictor, self).__init__()
        
        # GCN layer
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        
        # GAT layer
        self.gat = GATConv(hidden_channels, hidden_channels, heads=4, dropout=dropout)
        self.bn2 = nn.BatchNorm1d(hidden_channels * 4)
        
        # GraphSAGE layer
        self.sage = SAGEConv(hidden_channels * 4, hidden_channels)
        self.bn3 = nn.BatchNorm1d(hidden_channels)
        
        self.dropout = nn.Dropout(dropout)
        
        # Fusion layer: fuse embeddings of two nodes with the structural features.
        fusion_in_dim = 2 * hidden_channels + struct_features_dim
        self.fusion = nn.Sequential(
            nn.Linear(fusion_in_dim, hidden_channels * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels * 2, hidden_channels)
        )
        
        # Predictor layer
        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels // 2, 1)
        )
        
        print(f"Model configuration: input dimension = {in_channels}, hidden dimension = {hidden_channels}, structural feature dimension = {struct_features_dim}")
    
    def encode(self, x, edge_index):
        # GCN layer
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.elu(x)
        x = self.dropout(x)
        # GAT layer
        x = self.gat(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)
        # GraphSAGE layer
        x = self.sage(x, edge_index)
        x = self.bn3(x)
        x = F.elu(x)
        return x
    
    def decode(self, z, edge_index, structural_features):
        row, col = edge_index
        edge_features = []
        for i in range(edge_index.size(1)):
            source_embedding = z[row[i]]
            target_embedding = z[col[i]]
            edge_struct_feature = structural_features[i]
            combined = torch.cat([source_embedding, target_embedding, edge_struct_feature])
            edge_feature = self.fusion(combined.unsqueeze(0)).squeeze(0)
            edge_features.append(edge_feature)
        edge_features = torch.stack(edge_features)
        scores = self.predictor(edge_features)
        return scores
    
    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        return z

model = AdvancedGNNLinkPredictor(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    struct_features_dim=struct_features_dim,
    dropout=0.3
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)
criterion = nn.BCEWithLogitsLoss()

# 2) Early Stopping
class EarlyStopping:
    def __init__(self, patience=5, verbose=True, delta=0.001, path='best_link_predictor.pt'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} → {val_loss:.6f}), saving model...")
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

early_stopping = EarlyStopping(patience=5, verbose=True, path="best_link_predictor.pt")

# 3) Training function using Hard Negative Sampling
def train_one_epoch():
    model.train()
    optimizer.zero_grad()
    x = data_train.x.to(device)
    edge_index = data_train.edge_index.to(device)
    z = model(x, edge_index)

    # Positive samples
    pos_score = model.decode(z, train_pos_index.to(device), train_pos_struct.to(device))
    num_pos = train_pos_index.size(1)

    # Hard Negative Sampling
    candidate_pool_size = num_pos * 3
    neg_pairs = generate_negative_edges(G_train_embed, candidate_pool_size)
    neg_indices = [(train_node_to_idx[u], train_node_to_idx[v]) for u, v in neg_pairs]
    neg_index = torch.tensor(neg_indices, dtype=torch.long).t().to(device)
    neg_struct = compute_structural_features(G_train_embed, neg_pairs).to(device)

    with torch.no_grad():
        neg_scores = torch.sigmoid(model.decode(z, neg_index, neg_struct)).squeeze()
    _, topk = torch.topk(neg_scores, k=num_pos)
    hard_neg_index = neg_index[:, topk]
    hard_neg_struct = neg_struct[topk]

    neg_score = model.decode(z, hard_neg_index, hard_neg_struct)

    edge_label = torch.cat([torch.ones(num_pos), torch.zeros(num_pos)]).to(device)
    edge_pred = torch.cat([pos_score, neg_score]).squeeze()

    loss = criterion(edge_pred, edge_label)
    loss.backward()
    optimizer.step()
    return loss.item()

# 4) Evaluation function on the test subgraph
@torch.no_grad()
def evaluate_test():
    model.eval()
    x = data_test.x.to(device)
    edge_index = data_test.edge_index.to(device)
    z = model(x, edge_index)
    pos_score = torch.sigmoid(model.decode(z, test_pos_index.to(device), test_pos_struct.to(device))).squeeze()
    neg_score = torch.sigmoid(model.decode(z, test_neg_index.to(device), test_neg_struct.to(device))).squeeze()
    edge_label = torch.cat([
        torch.ones(test_pos_index.size(1)),
        torch.zeros(test_neg_index.size(1))
    ]).to(device)
    edge_pred = torch.cat([pos_score, neg_score], dim=0)
    val_loss = criterion(edge_pred, edge_label).item()
    auc = roc_auc_score(edge_label.cpu().numpy(), edge_pred.cpu().numpy())
    acc = ((edge_pred >= 0.5).float() == edge_label).sum().item() / len(edge_label)
    return val_loss, acc, auc

# 5) Main training loop
num_epochs = 300
train_losses, test_losses, test_aucs, test_accs = [], [], [], []

for epoch in range(num_epochs):
    train_loss = train_one_epoch()
    val_loss, val_acc, val_auc = evaluate_test()

    train_losses.append(train_loss)
    test_losses.append(val_loss)
    test_aucs.append(val_auc)
    test_accs.append(val_acc)

    print(f"Epoch {epoch+1}: TrainLoss={train_loss:.4f}, TestLoss={val_loss:.4f}, Acc={val_acc:.4f}, AUC={val_auc:.4f}")
    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break

# Load the best model
model.load_state_dict(torch.load("best_link_predictor.pt"))


# Evaluation

In [None]:
@torch.no_grad()
def evaluate_test_and_typewise_5fold():
    # Split the test positive edge set (test_pos_pairs) into 5 folds (after shuffling)
    pos_edges = test_pos_pairs.copy()
    random.seed(610)
    random.shuffle(pos_edges)
    pos_folds = np.array_split(pos_edges, 5)
    
    # Construct the evaluation edge set: positive edges + negative edges (negative edges are not in the graph)
    eval_edges = test_pos_pairs + test_neg_pairs
    # Corresponding labels: positives first, then negatives
    eval_labels = np.array([1] * len(test_pos_pairs) + [0] * len(test_neg_pairs))
    
    # Variables for accumulating global statistics (across all folds)
    total_correct_pos, total_correct_neg = 0, 0
    total_fold = 0  # Total number of evaluation edges
    agg_pair_stats = {}  # Accumulate statistics for each type combination
    
    # Lists for recording global AUC and accuracy of each fold
    auc_list = []
    acc_list = []
    
    # Collect correctly predicted edges
    correct_pos_edges = []  # Correctly predicted positive edges (predicted as 1)
    correct_neg_edges = []  # Correctly predicted negative edges (predicted as 0)
    
    # Lists for collecting ROC curve data
    all_y_true = []
    all_y_scores = []
    
    # 5-fold evaluation
    for fold in range(5):
        print(f"\n===== Fold {fold+1}/5 =====")
        # For the current fold, the positive edges to be removed from the test graph
        remove_subset = list(pos_folds[fold])
        
        # Construct a copy of the test graph and remove the positive edges for this fold 
        # (note: negative edges are absent from the graph)
        G_temp = G_test_embed.copy()
        G_temp.remove_edges_from(remove_subset)
        
        # Compute structural features for all evaluation edges using the modified test graph
        struct_tensor = compute_structural_features(G_temp, eval_edges).to(device)
        
        # Construct evaluation edge index tensor (based on test_node_to_idx)
        # Note: test_node_to_idx should correspond to the test subgraph
        def edge_pairs_to_tensor(edge_pairs, node_to_idx):
            idx_pairs = [(node_to_idx[u], node_to_idx[v]) for u, v in edge_pairs]
            return torch.tensor(idx_pairs, dtype=torch.long).t()
        edge_index_tensor = edge_pairs_to_tensor(eval_edges, test_node_to_idx).to(device)
        
        # Model inference (using the global structure of the test subgraph for message passing)
        model.eval()
        x = data_test.x.to(device)
        edge_index = data_test.edge_index.to(device)
        z = model(x, edge_index)
        
        pos_score = torch.sigmoid(
            model.decode(
                z,
                edge_index_tensor[:, :len(test_pos_pairs)].to(device),
                struct_tensor[:len(test_pos_pairs)].to(device)
            )
        ).squeeze().cpu().numpy()
        neg_score = torch.sigmoid(
            model.decode(
                z,
                edge_index_tensor[:, len(test_pos_pairs):].to(device),
                struct_tensor[len(test_pos_pairs):].to(device)
            )
        ).squeeze().cpu().numpy()
        
        # Global predicted labels
        y_pred = np.concatenate([(pos_score >= 0.5).astype(int), (neg_score >= 0.5).astype(int)])
        # Global AUC and accuracy
        y_true = np.concatenate([np.ones(len(pos_score)), np.zeros(len(neg_score))])
        y_scores = np.concatenate([pos_score, neg_score])
        
        # Collect ROC data
        all_y_true.extend(y_true)
        all_y_scores.extend(y_scores)
        
        fold_auc = roc_auc_score(y_true, y_scores)
        fold_acc = (np.sum(y_pred == y_true)) / len(y_true)
        print(f"[Fold {fold+1}] AUC: {fold_auc:.4f}, Acc: {fold_acc:.4f}")
        auc_list.append(fold_auc)
        acc_list.append(fold_acc)
        
        # Accumulate the number of correctly predicted edges (positives and negatives)
        correct_pos = sum(y_pred[:len(pos_score)] == 1)
        correct_neg = sum(y_pred[len(pos_score):] == 0)
        total_correct_pos += correct_pos
        total_correct_neg += correct_neg
        total_fold += len(eval_edges)
        
        # Collect correctly predicted edges
        for i in range(len(test_pos_pairs)):
            if y_pred[i] == 1:  # Correctly predicted positive edge
                correct_pos_edges.append(eval_edges[i])
        for i in range(len(test_pos_pairs), len(eval_edges)):
            if y_pred[i] == 0:  # Correctly predicted negative edge
                correct_neg_edges.append(eval_edges[i])
        
        # === Type Combination Statistics ===
        # Define a function to extract node types from G_aggregated (excluding 'unknown')
        def get_node_types(n):
            if n not in G_aggregated.nodes():
                return []
            attr = G_aggregated.nodes[n].get("type", None)
            if isinstance(attr, list):
                return [str(t).lower() for t in attr if t and str(t).lower() != 'unknown']
            elif isinstance(attr, str):
                return [attr.lower()] if attr.lower() != 'unknown' else []
            else:
                return []
        
        # For each evaluation edge (eval_edges), accumulate prediction statistics for the corresponding type combinations
        for i, (u, v) in enumerate(eval_edges):
            types_u = get_node_types(u)
            types_v = get_node_types(v)
            if not types_u or not types_v:
                continue
            for t1 in types_u:
                for t2 in types_v:
                    pair = tuple(sorted([t1, t2]))
                    if pair not in agg_pair_stats:
                        agg_pair_stats[pair] = {'count': 0, 'tp': 0, 'fp': 0, 'fn': 0}
                    agg_pair_stats[pair]['count'] += 1
                    if i < len(test_pos_pairs):  # positive example
                        if y_pred[i] == 1:
                            agg_pair_stats[pair]['tp'] += 1
                        else:
                            agg_pair_stats[pair]['fn'] += 1
                    else:  # negative example
                        if y_pred[i] == 1:
                            agg_pair_stats[pair]['fp'] += 1

    # Aggregate global statistics across all folds
    global_correct_pos = total_correct_pos
    global_correct_neg = total_correct_neg
    global_total = total_fold
    print("\n====== 5-Fold Global Metrics ======")
    print("Number of correctly predicted positive edges:", global_correct_pos)
    print("Number of correctly predicted negative edges:", global_correct_neg)
    print(f"Average AUC: {np.mean(auc_list):.4f} ± {np.std(auc_list):.4f}")
    print(f"Average Accuracy: {np.mean(acc_list):.4f}")
    
    # Aggregate type combination statistics into a DataFrame (based on accumulated results)
    rows = []
    for pair, stats in agg_pair_stats.items():
        tp, fp, fn = stats['tp'], stats['fp'], stats['fn']
        prec = (tp / (tp + fp) * 100) if (tp + fp) > 0 else 0
        rec = (tp / (tp + fn) * 100) if (tp + fn) > 0 else 0
        f1 = (2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0
        rows.append({
            "Type Combination": f"{pair[0]} - {pair[1]}",
            "Count": stats['count'],
            "Precision": f"{prec:.2f}%",
            "Recall": f"{rec:.2f}%",
            "F1 Score": f"{f1:.2f}%"
        })
    df = pd.DataFrame(rows)
    df = df.sort_values(by="Count", ascending=False).head(30).reset_index(drop=True)
    
    # Plot overall ROC curve
    plt.figure(figsize=(10, 8))
    fpr, tpr, _ = roc_curve(all_y_true, all_y_scores)
    overall_auc = roc_auc_score(all_y_true, all_y_scores)

    # Use a more prominent color and thicker line width
    plt.plot(fpr, tpr, lw=3, label=f'ROC curve (AUC = {overall_auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--', lw=2)

    # Adjust axis range to avoid overlap with borders
    plt.xlim([-0.01, 1.01])
    plt.ylim([-0.01, 1.01])

    # Add margins
    plt.subplots_adjust(left=0.12, right=0.95, bottom=0.12, top=0.95)

    plt.xlabel('False Positive Rate', fontsize=14)
    plt.ylabel('True Positive Rate', fontsize=14)
    plt.title('Receiver Operating Characteristic (ROC)', fontsize=16, fontweight='bold')
    plt.legend(loc="lower right", fontsize=12)

    # Add detailed grid
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.xticks(np.arange(0, 1.1, 0.1))
    plt.yticks(np.arange(0, 1.1, 0.1))

    # Add minor ticks on axes
    plt.minorticks_on()
    plt.tick_params(which='both', direction='in')

    plt.tight_layout()
    plt.savefig('roc_curve.png', dpi=300, bbox_inches='tight')  # Save a high resolution image
    plt.show()
    
    print(f"\nNumber of correctly predicted positive edges: {len(correct_pos_edges)}")
    print(f"Number of correctly predicted negative edges: {len(correct_neg_edges)}")
    
    return df, correct_pos_edges, correct_neg_edges, overall_auc

# Call the 5-fold evaluation function
df_result, correct_pos_edges, correct_neg_edges, overall_auc = evaluate_test_and_typewise_5fold()
print(df_result)
print(f"\nTop 10 examples of correctly predicted positive edges:")
for i, edge in enumerate(correct_pos_edges[:10]):
    print(f"{i+1}. {edge}")

print(f"\nTop 10 examples of correctly predicted negative edges:")
for i, edge in enumerate(correct_neg_edges[:10]):
    print(f"{i+1}. {edge}")
