In [1]:
import os
import sys
import pickle
import random

import networkx as nx
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import to_undirected
from torch_geometric.utils.convert import from_networkx

# Add "src" path to Python path
sys.path.append(os.path.abspath("../src"))

# Import custom graph formatting function
from graph_formatting_utils import format_graph_for_graphsage

In [2]:
# Check CUDA status
print("CUDA Available:", torch.cuda.is_available())
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU detected")
print("CUDA Device Count:", torch.cuda.device_count())

CUDA Available: True
GPU Name: NVIDIA L40S
CUDA Device Count: 1


In [3]:
# Load baseline graph
with open("../data/MultiHop_graph_w_sem_embeddings.pkl", "rb") as f:
    G = pickle.load(f)

In [4]:
cleaned_G = format_graph_for_graphsage(G, embedding_dim=1024)

In [5]:
# Sanity check for initial node embeddings
for node in cleaned_G.nodes():
    emb = cleaned_G.nodes[node]["embedding"]
    if emb is None:
        print(f"[ERROR] Node {node} is still missing embedding!")
    elif not isinstance(emb, torch.Tensor):
        print(f"[ERROR] Node {node} embedding is not a torch.Tensor!")
    elif emb.shape != (1024,):
        print(f"[ERROR] Node {node} has shape {emb.shape} != ({1024},)!")

In [6]:
# Convert the NetworkX graph to a PyTorch Geometric Data object
data = from_networkx(cleaned_G)

In [7]:
# Ensure the graph is undirected
data.edge_index = to_undirected(data.edge_index)

In [8]:
# Create data attribute "x" containing the embeddings of each node complying with the PyTorch Geometric API
data.x = data.embedding

In [9]:
# Define the GraphSAGE model class
class GraphSAGE(torch.nn.Module):
    """
    Implementation of the GraphSAGE model for node representation learning in graphs.

    Parameters:
    -----------
    in_channels : int
        The dimensionality of input node features (e.g., embedding size).
    hidden_channels : int
        The dimensionality of hidden layers in the GraphSAGE model.
    out_channels : int
        The dimensionality of the output node representations.
    num_layers : int, optional
        The number of GraphSAGE layers (default is 2).

    Methods:
    --------
    forward(x, edge_index):
        Performs forward propagation through the GraphSAGE layers.

    Returns:
    --------
    torch.Tensor
        The learned node embeddings of shape (num_nodes, out_channels).

    Notes:
    ------
    - The first layer transforms input embeddings into a hidden representation.
    - Intermediate layers apply non-linear transformations (`ReLU` activation).
    - The final layer outputs node embeddings without activation.
    - Uses `SAGEConv` layers from PyTorch Geometric (`torch_geometric.nn`).
    """
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(GraphSAGE, self).__init__()
        self.convs = torch.nn.ModuleList()
        
        # First GraphSAGE layer: input (embeddings) → hidden layer
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        
        # Intermediate layers (if num_layers > 2)
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        
        # Last GraphSAGE layer: hidden layer → final embedding
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:  # Intermediate layers
            x = conv(x, edge_index)
            x = F.relu(x)  # ReLU activation
        x = self.convs[-1](x, edge_index)  # Last layer (no activation)
        return x

In [10]:
# Instantiate the GraphSAGE model
model = GraphSAGE(
    in_channels=1024,   # Input features (BGE-M3 embeddings)
    hidden_channels=512,  # First hidden layer (alto para máxima capacidad)
    out_channels=256,   # Output embeddings (más ricos)
    num_layers=2        # Mantenemos 2 capas (2 hops)
)

In [11]:
# Instantiate the NeighborLoader for mini-batch training
train_loader = NeighborLoader(
    data,
    num_neighbors=[25, 15],  # 25 neighbors for the first layer, 15 for the second
    batch_size=512,  # Batch size
    shuffle=True
)



In [12]:
# Set device for model training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the device
model = model.to(device)

# Check model device
print("Model in:", next(model.parameters()).device)

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Define scaler
scaler = torch.cuda.amp.GradScaler()

Model in: cuda:0


In [13]:
def unsupervised_loss(z, edge_index, num_neg_samples=5):
    """
    Compute the unsupervised loss for GraphSAGE using contrastive negative sampling.

    Parameters:
    -----------
    z : torch.Tensor
        Node embeddings of shape `(num_nodes, embedding_dim)`.
    edge_index : torch.Tensor
        Graph connectivity matrix of shape `(2, num_edges)`, where each column represents an edge `(u, v)`.
    num_neg_samples : int, optional
        Number of negative samples per positive edge (default is 5).

    Returns:
    --------
    torch.Tensor
        The computed contrastive loss value.

    Notes:
    ------
    - The loss function follows the **GraphSAGE unsupervised learning approach**, leveraging contrastive learning.
    - **Positive pairs**: Directly connected nodes in `edge_index`.
    - **Negative pairs**: Randomly sampled nodes that are not neighbors.
    - The objective is to maximize similarity for positive pairs and minimize it for negative pairs.
    """
    pos_loss = torch.tensor(0.0, device=z.device)  # Loss for positive node pairs
    neg_loss = torch.tensor(0.0, device=z.device)  # Loss for negative node pairs

    num_nodes = z.shape[0]  # Number of nodes in the graph

    for edge in edge_index.T:  # Iterate over each edge in the graph
        u, v = edge  # Extract source node (u) and destination node (v)

        # Positive pair loss (nodes that are neighbors)
        pos_loss += torch.log(torch.sigmoid(torch.dot(z[u], z[v])))

        # Negative sampling (random nodes that are NOT neighbors)
        for _ in range(num_neg_samples):
            v_neg = random.randint(0, num_nodes - 1)
            while v_neg in edge_index[1]:  # Ensure v_neg is NOT a neighbor
                v_neg = random.randint(0, num_nodes - 1)

            neg_loss += torch.log(1 - torch.sigmoid(torch.dot(z[u], z[v_neg])))

    loss = -(pos_loss + neg_loss) / edge_index.shape[1]  # Normalize by number of edges
    return loss

In [14]:
def train():
    model.train()
    for batch in train_loader:
        batch = batch.to(device)  # Move batch to device
        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):  # mixed precission
            z = model(batch.x, batch.edge_index)
            loss = unsupervised_loss(z, batch.edge_index)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

In [15]:
num_epochs = 100  # Number of training epochs (Each epoch takes apr. 25-30 seconds)

for epoch in range(num_epochs): 
    train()
    print(f"Epoch {epoch} completed.")

Epoch 0 completed.
Epoch 1 completed.
Epoch 2 completed.
Epoch 3 completed.
Epoch 4 completed.
Epoch 5 completed.
Epoch 6 completed.
Epoch 7 completed.
Epoch 8 completed.
Epoch 9 completed.
Epoch 10 completed.
Epoch 11 completed.
Epoch 12 completed.
Epoch 13 completed.
Epoch 14 completed.
Epoch 15 completed.
Epoch 16 completed.
Epoch 17 completed.
Epoch 18 completed.
Epoch 19 completed.
Epoch 20 completed.
Epoch 21 completed.
Epoch 22 completed.
Epoch 23 completed.
Epoch 24 completed.
Epoch 25 completed.
Epoch 26 completed.
Epoch 27 completed.
Epoch 28 completed.
Epoch 29 completed.
Epoch 30 completed.
Epoch 31 completed.
Epoch 32 completed.
Epoch 33 completed.
Epoch 34 completed.
Epoch 35 completed.
Epoch 36 completed.
Epoch 37 completed.
Epoch 38 completed.
Epoch 39 completed.
Epoch 40 completed.
Epoch 41 completed.
Epoch 42 completed.
Epoch 43 completed.
Epoch 44 completed.
Epoch 45 completed.
Epoch 46 completed.
Epoch 47 completed.
Epoch 48 completed.
Epoch 49 completed.
Epoch 50 c

In [16]:
# Save embeddings as .npy file

## 1) Get the device from any parameter in the model
#device = next(model.parameters()).device
#
## 2) Now move your data.x and data.edge_index to that device:
#data_x = data.x.to(device)
#data_edge_index = data.edge_index.to(device)
#
## 3) Forward pass with torch.no_grad():
#with torch.no_grad():
#    embeddings = model(data_x, data_edge_index)
#
#embeddings_np = embeddings.cpu().numpy()
#np.save("../data/graphsage_embeddings.npy", embeddings_np)
#print("Embeddings saved to graphsage_embeddings.npy")


In [17]:
# Add Graph_SAGE embeddings to the baseline graph

# 1) Move data to the same device as the model
device = next(model.parameters()).device
data_x = data.x.to(device)
data_edge_index = data.edge_index.to(device)

# 2) Obtain final embeddings from the trained model
with torch.no_grad():
    final_emb = model(data_x, data_edge_index)  # shape [num_nodes, embedding_dim]
    final_emb_np = final_emb.cpu().numpy()

# 3) Add them back to the cleaned_G graph
list_of_nodes = list(G.nodes())  # Must match the node ordering in data
for i, node in enumerate(list_of_nodes):
    # Store as a NumPy array (or you could store as a list if you prefer)
    G.nodes[node]["SAGE_embedding"] = final_emb_np[i]

print("SAGE embeddings added to G under 'SAGE_embedding' attribute.")


SAGE embeddings added to G under 'SAGE_embedding' attribute.


In [19]:
with open(f"../data/MultiHop_graph_w_sage{num_epochs}_embeddings.pkl", "wb") as f:
    pickle.dump(G, f)