In [3]:
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.5.1+cu124.html --no-cache-dir
!pip install pyg-lib -f https://data.pyg.org/whl/torch-2.5.1+cu124.html --no-cache-dir
!pip install torch-geometric

Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_scatter-2.1.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m199.8 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_sparse-0.6.18%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m278.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-scatter, torch-sparse
Successfully installed torch-scatter-2.1.2+pt25cu124 torch-sparse-0.6.18+pt25cu124
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting pyg-lib
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/pyg_lib-0.4.0%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (2.5 MB)
[2K     [

In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GINConv, SAGEConv
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import coalesce, subgraph
from tqdm.auto import tqdm # Still useful for data generation progress
import numpy as np
import random
import os # Import os for checking file existence

# Function to set seeds for reproducibility (less critical for inference, but good practice)
def set_seed(seed: int = 42):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed) # for multi-GPU
        # Optional: If you need deterministic behavior, uncomment these lines.
        # This might slow down training.
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False
    print(f"Random seed set to {seed}")

# Determine the device to use (single GPU or CPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
set_seed(42)

Using device: cuda
Random seed set to 42


In [7]:
class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, model_type='GIN'):
        super(LinkPredictor, self).__init__()
        self.num_layers = num_layers # Store the number of layers as an attribute
        self.convs = torch.nn.ModuleList()
        if model_type == 'GIN':
            for i in range(num_layers):
                # Using Linear layers within GINConv
                nn_GIN = torch.nn.Sequential(
                    torch.nn.Linear(in_channels if i == 0 else hidden_channels, hidden_channels),
                    torch.nn.BatchNorm1d(hidden_channels),
                    torch.nn.ReLU(),
                    torch.nn.Linear(hidden_channels, hidden_channels),
                    torch.nn.BatchNorm1d(hidden_channels),
                    torch.nn.ReLU()
                )
                self.convs.append(GINConv(nn_GIN))
        elif model_type == 'SAGE':
            for i in range(num_layers):
                # SAGEConv layers
                self.convs.append(SAGEConv(in_channels if i == 0 else hidden_channels, hidden_channels))
        else:
            raise ValueError("Model type must be 'GIN' or 'SAGE'")

        # Linear layer for prediction (used in the predict method)
        self.lin = torch.nn.Linear(2 * hidden_channels, out_channels)
        self.model_type = model_type

    def forward(self, x, edge_index):
        # Pass node features through graph convolution layers
        for conv_layer in self.convs:
            # Ensure inputs to convolution are on the correct device
            x = conv_layer(x, edge_index)
            x = F.relu(x) # Apply ReLU activation after each layer
        return x

    # Predict method using dot product between node embeddings
    # This method expects edge indices that are LOCAL to the provided embeddings `z`
    def predict(self, z, edge_index_pos, edge_index_neg):
        # Calculate scores for positive links
        if edge_index_pos.numel() > 0:
            row_pos, col_pos = edge_index_pos
            pos_out = (z[row_pos] * z[col_pos]).sum(dim=-1)
        else:
            pos_out = torch.empty(0).to(z.device) # Handle case with no positive edges

        # Calculate scores for negative links
        if edge_index_neg.numel() > 0:
            row_neg, col_neg = edge_index_neg
            neg_out = (z[row_neg] * z[col_neg]).sum(dim=-1)
        else:
            neg_out = torch.empty(0).to(z.device) # Handle case with no negative edges

        # Apply sigmoid to get probabilities
        return torch.sigmoid(pos_out), torch.sigmoid(neg_out)



In [29]:
import json

with open('/kaggle/input/team-prediction-inference-example/authors.json', 'r') as file:
    author_index = json.load(file)

author_embedings_np = np.load('/kaggle/input/team-prediction-inference-example/nodes.npy', mmap_mode='r')
edges_np = np.load('/kaggle/input/edges-inference-1/edges (3).npy',allow_pickle = True)

In [38]:
invert_author_index = {}
for key,value in author_index.items():
    invert_author_index[value] = key

In [31]:
edges_np

array([[     0,      0,      0, ...,  80695,  80695,  80695],
       [     1,      2,      3, ..., 100669,  30314, 100670]], dtype=int32)

In [32]:
x_features = torch.tensor(author_embedings_np, dtype=torch.float)
# edges are typically (2, num_edges), representing source and target nodes
edge_index = torch.tensor(edges_np, dtype=torch.long).contiguous() # Transpose to get (2, num_edges)
data_obj = Data(x=x_features, edge_index=edge_index, num_nodes=num_nodes)
print("Data object created.")

Data object created.


In [33]:
full_graph_data = data_obj # This contains the entire graph structure

# --- Model Initialization and Loading ---
# Hyperparameters - must match the saved model's hyperparameters
# Use the feature size from the generated data
in_channels = 559
hidden_channels = 32 # Keep consistent with training
out_channels = 1 # Keep consistent with training
num_layers = 2 # Keep consistent with training
model_type = 'GIN' # Must match the model type used for training

# Initialize the model with the same architecture as the saved model
model = LinkPredictor(in_channels, hidden_channels, out_channels, num_layers, model_type=model_type)

# Define the path to your saved model checkpoint file
model_checkpoint_path = '/kaggle/input/teams_prediction/pytorch/default/1/model_GIN_checkpoint_final.pth' 
model.load_state_dict(torch.load(model_checkpoint_path, map_location=device,weights_only=True))
model.to(device)
model.eval()

LinkPredictor(
  (convs): ModuleList(
    (0): GINConv(nn=Sequential(
      (0): Linear(in_features=559, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=32, out_features=32, bias=True)
      (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    ))
    (1): GINConv(nn=Sequential(
      (0): Linear(in_features=32, out_features=32, bias=True)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Linear(in_features=32, out_features=32, bias=True)
      (4): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    ))
  )
  (lin): Linear(in_features=64, out_features=1, bias=True)
)

In [40]:
target_node_index = 0

with torch.no_grad():
    # Move full graph data to device for embedding computation
    full_graph_data = full_graph_data.to(device)
    all_node_embeddings = model(full_graph_data.x, full_graph_data.edge_index)

target_node_embedding = all_node_embeddings[target_node_index]

target_node_edges_mask = (full_graph_data.edge_index[0] == target_node_index) | (full_graph_data.edge_index[1] == target_node_index)
existing_neighbors = torch.unique(full_graph_data.edge_index[:, target_node_edges_mask].flatten())

existing_neighbors = existing_neighbors[existing_neighbors != target_node_index]

existing_neighbors_set = set(existing_neighbors.cpu().numpy())

potential_target_nodes_indices = [
    i for i in range(full_graph_data.num_nodes)
    if i != target_node_index and i not in existing_neighbors_set
]

potential_target_nodes_indices = torch.tensor(potential_target_nodes_indices, dtype=torch.long).to(device)

potential_edges_from_target = torch.stack([
    torch.full_like(potential_target_nodes_indices, target_node_index),
    potential_target_nodes_indices
], dim=0).to(device)

print(f"Evaluating potential links from node {target_node_index} to {potential_edges_from_target.size(1)} other nodes...")

with torch.no_grad():
     predicted_scores, _ = model.predict(all_node_embeddings, potential_edges_from_target, torch.empty(2, 0).to(device))

ranked_predictions = sorted(zip(potential_target_nodes_indices.cpu().numpy(), predicted_scores.cpu().numpy()),
                            key=lambda x: x[1],
                            reverse=True)

# Display the top predicted links
top_k = 10 # Number of top links to display
print(f"\nTop {top_k} predicted links from Author {invert_author_index[target_node_index]} [node {target_node_index}] (excluding existing connections):")
for node_idx, score in ranked_predictions[:top_k]:
    print(f"  -> Author {invert_author_index[node_idx]} [Node {node_idx}]: Score {score:.4f}")

print(min(ranked_predictions))

Evaluating potential links from node 0 to 100647 other nodes...

Top 10 predicted links from Author /A5065430546 [node 0] (excluding existing connections):
  -> Author /A5087615023 [Node 687]: Score 1.0000
  -> Author /A5026302045 [Node 688]: Score 1.0000
  -> Author /A5110213087 [Node 883]: Score 1.0000
  -> Author /A5091324914 [Node 884]: Score 1.0000
  -> Author /A5067637550 [Node 6370]: Score 1.0000
  -> Author /A5022469337 [Node 7050]: Score 1.0000
  -> Author /A5084108976 [Node 7052]: Score 1.0000
  -> Author /A5108314166 [Node 7053]: Score 1.0000
  -> Author /A5066149235 [Node 7054]: Score 1.0000
  -> Author /A5005817118 [Node 7055]: Score 1.0000
(10, 0.6423824)
