In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
import random
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_undirected
import networkx as nx
from torch_geometric.utils.convert import from_networkx
import pickle
import numpy as np

In [2]:
print("CUDA Available:", torch.cuda.is_available())
print("GPU Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU detected")
print("CUDA Device Count:", torch.cuda.device_count())

CUDA Available: True
GPU Name: NVIDIA L40S
CUDA Device Count: 1


In [3]:
# Cargar el grafo
with open("../data/graph_w_embeddings_full_prototype.pkl", "rb") as f:
    G = pickle.load(f)

In [4]:
def clean_graph_for_graphsage(G, embedding_dim=1024, remove_incomplete=True):
    """
    Cleans a NetworkX graph for GraphSAGE training by:
      1. Converting the graph to an undirected version.
      2. Removing all edge attributes.
      3. Ensuring each node has a valid "embedding" attribute:
         - If an embedding is missing:
             - If remove_incomplete is True: remove the node.
             - Otherwise, print a warning and assign a default zero embedding.
         - If the embedding is stored as a dictionary (with key "embedding"),
           extract the value.
         - If the embedding is a list or numpy.ndarray, convert it to a torch.Tensor.
      4. Removing all other node attributes (keeping only "embedding").

    Parameters:
      G (networkx.Graph): The input graph.
      embedding_dim (int): Expected dimension of each node's embedding.
      remove_incomplete (bool): If True, remove nodes without an embedding;
                                if False, assign a default zero tensor.

    Returns:
      networkx.Graph: The cleaned, undirected graph.
    """
    # 1. Convert to undirected graph
    G_undirected = nx.Graph(nx.to_undirected(G))
    
    # 2. Clean node attributes
    # Use list(G_undirected.nodes()) since we may remove nodes
    for node in list(G_undirected.nodes()):
        node_attrs = G_undirected.nodes[node]
        embedding = node_attrs.get("embedding", None)
        
        if embedding is None:
            # Node is missing embedding
            if remove_incomplete:
                print(f"[INFO] Removing node {node} because it has no 'embedding'.")
                G_undirected.remove_node(node)
                continue  # Skip further processing for this node
            else:
                print(f"[WARNING] Node {node} has no 'embedding'. Assigning default zero embedding.")
                embedding = np.zeros(embedding_dim, dtype=np.float32)
        
        # If embedding is a dictionary with a key "embedding", extract it.
        if isinstance(embedding, dict) and "embedding" in embedding:
            embedding = embedding["embedding"]
        
        # Convert list or numpy.ndarray to a numpy array of type float32
        if isinstance(embedding, list):
            embedding = np.array(embedding, dtype=np.float32)
        elif isinstance(embedding, np.ndarray):
            embedding = embedding.astype(np.float32)
        
        # Now convert to a torch tensor if it's not already
        if not isinstance(embedding, torch.Tensor):
            embedding = torch.tensor(embedding, dtype=torch.float32)
        
        # Check if the embedding has the expected shape
        if embedding.dim() != 1 or embedding.shape[0] != embedding_dim:
            print(f"[WARNING] Node {node} embedding shape is {embedding.shape}, expected ({embedding_dim},).")
        
        # Reassign the cleaned embedding back to the node
        node_attrs["embedding"] = embedding
        
        # Remove any other attributes besides 'embedding'
        keys_to_remove = [k for k in node_attrs.keys() if k != "embedding"]
        for k in keys_to_remove:
            del node_attrs[k]

    #Remove all edge attributes
    for u, v in G_undirected.edges():
        for key in list(G_undirected[u][v].keys()):
            del G_undirected[u][v][key] 
    
    return G_undirected

In [5]:
cleaned_G = clean_graph_for_graphsage(G, embedding_dim=1024)

[INFO] Removing node Yardbarker because it has no 'embedding'.
[INFO] Removing node Engadget because it has no 'embedding'.


In [6]:
for node in cleaned_G.nodes():
    emb = cleaned_G.nodes[node]["embedding"]
    if emb is None:
        print(f"[ERROR] Node {node} is still missing embedding!")
    elif not isinstance(emb, torch.Tensor):
        print(f"[ERROR] Node {node} embedding is not a torch.Tensor!")
    elif emb.shape != (1024,):
        print(f"[ERROR] Node {node} has shape {emb.shape} != ({1024},)!")

In [7]:
# Convertir el grafo de NetworkX a PyTorch Geometric Data
data = from_networkx(cleaned_G)

In [8]:
data.edge_index = to_undirected(data.edge_index)

In [9]:
data.x = data.embedding

In [10]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(GraphSAGE, self).__init__()
        self.convs = torch.nn.ModuleList()
        
        # First GraphSAGE layer: input (embeddings) → hidden layer
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        
        # Intermediate layers (if num_layers > 2)
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        
        # Last GraphSAGE layer: hidden layer → final embedding
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:  # Intermediate layers
            x = conv(x, edge_index)
            x = F.relu(x)  # ReLU activation
        x = self.convs[-1](x, edge_index)  # Last layer (no activation)
        return x


In [11]:
model = GraphSAGE(
    in_channels=1024,   # Input features (BGE-M3 embeddings)
    hidden_channels=512,  # First hidden layer (alto para máxima capacidad)
    out_channels=256,   # Output embeddings (más ricos)
    num_layers=2        # Mantenemos 2 capas (2 hops)
)


In [12]:
train_loader = NeighborLoader(
    data,
    num_neighbors=[25, 15],  # Más vecinos en cada hop (sin desbordar memoria)
    batch_size=512,  # Balanceamos tamaño grande sin sobrecargar GPU
    shuffle=True
)




In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


In [14]:
# Para el modelo:
print("Modelo en:", next(model.parameters()).device)

# Para los datos (ejemplo con data.x):#
#print("Data en:", data.device)


Modelo en: cuda:0


In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [16]:
def unsupervised_loss(z, edge_index, num_neg_samples=5):
    """
    Compute GraphSAGE unsupervised loss with negative sampling.

    Parameters:
        z (Tensor): Node embeddings of shape [num_nodes, embedding_dim].
        edge_index (Tensor): Graph connectivity of shape [2, num_edges].
        num_neg_samples (int): Number of negative samples per node.

    Returns:
        loss (Tensor): Computed contrastive loss.
    """
    pos_loss = torch.tensor(0.0, device=z.device)  # Loss for positive node pairs
    neg_loss = torch.tensor(0.0, device=z.device)  # Loss for negative node pairs

    num_nodes = z.shape[0]  # Number of nodes in the graph

    for edge in edge_index.T:  # Iterate over each edge in the graph
        u, v = edge  # Extract source node (u) and destination node (v)

        # Positive pair loss (nodes that are neighbors)
        pos_loss += torch.log(torch.sigmoid(torch.dot(z[u], z[v])))

        # Negative sampling (random nodes that are NOT neighbors)
        for _ in range(num_neg_samples):
            v_neg = random.randint(0, num_nodes - 1)
            while v_neg in edge_index[1]:  # Ensure v_neg is NOT a neighbor
                v_neg = random.randint(0, num_nodes - 1)

            neg_loss += torch.log(1 - torch.sigmoid(torch.dot(z[u], z[v_neg])))

    loss = -(pos_loss + neg_loss) / edge_index.shape[1]  # Normalize by number of edges
    return loss


In [17]:
scaler = torch.cuda.amp.GradScaler()

In [18]:
data

Data(edge_index=[2, 9002], embedding=[2620, 1024], num_nodes=2620, x=[2620, 1024])

In [None]:
def train():
    model.train()
    for batch in train_loader:
        batch = batch.to(device)  # Mover batch a GPU
        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):  # Precision mixta (velocidad extra)
            z = model(batch.x, batch.edge_index)
            loss = unsupervised_loss(z, batch.edge_index)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

for epoch in range(100):  # ¡Duro con 100 épocas!
    train()
    print(f"🔥 Epoch {epoch} completada. 🔥")

🔥 Epoch 0 completada. 🔥
🔥 Epoch 1 completada. 🔥
🔥 Epoch 2 completada. 🔥
🔥 Epoch 3 completada. 🔥
🔥 Epoch 4 completada. 🔥
