In [1]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [2]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import GCNConv, GATConv, GINConv
from torch_geometric.datasets import CoraFull, DeezerEurope
from torch_geometric.utils import train_test_split_edges, negative_sampling, subgraph
from torch_geometric.utils import to_dense_adj, to_undirected, train_test_split_edges
from sklearn.model_selection import train_test_split

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

corafull = CoraFull(root='CoraFull')[0]
deezerEurope = DeezerEurope(root='DeezerEurope')[0]

datasets = {'CoraFull': corafull.to(device), 'DeezerEurope': deezerEurope.to(device)}

Downloading https://github.com/abojchevski/graph2gauss/raw/master/data/cora.npz
Processing...
Done!
Downloading https://graphmining.ai/datasets/ptg/deezer_europe.npz
Processing...
Done!


In [3]:
print (corafull)
print (deezerEurope)

Data(x=[19793, 8710], edge_index=[2, 126842], y=[19793])
Data(x=[28281, 128], edge_index=[2, 185504], y=[28281])


In [4]:
from sklearn.metrics import pairwise_distances
import numpy as np

def compute_nodes_in_triangle(edge_index, num_nodes):
    """
    Identify nodes in Q: nodes part of at least one triangle in the graph.
    """
    adj = to_dense_adj(edge_index, max_num_nodes=num_nodes).squeeze(0)
    triangles = torch.mm(adj, adj.T) * adj  # Triangle adjacency
    Q = torch.nonzero(torch.sum(triangles, dim=1)).squeeze()
    return Q[:1000]

In [5]:
# Split Datasets into train, test and validation
train_ratio, val_ratio, test_ratio = 0.6, 0.2, 0.2

num_nodes = {}
edge_index = {}

train_nodes = {}
train_data = {}
train_edge_index = {}

test_nodes = {}
test_data = {}
test_edge_index = {}
Q = {}

val_nodes = {}
val_data = {}
val_edge_index = {}
val_neg_edge_index = {}

for dataset_name in datasets:
    num_nodes[dataset_name] = datasets[dataset_name].num_nodes
    edge_index[dataset_name] = datasets[dataset_name].edge_index

    edge_splits = train_test_split_edges(datasets[dataset_name], test_ratio=test_ratio, val_ratio=val_ratio)
    
    train_edge_index[dataset_name] = to_undirected(edge_splits.train_pos_edge_index)
    test_edge_index[dataset_name] = to_undirected(edge_splits.test_pos_edge_index)
    val_edge_index[dataset_name] = to_undirected(edge_splits.val_pos_edge_index)

    train_nodes[dataset_name] = torch.sort(torch.unique(train_edge_index[dataset_name]))[0].to(device)
    train_data[dataset_name] = datasets[dataset_name].x[train_nodes[dataset_name]].to(device)
    train_edge_index[dataset_name] = subgraph(train_nodes[dataset_name], train_edge_index[dataset_name], relabel_nodes=True)[0].to(device)

    
    val_nodes[dataset_name] = torch.sort(torch.unique(val_edge_index[dataset_name]))[0].to(device)
    val_data[dataset_name] = datasets[dataset_name].x[val_nodes[dataset_name]].to(device)
    val_edge_index[dataset_name] = subgraph(val_nodes[dataset_name], val_edge_index[dataset_name], relabel_nodes=True)[0].to(device)

    val_neg_edge_index[dataset_name] = negative_sampling(edge_index=val_edge_index[dataset_name], 
                                                         num_nodes=val_nodes[dataset_name].shape[0],
                                                         num_neg_samples=val_edge_index[dataset_name].shape[1]).to(device)
    
    
    test_nodes[dataset_name] = torch.sort(torch.unique(test_edge_index[dataset_name]))[0].to(device)
    test_data[dataset_name] = datasets[dataset_name].x[test_nodes[dataset_name]].to(device)
    test_edge_index[dataset_name] = subgraph(test_nodes[dataset_name], test_edge_index[dataset_name], relabel_nodes=True)[0].to(device)

    Q[dataset_name] = compute_nodes_in_triangle(test_edge_index[dataset_name], test_nodes[dataset_name].shape[0])



In [6]:
print (datasets['CoraFull'])

Data(x=[19793, 8710], y=[19793], val_pos_edge_index=[2, 12684], test_pos_edge_index=[2, 12684], train_pos_edge_index=[2, 76106], train_neg_adj_mask=[19793, 19793], val_neg_edge_index=[2, 12684], test_neg_edge_index=[2, 12684])


In [7]:
print (train_edge_index['CoraFull'].shape)
print (test_edge_index['CoraFull'].shape)
print (val_edge_index['CoraFull'].shape)

print (train_edge_index['DeezerEurope'].shape)
print (test_edge_index['DeezerEurope'].shape)
print (val_edge_index['DeezerEurope'].shape)

torch.Size([2, 76106])
torch.Size([2, 25368])
torch.Size([2, 25368])
torch.Size([2, 111304])
torch.Size([2, 37100])
torch.Size([2, 37100])


In [8]:
# GNN Model Definition
class GNNModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=64, out_channels=64, gnn_type="GCN"):
        super(GNNModel, self).__init__()
        if gnn_type == "GCN":
            self.conv1 = GCNConv(in_channels, hidden_channels)
            self.conv2 = GCNConv(hidden_channels, out_channels)
        elif gnn_type == "GAT":
            self.conv1 = GATConv(in_channels, hidden_channels)
            self.conv2 = GATConv(hidden_channels, out_channels)
        elif gnn_type == "GIN":
            self.conv1 = GINConv(torch.nn.Sequential(
                torch.nn.Linear(in_channels, hidden_channels),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_channels, hidden_channels)
            ))
            self.conv2 = GINConv(torch.nn.Sequential(
                torch.nn.Linear(hidden_channels, out_channels),
                torch.nn.ReLU(),
                torch.nn.Linear(out_channels, out_channels)
            ))
        else:
            raise ValueError("Unsupported GNN type")

    def forward(self, input, edge_index):
        hidden_output = F.relu(self.conv1(input, edge_index))
        output = self.conv2(hidden_output, edge_index)
        return output

In [9]:
def auc_loss(node_embeddings, pos_edge_index, neg_edge_index, margin=1.0):
    # Positive edge scores
    pos_u, pos_v = pos_edge_index
    pos_scores = F.cosine_similarity(node_embeddings[pos_u], node_embeddings[pos_v])

    # Negative edge scores
    neg_u, neg_v = neg_edge_index
    neg_scores = F.cosine_similarity(node_embeddings[neg_u], node_embeddings[neg_v])
    
    # Margin-based ranking loss
    loss = F.relu(margin + neg_scores - pos_scores).mean()
    return loss

In [10]:
from rich.console import Console
from rich.table import Table
from rich.panel import Panel

console = Console()

def train(model, optimizer, train_data, train_edge_index, val_data, val_edge_index, val_neg_edge_index, epochs=200, val_interval=20):
    model.to(device)
    model.train()

    for epoch in range(1, epochs + 1):
        train_neg_edge_index = negative_sampling(edge_index=train_edge_index, 
                                                 num_nodes=train_data.shape[0], 
                                                 num_neg_samples=train_edge_index.shape[1]).to(device)

        node_embeddings = model(train_data, train_edge_index)
        loss = auc_loss(node_embeddings, train_edge_index, train_neg_edge_index)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % val_interval == 0:
            model.eval()
            with torch.no_grad():
                node_embeddings = model(val_data, val_edge_index)
                val_loss = auc_loss(node_embeddings, val_edge_index, val_neg_edge_index)
            model.train()

            console.print(f"Epoch {epoch:3d} | [cyan]Train Loss:[/] {loss:.5f} | [green]Val Loss:[/] {val_loss:.5f}")

In [11]:
models = {}

for dataset_name in datasets:
    models[dataset_name] = {'GCN': None, 'GAT': None, 'GIN': None}
    
    for gnn_model in models[dataset_name]:
        console.rule(f"[bold blue]Dataset: {dataset_name}, GNN Model: {gnn_model}")
        
        model = GNNModel(datasets[dataset_name].num_features, gnn_type=gnn_model)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    
        train(model, optimizer, train_data[dataset_name], train_edge_index[dataset_name], 
              val_data[dataset_name], val_edge_index[dataset_name], val_neg_edge_index[dataset_name])

        models[dataset_name][gnn_model] = model

In [12]:
import numpy as np

def get_evaluation(scores, labels, K_values):
    ranked_nodes = torch.argsort(scores, descending=True)
    rankings = torch.argsort(ranked_nodes)

    precision = {}
    for k in K_values:
        precision[k] = (labels[ranked_nodes[:k]].sum() / k).item()
    mrr = 0
    if torch.sum(labels) != 0:
        rank_of_first_positive = torch.min(rankings[torch.nonzero(labels)]).item() + 1
        mrr = 1 / rank_of_first_positive

    return precision, mrr

In [13]:
def evaluate_inference_basecase(model, test_data, test_edge_index, Q, K_values):
    model.eval()
    with torch.no_grad():
        precision_scores = {k: [] for k in K_values}
        mrr_scores = []
        adj = to_dense_adj(test_edge_index, max_num_nodes=test_data.shape[0]).squeeze(0)
        degree = torch.sum(adj, dim=1)
        
        embeddings = model(test_data, test_edge_index)
        
        for q in Q:
            scores = F.cosine_similarity(embeddings[q].unsqueeze(0), embeddings, dim=1)
            scores[q] = float('-inf')
            labels = adj[q]
    
            p, m = get_evaluation(scores, labels, K_values)
            for k in K_values:
                precision_scores[k].append(p[k])
            mrr_scores.append(m)
        
        # Average metrics
        precision_avg = {k: np.mean(precision_scores[k]) for k in K_values}
        mrr_avg = np.mean(mrr_scores)
        
    return precision_avg, mrr_avg

In [14]:
def adamic_adar(test_data, test_edge_index, Q, K_values):
    precision_scores = {k: [] for k in K_values}
    mrr_scores = []
    adj = to_dense_adj(test_edge_index, max_num_nodes=test_data.shape[0]).squeeze(0).to(device)
    degree = torch.sum(adj, dim=1).to(device)
    
    for q in Q:
        scores = adj[q] * adj
        scores = (1 / torch.log(degree.unsqueeze(0) * (scores > 0))).sum(dim=1)
        # scores = torch.stack([(1 / torch.log(degree[row > 0])).sum() for row in scores])
        scores[q] = float('-inf')
        labels = adj[q]
        
        p, m = get_evaluation(scores, labels, K_values)
        for k in K_values:
            precision_scores[k].append(p[k])
        mrr_scores.append(m)
    
    # Average metrics
    precision_avg = {k: np.mean(precision_scores[k]) for k in K_values}
    mrr_avg = np.mean(mrr_scores)
    
    return precision_avg, mrr_avg

def common_neighbors(test_data, test_edge_index, Q, K_values):
    precision_scores = {k: [] for k in K_values}
    mrr_scores = []
    adj = to_dense_adj(test_edge_index, max_num_nodes=test_data.shape[0]).squeeze(0).to(device)
    degree = torch.sum(adj, dim=1).to(device)

    for q in Q:
        scores = torch.sum(adj[q] * adj, dim=1)
        scores[q] = float('-inf')
        labels = adj[q]

        p, m = get_evaluation(scores, labels, K_values)
        for k in K_values:
            precision_scores[k].append(p[k])
        mrr_scores.append(m)
    
    # Average metrics
    precision_avg = {k: np.mean(precision_scores[k]) for k in K_values}
    mrr_avg = np.mean(mrr_scores)
    
    return precision_avg, mrr_avg

# **NORMAL INFERENCE**  

In [15]:
import time

K_values = [1, 5, 10]

for dataset_name in datasets:
    console.rule(f"[bold blue]Dataset: {dataset_name}")

    table = Table(show_header=True, header_style="bold green")
    table.add_column("GNN Models", style="dim", justify="center")
    for k in K_values:
        table.add_column(f"Precision@{k}", style="bold magenta", justify="center")
    table.add_column("MRR", style="bold magenta", justify="center")
    table.add_column("Inference Time (s)", style="bold red", justify="center")

    for gnn_model in models[dataset_name]:
        start_time = time.time()
        model = models[dataset_name][gnn_model]
        precision_avg, mrr_avg = evaluate_inference_basecase(model, test_data[dataset_name], test_edge_index[dataset_name], Q[dataset_name], K_values)
        end_time = time.time()
        
        row = [gnn_model]
        for k in K_values:
            row.append(f"{precision_avg[k]:.5f}")
        row.append(f"{mrr_avg:.5f}")
        row.append(f"{(end_time - start_time):.5f}")
        table.add_row(*row)

    
    # Adamic Adar
    start_time = time.time()
    precision_avg, mrr_avg = adamic_adar(test_data[dataset_name], test_edge_index[dataset_name], Q[dataset_name], K_values)
    end_time = time.time()
    
    row = ["Adamic-Adar"]
    for k in K_values:
        row.append(f"{precision_avg[k]:.5f}")
    row.append(f"{mrr_avg:.5f}")
    row.append(f"{(end_time - start_time):.5f}")
    table.add_row(*row)
    
    # Common Neighbors
    start_time = time.time()
    precision_avg, mrr_avg = common_neighbors(test_data[dataset_name], test_edge_index[dataset_name], Q[dataset_name], K_values)
    end_time = time.time()
    
    row = ["Common-Neighbors"]
    for k in K_values:
        row.append(f"{precision_avg[k]:.5f}")
    row.append(f"{mrr_avg:.5f}")
    row.append(f"{(end_time - start_time):.5f}")
    table.add_row(*row)
    
    console.print(table)

In [16]:
from collections import defaultdict
import numpy as np

def generate_random_hyperplanes(num_features, num_planes):
    return torch.randn(num_features, num_planes).to(device)

def projection_hash(features, hyperplanes):
    projections = torch.mm(features, hyperplanes)
    hash_codes = (projections > 0).float()
    return hash_codes

def get_hash_codes(train_features, test_features, num_planes=8):
    hyperplanes = generate_random_hyperplanes(train_features.shape[1], num_planes)    
    train_hash_codes = projection_hash(train_features, hyperplanes)

    powers_of_two = torch.pow(2, torch.arange(num_planes - 1, -1, -1)).float().to(device)
    
    train_hash_codes_dict = defaultdict(list)
    for i, train_hash_code in enumerate(train_hash_codes):
        key = torch.dot(train_hash_code, powers_of_two).int().item()
        train_hash_codes_dict[key].append(i)

    for code in train_hash_codes_dict.keys():
        train_hash_codes_dict[code] = torch.tensor(train_hash_codes_dict[code]).to(device)
    
    test_hash_codes = projection_hash(test_features, hyperplanes)
    test_hash_codes = [torch.dot(test_hash_code, powers_of_two).item() for test_hash_code in test_hash_codes]
    
    return train_hash_codes_dict, test_hash_codes

In [17]:
def evaluate_inference_random_lsh(embeddings, test_edge_index, Q, K_values):
    precision_scores = {k: [] for k in K_values}
    mrr_scores = []
    adj = to_dense_adj(test_edge_index, max_num_nodes=embeddings.shape[0]).squeeze(0)
    
    embeddings_hash_codes_dict, Q_hash_codes = get_hash_codes(embeddings, embeddings[Q])
    
    for q, q_hash_code in zip(Q, Q_hash_codes):
        subset_indices = embeddings_hash_codes_dict[q_hash_code]
        scores = F.cosine_similarity(embeddings[q].unsqueeze(0), embeddings[subset_indices], dim=1)
        scores[torch.nonzero(subset_indices == q, as_tuple=True)[0].item()] = float('-inf')
        labels = adj[q][subset_indices]

        p, m = get_evaluation(scores, labels, K_values)
        for k in K_values:
            precision_scores[k].append(p[k])
        mrr_scores.append(m)
    
    # Average metrics
    precision_avg = {k: np.mean(precision_scores[k]) for k in K_values}
    mrr_avg = np.mean(mrr_scores)
        
    return precision_avg, mrr_avg

# **RANDOM LSH INFERENCE**

In [18]:
import time

K_values = [1, 5, 10]

for dataset_name in datasets:
    console.rule(f"[bold blue]Dataset: {dataset_name}")

    table = Table(show_header=True, header_style="bold green")
    table.add_column("GNN Models", style="dim", justify="center")
    for k in K_values:
        table.add_column(f"Precision@{k}", style="bold magenta", justify="center")
    table.add_column("MRR", style="bold magenta", justify="center")
    table.add_column("Inference Time (s)", style="bold red", justify="center")

    for gnn_model in models[dataset_name]:
        start_time = time.time()
        model = models[dataset_name][gnn_model]
        model.eval()
        with torch.no_grad():
            embeddings = model(test_data[dataset_name], test_edge_index[dataset_name])
            precision_avg, mrr_avg = evaluate_inference_random_lsh(embeddings, test_edge_index[dataset_name], Q[dataset_name], K_values)
        end_time = time.time()
        
        row = [gnn_model]
        for k in K_values:
            row.append(f"{precision_avg[k]:.5f}")
        row.append(f"{mrr_avg:.5f}")
        row.append(f"{(end_time - start_time):.5f}")
        table.add_row(*row)
    
    console.print(table)

In [19]:
from collections import defaultdict
import numpy as np
from torch import nn
import random

class NeuralLSH(nn.Module):
    def __init__(self, input_dim, hash_dim, num_tables, subset_size):
        super(NeuralLSH, self).__init__()
        self.input_dim = input_dim
        self.hash_dim = hash_dim
        self.num_tables = num_tables
        self.subset_size = subset_size
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.powers_of_two = torch.pow(2, torch.arange(subset_size - 1, -1, -1)).float().to(device)
        self.zero = torch.tensor([0], device=self.device)

        self.init_hash_functions()
        self.hyperplanes = nn.Parameter(torch.randn(self.input_dim, self.hash_dim, device=self.device))
        
    def init_hash_functions(self):
        self.hash_functions = torch.tensor([], device=self.device).long()
        indices = list(range(self.hash_dim))
        for _ in range(self.num_tables):
            random.shuffle(indices)
            self.hash_functions = torch.cat((self.hash_functions,
                torch.tensor([indices[:self.subset_size]], device=self.device).long()), dim=0)
    
    def _projection(self, features):
        return torch.mm(features, self.hyperplanes)
    
    def forward(self, features):
        return torch.tanh(self._projection(features))
    
    
    def init_hash_tables(self, features):
        features = features.to(self.device)
        full_hash_codes = self._projection(features)
        self.hash_tables = []

        full_hash_values = torch.transpose(((full_hash_codes[:, self.hash_functions] > 0).float() @ self.powers_of_two).int(), 0, 1)
        for table in range(self.num_tables):
            self.hash_tables.append([])
            for hash_val in range(2 ** self.subset_size):
                self.hash_tables[table].append(torch.nonzero(full_hash_values[table] == hash_val).T[0].tolist())      

    def get_corpus_indices(self, corpus_features, test_features):
        self.init_hash_tables(corpus_features)

        test_features = test_features.to(device)
        full_hash_codes = self._projection(test_features)
        
        full_hash_values = ((full_hash_codes[:, self.hash_functions] > 0).float() @ self.powers_of_two).int()

        corpus_indices = []
        for hash_values in full_hash_values:
            indices = set()
            for hash_table, hash_val in zip(self.hash_tables, hash_values):
                indices.update(hash_table[hash_val.item()])
            
            corpus_indices.append(torch.tensor(list(indices)).to(self.device))
        
        return corpus_indices

In [20]:
def loss_func(hash_codes, neg_indices):
    # taking alpha = beta = gamma = 1/3
    term1 = torch.sum(torch.abs(torch.sum(hash_codes, dim=1))) / hash_codes.shape[0]
    
    term2 = torch.sum(torch.abs(torch.abs(hash_codes) - torch.ones(hash_codes.shape[1], device=device))) / hash_codes.shape[0]
        
    negs = hash_codes[neg_indices]
    term3 = torch.sum((negs[0] * negs[1])) / neg_indices.shape[1]

    return (term1 + term2 + term3) / 3


def train_lsh_model(model, optimizer, train_data_embeddings, train_edge_index, val_data, val_neg_edge_index, epochs=200, val_interval=20):
    model.to(device)
    model.train()

    for epoch in range(1, epochs + 1):
        train_neg_edge_index = negative_sampling(edge_index=train_edge_index, 
                                                 num_nodes=train_data_embeddings.shape[0], 
                                                 num_neg_samples=train_edge_index.shape[1]).to(device)

        hash_codes = model(train_data_embeddings)
        loss = loss_func(hash_codes, train_neg_edge_index)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

                
        if epoch % val_interval == 0:
            model.eval()
            with torch.no_grad():
                hash_codes = model(val_data)
                val_loss = loss_func(hash_codes, val_neg_edge_index)
            model.train()

            console.print(f"Epoch {epoch:3d} | [cyan]Train Loss:[/] {loss:.5f} | [green]Val Loss:[/] {val_loss:.5f}")

In [21]:
hash_dim = 16
num_tables = 10
subset_size = 8


lsh_models = {}

for dataset_name in datasets:
    lsh_models[dataset_name] = {'GCN': None, 'GAT': None, 'GIN': None}
    
    for gnn_model in models[dataset_name]:
        console.rule(f"[bold blue]Dataset: {dataset_name}, GNN Model: {gnn_model}")

        with torch.no_grad():
            graph_model = models[dataset_name][gnn_model]
            train_data_embeddings = graph_model(train_data[dataset_name], train_edge_index[dataset_name])
            val_data_embeddings = graph_model(val_data[dataset_name], val_edge_index[dataset_name])

        num_features = train_data_embeddings.shape[1]
        model = NeuralLSH(num_features, hash_dim, num_tables, subset_size)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
    
        train_lsh_model(model, optimizer, train_data_embeddings, train_edge_index[dataset_name], 
              val_data_embeddings, val_neg_edge_index[dataset_name])

        lsh_models[dataset_name][gnn_model] = model

In [22]:
def evaluate_inference_neural_lsh(model, embeddings, test_edge_index, Q, K_values):
    precision_scores = {k: [] for k in K_values}
    mrr_scores = []
    adj = to_dense_adj(test_edge_index, max_num_nodes=embeddings.shape[0]).squeeze(0)

    Q_corpus_indices = model.get_corpus_indices(embeddings, embeddings[Q])
    
    for q, subset_indices in zip(Q, Q_corpus_indices):
        scores = F.cosine_similarity(embeddings[q].unsqueeze(0), embeddings[subset_indices], dim=1)
        scores[torch.nonzero(subset_indices == q, as_tuple=True)[0].item()] = float('-inf')
        labels = adj[q][subset_indices]

        p, m = get_evaluation(scores, labels, K_values)
        for k in K_values:
            precision_scores[k].append(p[k])
        mrr_scores.append(m)
    
    # Average metrics
    precision_avg = {k: np.mean(precision_scores[k]) for k in K_values}
    mrr_avg = np.mean(mrr_scores)
        
    return precision_avg, mrr_avg

# **Neural LSH INFERENCE**

In [23]:
import time

K_values = [1, 5, 10]

for dataset_name in datasets:
    console.rule(f"[bold blue]Dataset: {dataset_name}")

    table = Table(show_header=True, header_style="bold green")
    table.add_column("GNN Models", style="dim", justify="center")
    for k in K_values:
        table.add_column(f"Precision@{k}", style="bold magenta", justify="center")
    table.add_column("MRR", style="bold magenta", justify="center")
    table.add_column("Inference Time (s)", style="bold red", justify="center")

    for gnn_model in models[dataset_name]:
        start_time = time.time()
        model = models[dataset_name][gnn_model]
        lsh_model = lsh_models[dataset_name][gnn_model]
        model.eval()
        with torch.no_grad():
            embeddings = model(test_data[dataset_name], test_edge_index[dataset_name])
            precision_avg, mrr_avg = evaluate_inference_neural_lsh(lsh_model, embeddings, test_edge_index[dataset_name], Q[dataset_name], K_values)
        end_time = time.time()
        
        row = [gnn_model]
        for k in K_values:
            row.append(f"{precision_avg[k]:.5f}")
        row.append(f"{mrr_avg:.5f}")
        row.append(f"{(end_time - start_time):.5f}")
        table.add_row(*row)
    
    console.print(table)