# Load Graph

In [None]:
import sys, pickle, os, json, re, time, random, logging, pandas as pd, numpy as np, matplotlib.pyplot as plt, seaborn as sns, scipy, sklearn, networkx as nx, importlib
with open('RQ3_KnowledgeGraph.pkl', 'rb') as f:
    G_aggregated = pickle.load(f)

# Model Training

In [None]:
import os
import pickle
import random
import time
import math
import numpy as np
import pandas as pd
import networkx as nx
import torch
import torch.nn.functional as F
from torch import nn, optim
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc, roc_curve
from tqdm import tqdm
import matplotlib.pyplot as plt


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#############################################
# 1. Define function to compute structural features
#############################################
def compute_structural_features(G, node_pairs, max_path_length=3):
    """
    Computes structural features between node pairs, including shortest path, common neighbors,
    Jaccard, Adamic-Adar, preferential attachment, resource allocation index, and counts
    of simple paths of different lengths.
    Returns a feature tensor of shape [len(node_pairs), 16].
    """
    print(f"Computing structural features for {len(node_pairs)} node pairs...")
    features = []
    for u, v in tqdm(node_pairs):
        feature_vector = []
        
        # 1. Shortest path length (normalized: 1/(1+length))
        try:
            path_length = nx.shortest_path_length(G, source=u, target=v)
            path_length_feature = 1.0 / (1.0 + path_length)
        except:
            path_length_feature = 0.0
        feature_vector.append(path_length_feature)
        
        # 2. Common neighbors count (normalized: log(1 + count))
        try:
            u_neighbors = set(G.neighbors(u))
            v_neighbors = set(G.neighbors(v))
            common_neighbor_count = len(u_neighbors.intersection(v_neighbors))
            common_neighbor_feature = math.log(1 + common_neighbor_count)
        except:
            common_neighbor_feature = 0.0
        feature_vector.append(common_neighbor_feature)
        
        # 3. Jaccard coefficient
        try:
            union_count = len(u_neighbors.union(v_neighbors))
            jaccard = common_neighbor_count / union_count if union_count > 0 else 0
        except:
            jaccard = 0.0
        feature_vector.append(jaccard)
        
        # 4. Adamic-Adar index
        try:
            adamic_adar = sum(1.0 / math.log(G.degree(w) + 1) for w in u_neighbors.intersection(v_neighbors))
        except:
            adamic_adar = 0.0
        feature_vector.append(adamic_adar)
        
        # 5. Preferential attachment
        try:
            pref_attachment = G.degree(u) * G.degree(v)
            pref_attachment_feature = math.log(1 + pref_attachment)
        except:
            pref_attachment_feature = 0.0
        feature_vector.append(pref_attachment_feature)
        
        # 6. Resource allocation index
        try:
            resource_allocation = sum(1.0 / G.degree(w) for w in u_neighbors.intersection(v_neighbors))
        except:
            resource_allocation = 0.0
        feature_vector.append(resource_allocation)
        
        # 7-9. Simple path counts for lengths 2, 3, 4
        for path_len in range(2, max_path_length + 1):
            try:
                paths = list(nx.all_simple_paths(G, source=u, target=v, cutoff=path_len))
                path_count = len([p for p in paths if len(p) - 1 == path_len])
                path_count_feature = math.log(1 + path_count)
            except:
                path_count_feature = 0.0
            feature_vector.append(path_count_feature)
        
        # Pad to a fixed 16 dimensions
        pad_length = 16
        if len(feature_vector) < pad_length:
            feature_vector.extend([0.0] * (pad_length - len(feature_vector)))
        else:
            feature_vector = feature_vector[:pad_length]
            
        features.append(feature_vector)
    
    return torch.tensor(features, dtype=torch.float)

#############################################
# 2. Define function to generate node embeddings (using cache)
#############################################
def get_node_embeddings_with_cache(G, cache_path, embedding_model="text-embedding-3-small"):
    """
    Generates node embeddings and returns a new graph where nodes only retain the 'x' attribute.
    """
    G_copy = nx.MultiDiGraph()
    G_copy.add_nodes_from(G.nodes(data=True))
    G_copy.add_edges_from([(u, v, k, d.copy()) for u, v, k, d in G.edges(data=True, keys=True)])
    
    embedding_cache = {}
    cache_file = os.path.join(cache_path, f'embedding_cache_{embedding_model.replace("-", "_")}.pkl')
    if os.path.exists(cache_file):
        try:
            with open(cache_file, 'rb') as f:
                embedding_cache = pickle.load(f)
            print(f"Successfully loaded embedding cache, count: {len(embedding_cache)}")
        except Exception as e:
            print(f"Failed to load cache: {e}")
    else:
        print(f"Embedding cache not found: {cache_file}")
    
    from concurrent.futures import ThreadPoolExecutor, as_completed
    nodes_to_process = []
    for node in G_copy.nodes():
        node_text = str(node)
        if node_text not in embedding_cache:
            nodes_to_process.append(node)
    print(f"Number of new embeddings to generate: {len(nodes_to_process)}")
    
    def fetch_embedding(node):
        try:
            node_text = str(node)
            emb = tools.get_embedding(node_text, model=embedding_model)
            return node, node_text, emb, None
        except Exception as e:
            return node, str(node), [0.0]*1536, str(e)
    
    if nodes_to_process:
        new_entries = 0
        with ThreadPoolExecutor(max_workers=32) as executor:
            futures = [executor.submit(fetch_embedding, node) for node in nodes_to_process]
            for future in tqdm(as_completed(futures), total=len(futures), desc="Generating embeddings"):
                try:
                    node, node_text, emb, err = future.result()
                    embedding_cache[node_text] = emb
                    new_entries += 1
                except Exception as e:
                    print(f"Error processing node: {e}")
        print(f"Number of new embeddings added: {new_entries}")
        if new_entries > 0:
            if not os.path.exists(cache_path):
                os.makedirs(cache_path)
            try:
                with open(cache_file, 'wb') as f:
                    pickle.dump(embedding_cache, f)
                print(f"Successfully updated embedding cache, total count: {len(embedding_cache)}")
            except Exception as e:
                print(f"Failed to save cache: {e}")
    
    # Clean node attributes, keeping only 'x'
    for node in G_copy.nodes():
        node_text = str(node)
        x_value = embedding_cache.get(node_text, [0.0]*1536)
        G_copy.nodes[node].clear()
        G_copy.nodes[node]['x'] = x_value
    return G_copy

#############################################
# 3. Define function to generate negative samples (edges not in the graph)
#############################################
def generate_negative_edges(G, num_neg_samples, exclude_edges=None):
    """
    Generates edges not present in the graph as negative samples.
    """
    if exclude_edges is None:
        exclude_edges = set()
    else:
        exclude_edges = set(exclude_edges)
    nodes = list(G.nodes())
    neg_edges = []
    pbar = tqdm(total=num_neg_samples, desc="Generating negative samples")
    while len(neg_edges) < num_neg_samples:
        u = random.choice(nodes)
        v = random.choice(nodes)
        if u != v and not G.has_edge(u, v) and (u, v) not in exclude_edges:
            neg_edges.append((u, v))
            pbar.update(1)
    pbar.close()
    return neg_edges

#############################################
# 4. Split training and test sets based on the "source" edge attribute
#############################################
# Extract all sources from the edge attributes of the original graph G_aggregated
all_edges = list(G_aggregated.edges(data=True))
all_sources = set()
for u, v, d in all_edges:
    if 'source' in d:
        all_sources.add(d['source'])
all_sources = list(all_sources)
random.seed(610)
random.shuffle(all_sources)
train_source_count = int(0.7 * len(all_sources))
train_sources = set(all_sources[:train_source_count])
test_sources = set(all_sources[train_source_count:])
print(f"Number of sources for training: {len(train_sources)}, for testing: {len(test_sources)}")

# Partition edges based on source
train_edges = [ (u, v, d) for u, v, d in all_edges if d.get('source') in train_sources ]
test_edges  = [ (u, v, d) for u, v, d in all_edges if d.get('source') in test_sources ]
print(f"Number of training edges: {len(train_edges)}, Number of test edges: {len(test_edges)}")

# Construct the training graph: extract all involved nodes from the training edges
train_nodes=set()
for u, v, d in train_edges:
    train_nodes.add(u)
    train_nodes.add(v)
print(f"Number of training nodes: {len(train_nodes)}")
G_train = nx.MultiDiGraph()
G_train.add_nodes_from(train_nodes)
G_train.add_edges_from(train_edges)

# Construct the test graph: extract all involved nodes from the test edges
test_nodes = set()
for u, v, d in test_edges:
    test_nodes.add(u)
    test_nodes.add(v)
print(f"Number of test nodes: {len(test_nodes)}")
G_test = nx.MultiDiGraph()
G_test.add_nodes_from(test_nodes)
G_test.add_edges_from(test_edges)

#############################################
# 5. Generate Node Embeddings
#############################################
cache_path = '/path/to/embedding_cache'
print("Generating node embeddings for the training graph...")
G_train_embed = get_node_embeddings_with_cache(G_train, cache_path, embedding_model="text-embedding-3-small")
print("Generating node embeddings for the test graph...")
G_test_embed = get_node_embeddings_with_cache(G_test, cache_path, embedding_model="text-embedding-3-small")

#############################################
# 6. Convert graphs to PyG Data and create node index mappings
#############################################
def create_pyg_data(G_embed):
    # Clear edge attributes (to ensure they are empty)
    for u, v, k, d in G_embed.edges(keys=True, data=True):
        d.clear()
    data = from_networkx(G_embed, group_node_attrs=None, group_edge_attrs=None)
    data.x = torch.tensor([G_embed.nodes[node]['x'] for node in G_embed.nodes()], dtype=torch.float)
    return data

data_train = create_pyg_data(G_train_embed)
data_test = create_pyg_data(G_test_embed)

# Create node-to-index mappings (independently for train and test)
train_nodes_list = list(G_train_embed.nodes())
train_node_to_idx = {node: i for i, node in enumerate(train_nodes_list)}

test_nodes_list = list(G_test_embed.nodes())
test_node_to_idx = {node: i for i, node in enumerate(test_nodes_list)}

#############################################
# 7. Prepare positive edge samples and structural features (separately for train and test)
#############################################
def prepare_edge_data(G_embed, node_to_idx):
    pos_edge_pairs = [(u, v) for u, v in G_embed.edges()]
    pos_edge_indices = [(node_to_idx[u], node_to_idx[v]) for u, v in pos_edge_pairs]
    pos_edge_index = torch.tensor(pos_edge_indices, dtype=torch.long).t()
    structural = compute_structural_features(G_embed, pos_edge_pairs)
    return pos_edge_pairs, pos_edge_index, structural

print("Preparing training positive samples...")
train_pos_edge_pairs, train_pos_edge_index, train_pos_structural = prepare_edge_data(G_train_embed, train_node_to_idx)
print(f"Number of training positive samples: {len(train_pos_edge_pairs)}")

print("Preparing test positive samples...")
test_pos_edge_pairs, test_pos_edge_index, test_pos_structural = prepare_edge_data(G_test_embed, test_node_to_idx)
print(f"Number of test positive samples: {len(test_pos_edge_pairs)}")

#############################################
# 8. Generate negative samples for testing (via random sampling)
#############################################
print("Generating test negative samples...")
test_neg_edge_pairs = generate_negative_edges(G_test_embed, len(test_pos_edge_pairs))
test_neg_edge_indices = [(test_node_to_idx[u], test_node_to_idx[v]) for u, v in test_neg_edge_pairs]
test_neg_edge_index = torch.tensor(test_neg_edge_indices, dtype=torch.long).t()
test_neg_structural = compute_structural_features(G_test_embed, test_neg_edge_pairs)
print(f"Number of test negative samples: {len(test_neg_edge_pairs)}")

#############################################
# 9. Define the Advanced GNN Link Predictor Model
#############################################
class AdvancedGNNLinkPredictor(nn.Module):
    def __init__(self, in_channels, hidden_channels, struct_features_dim=16, dropout=0.5):
        super(AdvancedGNNLinkPredictor, self).__init__()
        
        # GCN layer
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.bn1 = nn.BatchNorm1d(hidden_channels)
        
        # GAT layer
        self.gat = GATConv(hidden_channels, hidden_channels, heads=4, dropout=dropout)
        self.bn2 = nn.BatchNorm1d(hidden_channels * 4)
        
        # GraphSAGE layer
        self.sage = SAGEConv(hidden_channels * 4, hidden_channels)
        self.bn3 = nn.BatchNorm1d(hidden_channels)
        
        self.dropout = nn.Dropout(dropout)
        
        # Fusion layer: Concatenates and fuses the embeddings of two nodes with their structural features
        fusion_in_dim = 2 * hidden_channels + struct_features_dim
        self.fusion = nn.Sequential(
            nn.Linear(fusion_in_dim, hidden_channels * 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels * 2, hidden_channels)
        )
        
        # Prediction layer
        self.predictor = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels // 2, 1)
        )
        
        print(f"Model configuration: Input dim = {in_channels}, Hidden dim = {hidden_channels}, Struct features dim = {struct_features_dim}")
    
    def encode(self, x, edge_index):
        # GCN
        x = self.conv1(x, edge_index)
        x = self.bn1(x)
        x = F.elu(x)
        x = self.dropout(x)
        # GAT
        x = self.gat(x, edge_index)
        x = self.bn2(x)
        x = F.elu(x)
        x = self.dropout(x)
        # GraphSAGE
        x = self.sage(x, edge_index)
        x = self.bn3(x)
        x = F.elu(x)
        return x
    
    def decode(self, z, edge_index, structural_features):
        row, col = edge_index
        edge_features = []
        for i in range(edge_index.size(1)):
            source_embedding = z[row[i]]
            target_embedding = z[col[i]]
            edge_struct_feature = structural_features[i]
            combined = torch.cat([source_embedding, target_embedding, edge_struct_feature])
            edge_feature = self.fusion(combined.unsqueeze(0)).squeeze(0)
            edge_features.append(edge_feature)
        edge_features = torch.stack(edge_features)
        scores = self.predictor(edge_features)
        return scores
    
    def forward(self, x, edge_index):
        z = self.encode(x, edge_index)
        return z

in_channels = data_train.x.size(1)
hidden_channels = 128
struct_features_dim = train_pos_structural.size(1)

torch.manual_seed(610)
np.random.seed(610)
random.seed(610)

model = AdvancedGNNLinkPredictor(
    in_channels=in_channels, 
    hidden_channels=hidden_channels,
    struct_features_dim=struct_features_dim,
    dropout=0.3
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-3)
criterion = nn.BCEWithLogitsLoss()

#############################################
# 10. Define EarlyStopping
#############################################
class EarlyStopping:
    def __init__(self, patience=3, verbose=True, delta=0.001, path='best_link_predictor_inductive.pt'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} -> {val_loss:.6f}). Saving model ...")
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

early_stopping = EarlyStopping(patience=3, verbose=True, delta=0.001, path='best_link_predictor_inductive.pt')

#############################################
# 11. Define the training function (with Hard Negative Sampling)
#############################################
def train_func():
    model.train()
    optimizer.zero_grad()
    # Use the training graph
    x = data_train.x.to(device)
    edge_index = data_train.edge_index.to(device)
    z = model(x, edge_index)
    
    # Positive sample calculation
    pos_score = model.decode(z, train_pos_edge_index.to(device), train_pos_structural.to(device))
    num_pos = train_pos_edge_index.size(1)
    
    # Generate a pool of candidate negative samples (3x the number of positive samples)
    candidate_pool_size = num_pos * 3
    candidate_neg_pairs = generate_negative_edges(G_train_embed, candidate_pool_size)
    candidate_neg_indices = [(train_node_to_idx[u], train_node_to_idx[v]) for u, v in candidate_neg_pairs]
    candidate_neg_edge_index = torch.tensor(candidate_neg_indices, dtype=torch.long).t().to(device)
    candidate_neg_structural = compute_structural_features(G_train_embed, candidate_neg_pairs).to(device)
    
    # Use the current model to score candidate negative samples (without backpropagation)
    with torch.no_grad():
        candidate_scores = torch.sigmoid(model.decode(z, candidate_neg_edge_index, candidate_neg_structural)).squeeze()
    # Select the highest-scoring ones as Hard Negatives (same number as positive samples)
    _, hard_indices = torch.topk(candidate_scores, k=num_pos)
    hard_neg_edge_index = candidate_neg_edge_index[:, hard_indices]
    hard_neg_structural = candidate_neg_structural[hard_indices]
    
    neg_score = model.decode(z, hard_neg_edge_index, hard_neg_structural)
    
    # Construct labels (1 for positive, 0 for negative)
    edge_label = torch.cat([torch.ones(num_pos), torch.zeros(num_pos)], dim=0).to(device)
    edge_pred = torch.cat([pos_score, neg_score], dim=0).squeeze()
    
    loss = criterion(edge_pred, edge_label)
    loss.backward()
    optimizer.step()
    return loss.item()

#############################################
# 12. Define the test evaluation function (evaluate link prediction on the test graph)
#############################################
@torch.no_grad()
def evaluate_test():
    model.eval()
    x = data_test.x.to(device)
    edge_index = data_test.edge_index.to(device)
    z = model(x, edge_index)
    pos_score = torch.sigmoid(model.decode(z, test_pos_edge_index.to(device), test_pos_structural.to(device))).squeeze()
    neg_score = torch.sigmoid(model.decode(z, test_neg_edge_index.to(device), test_neg_structural.to(device))).squeeze()
    
    edge_label = torch.cat([
        torch.ones(test_pos_edge_index.size(1)),
        torch.zeros(test_neg_edge_index.size(1))
    ], dim=0).to(device)
    edge_pred = torch.cat([pos_score, neg_score], dim=0)
    
    auc_roc = roc_auc_score(edge_label.cpu().numpy(), edge_pred.cpu().numpy())
    preds = (edge_pred >= 0.5).float()
    acc = (preds == edge_label).sum().item() / len(edge_label)
    loss_val = criterion(edge_pred, edge_label).item()
    return loss_val, acc, auc_roc

#############################################
# 13. Training Loop
#############################################
num_epochs = 300
train_losses = []
test_losses = []
test_accs = []
test_aucs = []
best_test_auc = 0

for epoch in range(num_epochs):
    loss = train_func()
    train_losses.append(loss)
    
    t_loss, t_acc, t_auc = evaluate_test()
    test_losses.append(t_loss)
    test_accs.append(t_acc)
    test_aucs.append(t_auc)
    
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {loss:.4f}, Test Loss: {t_loss:.4f}, Test Acc: {t_acc:.4f}, Test AUC: {t_auc:.4f}")
    
    early_stopping(t_loss, model)
    if early_stopping.early_stop:
        print(f"Early stopping at epoch {epoch+1}")
        break

print(f"Training complete, Best Test AUC: {max(test_aucs):.4f}")
model.load_state_dict(torch.load('best_link_predictor_inductive.pt'))

#############################################
# 14. Print final evaluation results on the test set
#############################################
final_test_loss, final_test_acc, final_test_auc = evaluate_test()
print("\nFinal evaluation results on the test set:")
print(f"Test Set - Final Loss: {final_test_loss:.4f}, Final Accuracy: {final_test_acc:.4f}, Final AUC: {final_test_auc:.4f}")

#############################################
# 15. Plot training curves
#############################################
plt.figure(figsize=(15,5))
plt.subplot(1,3,1)
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Train vs. Test Loss")

plt.subplot(1,3,2)
plt.plot(test_accs, label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Test Accuracy")

plt.subplot(1,3,3)
plt.plot(test_aucs, label="Test AUC")
plt.xlabel("Epoch")
plt.ylabel("AUC")
plt.legend()
plt.title("Test AUC")
plt.tight_layout()
plt.show()

# Evaluation

In [None]:
import torch
import numpy as np
import random
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

@torch.no_grad()
def evaluate_link_ranking_filtered(model, G_test_embed, test_pos_edge_pairs, all_true_edges_set, test_node_to_idx, data_test, device, k_values=[1, 3, 10], sample_size=None):
    """
    Performs ranking evaluation (MRR, Hits@K) for the link prediction model using the strict 'Filtered' setting.

    Args:
    - model: The trained PyTorch model.
    - G_test_embed: The test graph containing node embeddings.
    - test_pos_edge_pairs: A list of positive edge pairs from the test set to be evaluated.
    - all_true_edges_set: A set containing all true edges (u, v) from the entire dataset (train + test).
    - test_node_to_idx: A dictionary mapping node names to their indices in the test set.
    - data_test: The test graph data in PyTorch Geometric format.
    - device: 'cuda' or 'cpu'.
    - k_values: A list of K values for calculating Hits@K.
    - sample_size (int, optional): The number of samples to randomly evaluate.

    Returns:
    - A dictionary containing the evaluation results for MRR and Hits@K.
    """
    model.eval()
    z = model.encode(data_test.x.to(device), data_test.edge_index.to(device))

    all_test_nodes = list(test_node_to_idx.keys())
    all_test_node_indices = list(test_node_to_idx.values())
    
    reciprocal_ranks = []
    hits_at_k = {k: 0 for k in k_values}
    
    if sample_size is not None and sample_size < len(test_pos_edge_pairs):
        print(f"Randomly sampling {sample_size} positive edges from {len(test_pos_edge_pairs)} for evaluation...")
        edges_to_evaluate = random.sample(test_pos_edge_pairs, sample_size)
    else:
        edges_to_evaluate = test_pos_edge_pairs
        print(f"Evaluating all {len(edges_to_evaluate)} positive edges...")

    for u_node, v_node_true in tqdm(edges_to_evaluate, desc="Filtered Link Ranking"):
        
        u_idx = test_node_to_idx[u_node]
        v_idx_true = test_node_to_idx[v_node_true]

        candidate_v_nodes = all_test_nodes
        candidate_v_indices = all_test_node_indices
        
        u_nodes_repeated = [u_node] * len(candidate_v_nodes)
        batch_edge_pairs = list(zip(u_nodes_repeated, candidate_v_nodes))
        
        batch_edge_index = torch.tensor(
            [[u_idx] * len(candidate_v_indices), candidate_v_indices],
            dtype=torch.long
        ).to(device)
        
        batch_struct_features = compute_structural_features(G_test_embed, batch_edge_pairs).to(device)
        batch_scores = model.decode(z, batch_edge_index, batch_struct_features).squeeze()

        true_v_position = candidate_v_indices.index(v_idx_true)
        true_score = batch_scores[true_v_position]

        # --- Core modification for the Filtered Setting ---
        # 1. Find all other true tail entities besides the one currently being evaluated.
        filter_indices = []
        for i, v_cand_node in enumerate(candidate_v_nodes):
            # If (u, v_cand) is a known true edge and v_cand is not the current v_true we are evaluating.
            if (u_node, v_cand_node) in all_true_edges_set and v_cand_node != v_node_true:
                filter_indices.append(i)
        
        # 2. Set the scores of these other true tail entities to a very low value so they don't affect the rank calculation.
        if filter_indices:
            batch_scores[torch.tensor(filter_indices)] = -float('inf')
        # --- End of modification ---
        
        rank = (batch_scores >= true_score).sum().item()
        
        reciprocal_ranks.append(1.0 / rank)
        for k in k_values:
            if rank <= k:
                hits_at_k[k] += 1

    mrr = np.mean(reciprocal_ranks)
    for k in k_values:
        hits_at_k[k] /= len(edges_to_evaluate)

    results = {'MRR': mrr}
    for k in k_values:
        results[f'Hits@{k}'] = hits_at_k[k]
        
    return results

# 1. Create a set containing all true edges.
#    (Assumes train_pos_edge_pairs and test_pos_edge_pairs are already defined)
all_true_edges_set = set(train_pos_edge_pairs) | set(test_pos_edge_pairs)

# 2. Call the new evaluation function.
ranking_results_filtered = evaluate_link_ranking_filtered(
    model=model,
    G_test_embed=G_test_embed,
    test_pos_edge_pairs=test_pos_edge_pairs,
    all_true_edges_set=all_true_edges_set, # Pass the set of all true edges
    test_node_to_idx=test_node_to_idx,
    data_test=data_test,
    device=device,
    k_values=[1, 3, 10],
    sample_size=100  # Still can use sampling to speed up
)

# 3. Print the results.
print("\n===== Link Ranking Evaluation Results (Filtered Setting) =====")
print(f"MRR (Mean Reciprocal Rank): {ranking_results_filtered['MRR']:.4f}")
for k in [1, 3, 10]:
    print(f"Hits@{k}: {ranking_results_filtered[f'Hits@{k}'] * 100:.2f}%")