In [None]:
import pandas as pd
import numpy as np
from typing import Dict, Tuple
from collections import defaultdict
import time

class PrimeKGLoader:
    """
    Prepares PrimeKG data for efficient loading into DGL heterogeneous graphs.
    Each node type gets sequential IDs starting from 0.
    """
    
    def __init__(self):
        self.node_type_mapping = {}  # string -> int
        self.relationship_type_mapping = {}  # string -> int
        self.reverse_node_type_mapping = {}  # int -> string
        self.reverse_relationship_type_mapping = {}  # int -> string
        self.global_to_local_mapping = {}  # For reference: global_id -> (node_type, local_id)
        
    def load_and_prepare_primekg(self, nodes_csv_path: str, edges_csv_path: str):
        """
        Load PrimeKG data and prepare it for bulk_load_heterogeneous_graph.
        Each node type gets sequential IDs starting from 0.
        
        Args:
            nodes_csv_path: Path to nodes CSV file
            edges_csv_path: Path to edges CSV file
            
        Returns:
            Tuple of (node_types_dict, edge_types_dict) ready for DGL loading
        """
        print("Loading PrimeKG data...")
        start_time = time.time()
        
        # Load raw data
        print("  Reading CSV files...")
        nodes_df = pd.read_csv(nodes_csv_path, low_memory=False)
        edges_df = pd.read_csv(edges_csv_path, low_memory=False)
        
        print(f"  Loaded {len(nodes_df):,} nodes and {len(edges_df):,} edges")
        
        # Create type mappings
        print("  Creating type mappings...")
        self._create_type_mappings(nodes_df, edges_df)
        
        # Prepare node data (sequential IDs starting from 0 for each type)
        print("  Preparing node data...")
        node_types_dict = self._prepare_node_data(nodes_df)
        
        # Prepare edge data (using local IDs)
        print("  Preparing edge data...")
        edge_types_dict = self._prepare_edge_data(edges_df, nodes_df)
        
        total_time = time.time() - start_time
        print(f"\nData preparation completed in {total_time:.2f}s")
        
        # Print summary
        self._print_summary(node_types_dict, edge_types_dict)
        
        return node_types_dict, edge_types_dict, self.global_to_local_mapping
    
    def _create_type_mappings(self, nodes_df: pd.DataFrame, edges_df: pd.DataFrame):
        """Create mappings between string types and integer representations."""
        
        # Node type mappings
        unique_node_types = sorted(nodes_df['node_type'].unique())
        self.node_type_mapping = {node_type: i for i, node_type in enumerate(unique_node_types)}
        self.reverse_node_type_mapping = {i: node_type for node_type, i in self.node_type_mapping.items()}
        
        # Relationship type mappings
        unique_rel_types = sorted(edges_df['relationship_type'].unique())
        self.relationship_type_mapping = {rel_type: i for i, rel_type in enumerate(unique_rel_types)}
        self.reverse_relationship_type_mapping = {i: rel_type for rel_type, i in self.relationship_type_mapping.items()}
        
        print(f"    Found {len(unique_node_types)} node types: {unique_node_types}")
        print(f"    Found {len(unique_rel_types)} relationship types: {unique_rel_types}")
    
    def _prepare_node_data(self, nodes_df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
        """
        Group nodes by type and prepare DataFrames with sequential IDs starting from 0.
        
        Returns:
            Dict mapping node_type_string -> DataFrame with columns ['node_id', 'name', 'metadata_source', 'node_type_id', 'original_global_id']
        """
        node_types_dict = {}
        
        # Add numeric type ID to nodes
        nodes_df_copy = nodes_df.copy()
        nodes_df_copy['node_type_id'] = nodes_df_copy['node_type'].map(self.node_type_mapping)
        
        # Group by node type and assign sequential IDs starting from 0
        for node_type_str, group_df in nodes_df_copy.groupby('node_type'):
            # Sort by original ID for consistency
            group_df = group_df.sort_values('id').reset_index(drop=True)
            
            # Create sequential IDs starting from 0
            num_nodes = len(group_df)
            
            # Build global to local mapping for this node type
            global_ids = group_df['id'].values
            local_ids = np.arange(num_nodes)  # 0, 1, 2, ..., num_nodes-1
            
            # Store the mapping for edge processing
            for local_id, global_id in zip(local_ids, global_ids):
                self.global_to_local_mapping[global_id] = (node_type_str, local_id)
            
            # Prepare DataFrame for DGL
            prepared_df = pd.DataFrame({
                'node_id': local_ids,  # Sequential IDs starting from 0
                'name': group_df['name'].values,
                'metadata_source': group_df['metadata_source'].values,
                'node_type_id': group_df['node_type_id'].values,
                'original_global_id': global_ids  # Keep original for reference
            })
            
            node_types_dict[node_type_str] = prepared_df
            print(f"    {node_type_str}: {num_nodes:,} nodes (IDs: 0 to {num_nodes-1})")
            
        return node_types_dict
    
    def _prepare_edge_data(self, edges_df: pd.DataFrame, nodes_df: pd.DataFrame) -> Dict[Tuple[str, str, str], pd.DataFrame]:
        """
        Prepare edge data grouped by (src_type, edge_type, dst_type) using local IDs.
        
        Returns:
            Dict mapping (src_type, edge_type, dst_type) -> DataFrame with columns ['src', 'dst', 'relationship_type_id']
        """
        # Create node ID to type mapping for fast lookup
        node_id_to_type = dict(zip(nodes_df['id'], nodes_df['node_type']))
        
        # Add relationship type IDs
        edges_df_copy = edges_df.copy()
        edges_df_copy['relationship_type_id'] = edges_df_copy['relationship_type'].map(self.relationship_type_mapping)
        
        # Add source and target node types
        edges_df_copy['src_type'] = edges_df_copy['source_id'].map(node_id_to_type)
        edges_df_copy['dst_type'] = edges_df_copy['target_id'].map(node_id_to_type)
        
        # Filter out edges with unknown nodes
        valid_mask = (edges_df_copy['src_type'].notna()) & (edges_df_copy['dst_type'].notna())
        valid_edges = edges_df_copy[valid_mask]
        
        if len(valid_edges) < len(edges_df_copy):
            print(f"    Warning: Filtered out {len(edges_df_copy) - len(valid_edges)} edges with unknown nodes")
        
        # Group by (src_type, relationship_type, dst_type)
        edge_types_dict = {}
        
        for (src_type, rel_type, dst_type), group_df in valid_edges.groupby(['src_type', 'relationship_type', 'dst_type']):
            print(f"    Processing {src_type} --[{rel_type}]--> {dst_type}: {len(group_df):,} edges")
            
            # VECTORIZED APPROACH - Much faster than loops
            group_df_reset = group_df.reset_index(drop=True)
            
            # Create mapping functions for this specific edge type
            src_type_mapping = {global_id: local_id for global_id, (nt, local_id) in self.global_to_local_mapping.items() if nt == src_type}
            dst_type_mapping = {global_id: local_id for global_id, (nt, local_id) in self.global_to_local_mapping.items() if nt == dst_type}
            
            # Vectorized mapping using pandas map
            group_df_reset['src_local'] = group_df_reset['source_id'].map(src_type_mapping)
            group_df_reset['dst_local'] = group_df_reset['target_id'].map(dst_type_mapping)
            
            # Filter valid edges (both src and dst must be mapped)
            valid_mask = (group_df_reset['src_local'].notna()) & (group_df_reset['dst_local'].notna())
            valid_edges_df = group_df_reset[valid_mask]
            
            if len(valid_edges_df) == 0:
                print(f"      Warning: No valid edges found for {src_type}-{rel_type}->{dst_type}")
                continue
            
            # Create edge DataFrame with local node IDs
            edge_df = pd.DataFrame({
                'src': valid_edges_df['src_local'].astype(int).values,  # Local IDs (0-based for each node type)
                'dst': valid_edges_df['dst_local'].astype(int).values,  # Local IDs (0-based for each node type)
                'relationship_type_id': valid_edges_df['relationship_type_id'].values,
                'original_src_id': valid_edges_df['source_id'].values,  # Keep original for reference
                'original_dst_id': valid_edges_df['target_id'].values   # Keep original for reference
            })
            
            edge_types_dict[(src_type, rel_type, dst_type)] = edge_df
            print(f"      Created {len(edge_df):,} valid edges")
            
        return edge_types_dict
    
    def _print_summary(self, node_types_dict: Dict[str, pd.DataFrame], 
                      edge_types_dict: Dict[Tuple[str, str, str], pd.DataFrame]):
        """Print summary of prepared data."""
        print("\n" + "="*60)
        print("PRIMEKG DATA PREPARATION SUMMARY")
        print("="*60)
        
        print("\nNode Type Mappings:")
        for str_type, int_type in self.node_type_mapping.items():
            count = len(node_types_dict.get(str_type, []))
            print(f"  {int_type}: {str_type} ({count:,} nodes, IDs: 0 to {count-1})")
        
        print("\nRelationship Type Mappings:")
        for str_type, int_type in self.relationship_type_mapping.items():
            print(f"  {int_type}: {str_type}")
        
        print("\nPrepared Node Types:")
        total_nodes = 0
        for node_type, df in node_types_dict.items():
            min_id = df['node_id'].min()
            max_id = df['node_id'].max()
            print(f"  {node_type}: {len(df):,} nodes (local IDs: {min_id} to {max_id})")
            total_nodes += len(df)
        print(f"  TOTAL: {total_nodes:,} nodes")
        
        print("\nPrepared Edge Types:")
        total_edges = 0
        for (src_type, edge_type, dst_type), df in edge_types_dict.items():
            print(f"  {src_type} --[{edge_type}]--> {dst_type}: {len(df):,} edges")
            total_edges += len(df)
        print(f"  TOTAL: {total_edges:,} edges")
        
        print("\nData Format Verification:")
        for node_type, df in node_types_dict.items():
            assert df['node_id'].min() == 0, f"Node IDs for {node_type} don't start at 0!"
            assert df['node_id'].max() == len(df) - 1, f"Node IDs for {node_type} are not sequential!"
            print(f"  ✅ {node_type}: Sequential IDs 0 to {len(df)-1}")
        
        print("="*60)
    
    def get_type_mappings(self):
        """Return the type mappings for reference."""
        return {
            'node_types': self.node_type_mapping,
            'relationship_types': self.relationship_type_mapping,
            'reverse_node_types': self.reverse_node_type_mapping,
            'reverse_relationship_types': self.reverse_relationship_type_mapping
        }
    
    def get_global_to_local_mapping(self):
        """Return the global to local ID mapping for reference."""
        return self.global_to_local_mapping.copy()
    
    def global_id_to_local(self, global_id: int) -> Tuple[str, int]:
        """Convert a global node ID to (node_type, local_id)."""
        if global_id in self.global_to_local_mapping:
            return self.global_to_local_mapping[global_id]
        else:
            raise ValueError(f"Global ID {global_id} not found in mapping")
    
    def local_id_to_global(self, node_type: str, local_id: int) -> int:
        """Convert (node_type, local_id) to global node ID."""
        for global_id, (nt, lid) in self.global_to_local_mapping.items():
            if nt == node_type and lid == local_id:
                return global_id
        raise ValueError(f"Local ID ({node_type}, {local_id}) not found in mapping")

In [None]:
from DeepGraphDB import DeepGraphDB
import torch

db = DeepGraphDB()
# Initialize the loader
loader = PrimeKGLoader()

# Load and prepare data
nodes_csv = "nodes.csv"  # Replace with your actual path
edges_csv = "edges.csv"  # Replace with your actual path

node_types_dict, edge_types_dict, mapping = loader.load_and_prepare_primekg(nodes_csv, edges_csv)

    
# Get type mappings for reference
mappings = loader.get_type_mappings()
print("\nType mappings created:")
print("Node types:", mappings['node_types'])
print("Relationship types:", mappings['relationship_types'])

# Verify data format
print("\nData format verification:")
for node_type, df in node_types_dict.items():
    print(f"  {node_type}: node_id range {df['node_id'].min()}-{df['node_id'].max()}")

# Now you can use this data with your DGL graph analyzer
print("\nReady to load into DGL!")
print("Use: analyzer.bulk_load_heterogeneous_graph(node_types_dict, edge_types_dict)")
db.bulk_load_heterogeneous_graph(node_types_dict, edge_types_dict)
db.set_mappings(loader.node_type_mapping, loader.relationship_type_mapping)
db.set_global_to_local_mapping(mapping)

x = torch.rand(max(db.global_to_local_mapping.keys())+1, 128)
db.load_node_features_for_gnn(x)

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn
import dgl.function as fn
from dgl.dataloading import (
    DataLoader, 
    MultiLayerNeighborSampler, 
    as_edge_prediction_sampler
)
from sklearn.metrics import roc_auc_score, average_precision_score
import itertools
from typing import Dict, List, Tuple, Optional
from collections import defaultdict

# Set PyTorch backend for DGL
os.environ["DGLBACKEND"] = "pytorch"

class HeteroGraphSAGE(nn.Module):
    """
    Heterogeneous GraphSAGE implementation with attention mechanism
    """
    def __init__(self, node_types, edge_types, in_feats, hidden_feats, out_feats, 
                 num_layers=2, aggregator_type='mean', use_attention=True):
        super().__init__()
        self.node_types = node_types
        self.edge_types = edge_types
        self.num_layers = num_layers
        self.use_attention = use_attention

        print(node_types)
        
        # Input projection for different node types
        self.input_proj = nn.ModuleDict({
            ntype: nn.Linear(in_feats[ntype] if isinstance(in_feats, dict) else in_feats, hidden_feats)
            for ntype in node_types
        })

        print(self.input_proj)
        
        # GraphSAGE layers
        self.sage_layers = nn.ModuleList()
        self.layer_norms = nn.ModuleList()
        
        for i in range(num_layers):
            in_dim = hidden_feats
            out_dim = hidden_feats if i < num_layers - 1 else out_feats
            
            # Use different aggregators for different layers
            conv_dict = {}
            for etype in edge_types:
                conv_dict[etype] = dglnn.SAGEConv(
                    in_feats=in_dim,
                    out_feats=out_dim,
                    aggregator_type=aggregator_type,
                    norm=F.relu if i < num_layers - 1 else None,
                    activation=F.relu if i < num_layers - 1 else None
                )
            
            self.sage_layers.append(
                dglnn.HeteroGraphConv(conv_dict, aggregate='sum')
            )
            
            # Layer norm for each layer with correct dimensions
            layer_norm_dict = nn.ModuleDict({
                ntype: nn.LayerNorm(out_dim) for ntype in node_types
            })
            self.layer_norms.append(layer_norm_dict)
        
        # Attention mechanism for heterogeneous message passing
        if use_attention:
            self.attention = nn.ModuleDict({
                ntype: nn.MultiheadAttention(
                    embed_dim=out_feats,
                    num_heads=4,
                    dropout=0.1,
                    batch_first=True
                ) for ntype in node_types
            })
        
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, blocks, x):
        h = {}
        
        # Input projection
        for ntype in self.node_types:
            if ntype in x:
                h[ntype] = self.input_proj[ntype](x[ntype])
        
        # Forward through SAGE layers
        for i, (layer, block) in enumerate(zip(self.sage_layers, blocks)):
            h_new = layer(block, h)
            
            # Apply attention if enabled (only on final layer)
            if self.use_attention and i == len(self.sage_layers) - 1:
                for ntype in h_new:
                    if h_new[ntype].dim() == 2:
                        # Add sequence dimension for attention
                        h_input = h_new[ntype].unsqueeze(1)  # [N, 1, D]
                        attn_out, _ = self.attention[ntype](h_input, h_input, h_input)
                        h_new[ntype] = attn_out.squeeze(1)  # [N, D]
            
            # Layer normalization and residual connection
            if i > 0:  # Skip connection from previous layer
                for ntype in h_new:
                    if ntype in h and h[ntype].shape == h_new[ntype].shape:
                        h_new[ntype] = h_new[ntype] + h[ntype]
            
            # Apply layer norm and dropout with correct dimensions
            for ntype in h_new:
                h_new[ntype] = self.layer_norms[i][ntype](h_new[ntype])
                if i < len(self.sage_layers) - 1:
                    h_new[ntype] = self.dropout(h_new[ntype])
            
            h = h_new
        
        return h

class MultiEdgeTypePredictor(nn.Module):
    """
    Advanced predictor that can handle multiple edge types simultaneously
    """
    def __init__(self, in_features, hidden_features=128, edge_types=None):
        super().__init__()
        self.edge_types = edge_types or []
        
        # Shared feature transformation
        self.feature_transform = nn.Sequential(
            nn.Linear(in_features * 2, hidden_features),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.BatchNorm1d(hidden_features)
        )
        
        # Edge type specific predictors
        self.edge_predictors = nn.ModuleDict({
            f"{etype[0]}_{etype[1]}_{etype[2]}": nn.Sequential(
                nn.Linear(hidden_features, hidden_features // 2),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_features // 2, 1)
            ) for etype in self.edge_types
        })
        
        # Edge type embeddings for better prediction
        self.edge_type_embeddings = nn.Embedding(len(self.edge_types), hidden_features)
        
    def forward(self, graph, h, etype):
        with graph.local_scope():
            graph.ndata['h'] = h
            
            # Concatenate source and destination features
            src_type, edge_type, dst_type = etype
            src_nodes, dst_nodes = graph.edges(etype=etype)
            
            src_feat = h[src_type][src_nodes]
            dst_feat = h[dst_type][dst_nodes]
            edge_feat = torch.cat([src_feat, dst_feat], dim=1)
            
            # Transform features
            transformed_feat = self.feature_transform(edge_feat)
            
            # Add edge type embedding
            etype_key = f"{etype[0]}_{etype[1]}_{etype[2]}"
            if etype_key in self.edge_predictors and etype in self.edge_types:
                # Get edge type embedding
                etype_idx = torch.tensor([self.edge_types.index(etype)], device=edge_feat.device)
                etype_emb = self.edge_type_embeddings(etype_idx).expand(transformed_feat.size(0), -1)
                
                # Combine with edge features
                combined_feat = transformed_feat + etype_emb
                
                # Predict scores
                scores = self.edge_predictors[etype_key](combined_feat).squeeze()
            else:
                # Fallback to simple dot product
                scores = (src_feat * dst_feat).sum(dim=1)
            
            return scores

class HeteroMLPPredictor(nn.Module):
    """Enhanced MLP predictor with residual connections"""
    def __init__(self, in_features, hidden_features=128):
        super().__init__()
        self.input_proj = nn.Linear(in_features * 2, hidden_features)
        
        # Residual blocks
        self.residual_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_features, hidden_features),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(hidden_features, hidden_features),
                nn.ReLU()
            ) for _ in range(2)
        ])
        
        self.output_proj = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(hidden_features, 1)
        )
        
    def forward(self, graph, h, etype):
        with graph.local_scope():
            src_type, edge_type, dst_type = etype
            src_nodes, dst_nodes = graph.edges(etype=etype)
            
            src_feat = h[src_type][src_nodes]
            dst_feat = h[dst_type][dst_nodes]
            edge_feat = torch.cat([src_feat, dst_feat], dim=1)
            
            # Input projection
            x = self.input_proj(edge_feat)
            
            # Residual blocks
            for block in self.residual_blocks:
                residual = x
                x = block(x) + residual
            
            # Output projection
            scores = self.output_proj(x).squeeze()
            return scores

class AdvancedHeteroLinkPredictor(nn.Module):
    """
    Advanced heterogeneous link prediction model with multiple edge type support
    """
    def __init__(self, node_types, edge_types, in_feats, hidden_feats, out_feats, 
                 num_layers=3, use_attention=True, predictor_type='multi_edge', target_etypes=None):
        super().__init__()
        self.node_types = node_types
        self.edge_types = edge_types
        self.target_etypes = target_etypes or edge_types
        
        # Advanced GNN backbone (uses all edge types for message passing)
        self.gnn = HeteroGraphSAGE(
            node_types=node_types,
            edge_types=edge_types,
            in_feats=in_feats,
            hidden_feats=hidden_feats,
            out_feats=out_feats,
            num_layers=num_layers,
            use_attention=use_attention
        )
        
        # Advanced predictor (only for target edge types)
        if predictor_type == 'multi_edge':
            self.predictor = MultiEdgeTypePredictor(out_feats, hidden_feats, self.target_etypes)
        else:
            self.predictor = HeteroMLPPredictor(out_feats)
            
    def forward(self, pos_graph, neg_graph, blocks, x, etype):
        # Get node representations
        h = self.gnn(blocks, x)
        
        # Compute scores
        pos_score = self.predictor(pos_graph, h, etype)
        neg_score = self.predictor(neg_graph, h, etype)
        
        return pos_score, neg_score
    
    def get_embeddings(self, graph, x):
        """
        Get node embeddings for the entire graph
        """
        h = self.gnn([graph], x)
        return {ntype: h[ntype].detach() for ntype in self.node_types}

def advanced_negative_sampling(graph, etype, k=1, method='uniform'):
    """
    Advanced negative sampling with different strategies
    """
    if method == 'uniform':
        return negative_sampling(graph, etype, k)
    elif method == 'popularity_based':
        # Sample negative examples based on node popularity
        src_type, _, dst_type = etype
        src, dst = graph.edges(etype=etype)
        
        # Calculate node degrees for popularity-based sampling
        try:
            dst_degrees = graph.in_degrees(graph.nodes(dst_type), etype=etype).float()
        except:
            # Fallback to uniform sampling if degree calculation fails
            return negative_sampling(graph, etype, k)
            
        if dst_degrees.sum() == 0:
            # Fallback to uniform sampling if no degrees
            return negative_sampling(graph, etype, k)
            
        dst_probs = dst_degrees / dst_degrees.sum()
        
        # Sample negative destinations based on popularity
        neg_dst = torch.multinomial(dst_probs, len(src) * k, replacement=True)
        neg_src = src.repeat_interleave(k)
        
        # Create negative graph
        neg_graph = dgl.heterograph(
            {etype: (neg_src, neg_dst)},
            num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}
        )
        return neg_graph
    else:
        raise ValueError(f"Unknown negative sampling method: {method}")

def negative_sampling(graph, etype, k=1):
    """
    Standard negative sampling implementation using manual sampling
    """
    src_type, _, dst_type = etype
    src, dst = graph.edges(etype=etype)
    
    # Manual negative sampling
    num_pos_edges = len(src)
    num_neg_edges = num_pos_edges * k
    
    # Sample negative source nodes (same as positive)
    neg_src = src.repeat_interleave(k)
    
    # Sample negative destination nodes uniformly
    num_dst_nodes = graph.num_nodes(dst_type)
    neg_dst = torch.randint(0, num_dst_nodes, (num_neg_edges,), device=src.device)
    
    # Create negative graph
    neg_graph = dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}
    )
    
    return neg_graph

def multi_metric_evaluation(pos_scores, neg_scores):
    """
    Compute multiple evaluation metrics
    """
    scores = torch.cat([pos_scores, neg_scores]).detach().cpu().numpy()
    labels = torch.cat([
        torch.ones(pos_scores.shape[0]),
        torch.zeros(neg_scores.shape[0])
    ]).cpu().numpy()
    
    auc = roc_auc_score(labels, scores)
    ap = average_precision_score(labels, scores)
    
    # Compute Hit@K metrics
    k_values = [1, 5, 10]
    hit_at_k = {}
    
    for k in k_values:
        # Sort scores in descending order
        sorted_indices = np.argsort(scores)[::-1]
        top_k_labels = labels[sorted_indices[:k]]
        hit_at_k[f'Hit@{k}'] = np.sum(top_k_labels) / min(k, np.sum(labels))
    
    return {
        'AUC': auc,
        'AP': ap,
        **hit_at_k
    }

# Enhanced loss function with margin-based ranking
def compute_margin_loss(pos_score, neg_score, margin=1.0):
    """
    Compute margin-based ranking loss
    """
    # Expand dimensions for broadcasting
    pos_score = pos_score.unsqueeze(1)  # [batch_size, 1]
    neg_score = neg_score.unsqueeze(0)  # [1, num_neg]
    
    # Compute margin loss
    loss = torch.clamp(margin - pos_score + neg_score, min=0)
    return loss.mean()

def compute_loss(pos_score, neg_score, loss_type='bce'):
    """
    Compute loss with different loss types
    """
    if loss_type == 'bce':
        pos_label = torch.ones_like(pos_score)
        neg_label = torch.zeros_like(neg_score)
        scores = torch.cat([pos_score, neg_score])
        labels = torch.cat([pos_label, neg_label])
        return F.binary_cross_entropy_with_logits(scores, labels)
    elif loss_type == 'margin':
        return compute_margin_loss(pos_score, neg_score)
    else:
        raise ValueError(f"Unknown loss type: {loss_type}")


In [None]:
import torch
import dgl
import numpy as np
from collections import defaultdict
from sklearn.model_selection import train_test_split

def split_edges_consistent(graph, target_etypes, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
    """
    Split edges consistently across different edge types in a heterogeneous graph.
    
    Args:
        graph: DGL heterogeneous graph
        target_etypes: List of edge types to split
        train_ratio: Ratio for training set
        val_ratio: Ratio for validation set  
        test_ratio: Ratio for test set
        random_state: Random seed for reproducibility
    
    Returns:
        Dictionary containing train/val/test edge indices for each edge type
    """
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, "Ratios must sum to 1.0"
    
    edge_splits = {}
    np.random.seed(random_state)
    
    for etype in target_etypes:
        src, dst = graph.edges(etype=etype)
        num_edges = len(src)
        
        if num_edges == 0:
            edge_splits[etype] = {
                'train': {'src': torch.tensor([]), 'dst': torch.tensor([])},
                'val': {'src': torch.tensor([]), 'dst': torch.tensor([])},
                'test': {'src': torch.tensor([]), 'dst': torch.tensor([])}
            }
            continue
        
        # Create edge indices
        edge_indices = np.arange(num_edges)
        
        # First split: train vs (val + test)
        train_indices, temp_indices = train_test_split(
            edge_indices, 
            train_size=train_ratio, 
            random_state=random_state
        )
        
        # Second split: val vs test from the remaining
        val_size = val_ratio / (val_ratio + test_ratio)
        val_indices, test_indices = train_test_split(
            temp_indices, 
            train_size=val_size, 
            random_state=random_state
        )
        
        # Store the splits
        edge_splits[etype] = {
            'train': {
                'src': src[train_indices],
                'dst': dst[train_indices]
            },
            'val': {
                'src': src[val_indices], 
                'dst': dst[val_indices]
            },
            'test': {
                'src': src[test_indices],
                'dst': dst[test_indices]
            }
        }
        
        print(f"Edge type {etype}: {len(train_indices)} train, {len(val_indices)} val, {len(test_indices)} test")
    
    return edge_splits

def create_train_graph(graph, edge_splits, target_etypes):
    """
    Create a training graph that only contains training edges.
    """
    train_edge_dict = {}
    
    # Add all non-target edge types (keep full connectivity for message passing)
    for etype in graph.canonical_etypes:
        if etype not in target_etypes:
            src, dst = graph.edges(etype=etype)
            train_edge_dict[etype] = (src, dst)
        else:
            # Only add training edges for target edge types
            train_edges = edge_splits[etype]['train']
            if len(train_edges['src']) > 0:
                train_edge_dict[etype] = (train_edges['src'], train_edges['dst'])
    
    # Create training graph
    train_graph = dgl.heterograph(
        train_edge_dict,
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}
    )
    
    # Copy node features
    for ntype in graph.ntypes:
        if graph.nodes[ntype].data:
            for key, value in graph.nodes[ntype].data.items():
                train_graph.nodes[ntype].data[key] = value
    
    return train_graph

def evaluate_on_split(model, graph, edge_splits, target_etypes, split_name, device, num_neg_samples=5):
    """
    Evaluate model on a specific data split (val or test) using multiple negative samples.
    
    Args:
        model: The trained model
        graph: Full graph for message passing
        edge_splits: Dictionary containing edge splits
        target_etypes: List of target edge types
        split_name: 'val' or 'test'
        device: Device to run evaluation on
        num_neg_samples: Number of different negative graphs to average over
    
    Returns:
        split_metrics: Dictionary with averaged metrics
        metrics_std: Dictionary with standard deviations (only returned for final evaluation)
    """
    model.eval()
    split_metrics = defaultdict(list)
    all_metrics = defaultdict(lambda: defaultdict(list))  # For calculating std
    
    with torch.no_grad():
        for target_etype in target_etypes:
            split_edges = edge_splits[target_etype][split_name]
            
            if len(split_edges['src']) == 0:
                continue
            
            # Create positive graph for evaluation
            pos_graph = dgl.heterograph(
                {target_etype: (split_edges['src'], split_edges['dst'])},
                num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}
            ).to(device)
            
            # Prepare input features
            input_features = {ntype: graph.nodes[ntype].data['x'] for ntype in graph.ntypes}
            
            # Create blocks for GNN (using full graph for message passing)
            blocks = [graph, graph, graph]
            
            # Evaluate with multiple negative samples
            etype_metrics = defaultdict(list)
            
            for neg_sample_idx in range(num_neg_samples):
                # Generate different negative samples for each iteration
                neg_graph = advanced_negative_sampling(
                    pos_graph, target_etype, k=1, method='uniform'
                ).to(device)
                
                # Forward pass
                pos_score, neg_score = model(pos_graph, neg_graph, blocks, input_features, target_etype)
                
                # Compute metrics for this negative sample
                metrics = multi_metric_evaluation(pos_score, neg_score)
                for metric_name, value in metrics.items():
                    etype_metrics[metric_name].append(value)
                    all_metrics[target_etype][metric_name].append(value)
            
            # Average metrics across all negative samples
            for metric_name, values in etype_metrics.items():
                avg_metric = np.mean(values)
                split_metrics[f"{target_etype}_{metric_name}"].append(avg_metric)
    
    # Calculate standard deviations for detailed evaluation
    metrics_std = {}
    for etype in target_etypes:
        for metric_name in all_metrics[etype]:
            std_key = f"{etype}_{metric_name}_std"
            metrics_std[std_key] = np.std(all_metrics[etype][metric_name])
    
    return split_metrics, metrics_std

# Your existing code for graph creation and model setup
graph = db.graph

print(f"Graph created with {graph.num_nodes()} nodes and {graph.num_edges()} edges")
print(f"Node types: {graph.ntypes}")
print(f"Edge types: {graph.etypes}")

# Create model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

node_types = graph.ntypes
edge_types = graph.etypes

in_feats = {ntype: 128 for ntype in node_types}
target_entities = ['drug', 'disease', 'geneprotein', 'effectphenotype']

# Choose multiple edge types for prediction
target_etypes = [ctype for ctype in graph.canonical_etypes if ctype[0] in target_entities and ctype[2] in target_entities]
# target_etypes = [ctype for ctype in graph.canonical_etypes]

print(f"Target edge types for prediction: {target_etypes}")

hidden_feats = 256
out_feats = 128

model = AdvancedHeteroLinkPredictor(
    node_types=node_types,
    edge_types=edge_types,  # All edge types for GNN layers
    in_feats=in_feats,
    hidden_feats=hidden_feats,
    out_feats=out_feats,
    num_layers=3,
    use_attention=True,
    predictor_type='multi_edge',
    target_etypes=target_etypes  # Only target edge types for prediction
).to(device)

print(f"Model created with {sum(p.numel() for p in model.parameters())} parameters")

# ==================== NEW: Edge Splitting ====================
print("\nSplitting edges into train/val/test sets...")
edge_splits = split_edges_consistent(
    graph, 
    target_etypes, 
    train_ratio=0.7, 
    val_ratio=0.15, 
    test_ratio=0.15,
    random_state=42
)

# Create training graph (only contains training edges for target edge types)
train_graph = create_train_graph(graph, edge_splits, target_etypes)
print(f"Training graph created with {train_graph.num_edges()} edges")

# Move graphs to device
graph = graph.to(device)
train_graph = train_graph.to(device)

# Training configuration
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5)

# ==================== MODIFIED: Training Loop ====================
print("Starting training with proper train/val splits...")
best_val_metrics = {etype: 0.0 for etype in target_etypes}
patience_counter = 0
max_patience = 20

for epoch in range(100):
    model.train()
    total_loss = 0
    
    # Train on each edge type using only training edges
    for target_etype in target_etypes:
        train_edges = edge_splits[target_etype]['train']
        
        if len(train_edges['src']) == 0:
            continue
        
        src, dst = train_edges['src'], train_edges['dst']
        
        # Create mini-batches
        batch_size = min(1000000, len(src))
        num_batches = (len(src) + batch_size - 1) // batch_size
        
        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(src))
            
            batch_src = src[start_idx:end_idx]
            batch_dst = dst[start_idx:end_idx]
            
            # Create positive graph for batch
            pos_graph = dgl.heterograph(
                {target_etype: (batch_src, batch_dst)},
                num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}
            ).to(device)
            
            # Generate negative samples
            neg_graph = advanced_negative_sampling(
                pos_graph, target_etype, k=1, method='uniform'
            ).to(device)
            
            # Prepare input features (use full graph features)
            input_features = {ntype: graph.nodes[ntype].data['x'] for ntype in node_types}
            
            # Create blocks for GNN (use full graph for message passing)
            blocks = [graph, graph, graph]
            
            # Forward pass
            pos_score, neg_score = model(pos_graph, neg_graph, blocks, input_features, target_etype)
            
            # Compute loss
            loss = compute_loss(pos_score, neg_score)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
    
    # ==================== NEW: Validation Evaluation ====================
    if epoch % 5 == 0:  # Evaluate every 5 epochs
        # Evaluate on validation set
        val_metrics = evaluate_on_split(model, graph, edge_splits, target_etypes, 'val', device)
        
        print(f"\nEpoch {epoch:03d} | Loss: {total_loss:.4f}")
        
        # Track best validation metrics and calculate average
        epoch_improved = False
        epoch_aucs = []
        for etype in target_etypes:
            if f"{etype}_AUC" in val_metrics:
                auc_values = val_metrics[f"{etype}_AUC"]
                if auc_values:
                    avg_auc = np.mean(auc_values)
                    epoch_aucs.append(avg_auc)
                    print(f"  {etype} Val AUC: {avg_auc:.4f}")
                    
                    if avg_auc > best_val_metrics[etype]:
                        best_val_metrics[etype] = avg_auc
                        epoch_improved = True
        
        # Print average AUC across all edge types
        if epoch_aucs:
            avg_auc_all = np.mean(epoch_aucs)
            print(f"  Average Val AUC: {avg_auc_all:.4f}")
        
        # Early stopping logic
        if epoch_improved:
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= max_patience:
            print(f"\nEarly stopping at epoch {epoch} (no improvement for {max_patience} evaluations)")
            break
    
    # Scheduler step
    scheduler.step(total_loss)

print("\nTraining completed!")

# ==================== NEW: Final Evaluation ====================
print("\n" + "="*50)
print("FINAL EVALUATION")
print("="*50)

# Validation set evaluation
print("\nValidation Set Results:")
val_metrics = evaluate_on_split(model, graph, edge_splits, target_etypes, 'val', device)
val_aucs = []
for etype in target_etypes:
    if f"{etype}_AUC" in val_metrics:
        auc_values = val_metrics[f"{etype}_AUC"]
        if auc_values:
            avg_auc = np.mean(auc_values)
            val_aucs.append(avg_auc)
            print(f"  {etype} AUC: {avg_auc:.4f}")

if val_aucs:
    avg_val_auc = np.mean(val_aucs)
    print(f"  Average Val AUC: {avg_val_auc:.4f}")

# Test set evaluation  
print("\nTest Set Results:")
test_metrics = evaluate_on_split(model, graph, edge_splits, target_etypes, 'test', device)
test_aucs = []
for etype in target_etypes:
    if f"{etype}_AUC" in test_metrics:
        auc_values = test_metrics[f"{etype}_AUC"]
        if auc_values:
            avg_auc = np.mean(auc_values)
            test_aucs.append(avg_auc)
            print(f"  {etype} AUC: {avg_auc:.4f}")

if test_aucs:
    avg_test_auc = np.mean(test_aucs)
    print(f"  Average Test AUC: {avg_test_auc:.4f}")

print("\nBest validation metrics achieved during training:")
best_val_aucs = []
for etype, best_auc in best_val_metrics.items():
    best_val_aucs.append(best_auc)
    print(f"  {etype}: {best_auc:.4f}")

if best_val_aucs:
    avg_best_val_auc = np.mean(best_val_aucs)
    print(f"  Average Best Val AUC: {avg_best_val_auc:.4f}")

In [None]:
embs = model.get_embeddings(graph, {ntype: graph.nodes[ntype].data['x'] for ntype in node_types})