In [7]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
import random
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader

In [None]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super(GraphSAGE, self).__init__()
        self.convs = torch.nn.ModuleList()
        
        # First GraphSAGE layer: input (embeddings) → hidden layer
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        
        # Intermediate layers (if num_layers > 2)
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        
        # Last GraphSAGE layer: hidden layer → final embedding
        self.convs.append(SAGEConv(hidden_channels, out_channels))

    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:  # Intermediate layers
            x = conv(x, edge_index)
            x = F.relu(x)  # ReLU activation
        x = self.convs[-1](x, edge_index)  # Last layer (no activation)
        return x


In [None]:
model = GraphSAGE(
    in_channels=1024,   # Input features (BGE-M3 embeddings)
    hidden_channels=512,  # First hidden layer (alto para máxima capacidad)
    out_channels=256,   # Output embeddings (más ricos)
    num_layers=2        # Mantenemos 2 capas (2 hops)
)


In [None]:
train_loader = NeighborLoader(
    data,
    num_neighbors=[25, 15],  # Más vecinos en cada hop (sin desbordar memoria)
    batch_size=512,  # Balanceamos tamaño grande sin sobrecargar GPU
    shuffle=True
)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
data = data.to(device)


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
def unsupervised_loss(z, edge_index, num_neg_samples=5):
    """
    Compute GraphSAGE unsupervised loss with negative sampling.

    Parameters:
        z (Tensor): Node embeddings of shape [num_nodes, embedding_dim].
        edge_index (Tensor): Graph connectivity of shape [2, num_edges].
        num_neg_samples (int): Number of negative samples per node.

    Returns:
        loss (Tensor): Computed contrastive loss.
    """
    pos_loss = 0  # Loss for positive node pairs
    neg_loss = 0  # Loss for negative node pairs

    num_nodes = z.shape[0]  # Number of nodes in the graph

    for edge in edge_index.T:  # Iterate over each edge in the graph
        u, v = edge  # Extract source node (u) and destination node (v)

        # Positive pair loss (nodes that are neighbors)
        pos_loss += torch.log(torch.sigmoid(torch.dot(z[u], z[v])))

        # Negative sampling (random nodes that are NOT neighbors)
        for _ in range(num_neg_samples):
            v_neg = random.randint(0, num_nodes - 1)
            while v_neg in edge_index[1]:  # Ensure v_neg is NOT a neighbor
                v_neg = random.randint(0, num_nodes - 1)

            neg_loss += torch.log(1 - torch.sigmoid(torch.dot(z[u], z[v_neg])))

    loss = -(pos_loss + neg_loss) / edge_index.shape[1]  # Normalize by number of edges
    return loss


In [None]:
scaler = torch.cuda.amp.GradScaler()

In [None]:
def train():
    model.train()
    for batch in train_loader:
        batch = batch.to(device)  # Mover batch a GPU
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():  # Precision mixta (velocidad extra)
            z = model(batch.x, batch.edge_index)
            loss = unsupervised_loss(z, batch.edge_index)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

for epoch in range(100):  # ¡Duro con 100 épocas!
    train()
    print(f"🔥 Epoch {epoch} completada. 🔥")