In [9]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, Sequential

In [None]:
class ContextGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(ContextGNN, self).__init__()
        
        # Backbone GNN
        self.gnn = Sequential('x, edge_index', [
            (SAGEConv(in_channels, hidden_channels), 'x, edge_index -> x'),
            torch.nn.ReLU(),
            (SAGEConv(hidden_channels, hidden_channels), 'x, edge_index -> x'),  # Sortie = hidden_channels
        ])
        
        # Pair-wise scorer
        self.pairwise_mlp = torch.nn.Linear(hidden_channels, 1)
        
        # Two-tower scorer
        self.item_embedding = torch.nn.Embedding(num_embeddings=1000, embedding_dim=hidden_channels)
        self.user_tower = torch.nn.Linear(hidden_channels, out_channels)
        self.item_tower = torch.nn.Linear(hidden_channels, out_channels)
        
        # Fusion scorer
        self.fusion_mlp = torch.nn.Linear(hidden_channels, 1)
    
    def forward(self, x, edge_index, user_ids, item_ids):
        # Local subgraph (pair-wise) GNN representation
        user_item_features = self.gnn(x, edge_index)  # Sortie : (500, hidden_channels)
        
        # Filtrer les pairwise_scores pour les user_ids
        pairwise_scores = self.pairwise_mlp(user_item_features)  # (500, 1)
        pairwise_scores = pairwise_scores[user_ids]  # (50, 1)
        
        # Two-tower representation
        user_emb = self.user_tower(user_item_features[user_ids])  # (50, out_channels)
        item_emb = self.item_tower(self.item_embedding(item_ids))  # (50, out_channels)
        twotower_scores = torch.sum(user_emb * item_emb, dim=1, keepdim=True)  # (50, 1)
        
        # Fusion of pair-wise and two-tower scores
        fusion_score = self.fusion_mlp(user_item_features[user_ids])  # (50, 1)
        final_scores = pairwise_scores + fusion_score * twotower_scores  # (50, 1)
        return final_scores



# Dummy data for testing
node_features = torch.randn((500, 64))  # 500 nodes with 64 features each
edges = torch.randint(0, 500, (2, 2000))  # 2000 edges
user_ids = torch.randint(0, 500, (50,))  # 50 users
print(user_ids)
print('-'*50)
item_ids = torch.randint(0, 1000, (50,))  # 50 items (predefined embedding size)
print(item_ids)
print('-'*50)

# Model setup
model = ContextGNN(in_channels=64, hidden_channels=128, out_channels=64)
scores = model(node_features, edges, user_ids, item_ids)
print(scores)


tensor([292, 246,  29, 205, 185, 183, 374, 348, 309, 473, 124, 365, 399, 235,
        131, 428, 221, 361,  65, 410, 477, 300,  28, 303, 465,  25, 387, 265,
        407, 238,   4, 199, 449, 467,  27, 486, 422,  13, 443, 136, 277, 413,
        214,   8,  27, 326, 281, 429,  71, 219])
--------------------------------------------------
tensor([711,  78, 416, 243, 368, 487, 900, 875, 879, 881, 205, 339, 180, 520,
        136, 873, 216, 930, 506,  72,  19, 161, 697, 140, 107, 465, 835, 855,
        764, 113, 458,  12, 401,  30,  45, 444, 242,  94, 398, 974, 540, 693,
         58, 251, 876, 312, 615, 310, 406, 608])
--------------------------------------------------
tensor([[-0.3345],
        [-0.0546],
        [ 0.0628],
        [-0.0540],
        [-0.2733],
        [ 0.0954],
        [ 0.1112],
        [-0.0277],
        [ 0.0926],
        [-0.0582],
        [ 0.0962],
        [-0.2725],
        [ 0.0816],
        [-0.0040],
        [ 0.0120],
        [-0.0255],
        [-0.0984],
        [