In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from torch.utils.data import DataLoader, Dataset
from collections import defaultdict
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
import warnings
from DeepGraphDB import DeepGraphDB
import random

class HeteroNBFLayer(nn.Module):
    """
    Neural Bellman-Ford Layer for DGL HeteroGraph
    Implements MESSAGE, AGGREGATE and boundary conditions for heterogeneous graphs
    """
    def __init__(self, input_dim: int, hidden_dim: int, 
                 message_func: str = "distmult", aggregate_func: str = "pna", 
                 layer_norm: bool = True, dropout: float = 0.1):
        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.message_func = message_func
        self.aggregate_func = aggregate_func
        
        # Layer normalization and dropout
        self.layer_norm = nn.LayerNorm(input_dim) if layer_norm else None
        self.dropout = nn.Dropout(dropout)
        
        # Aggregation components
        if aggregate_func == "pna":
            # Principal Neighborhood Aggregation
            self.pna_linear = nn.Linear(input_dim * 4, input_dim)  # mean, max, min, std
        elif aggregate_func in ["sum", "mean", "max"]:
            pass  # No additional parameters needed
        else:
            raise ValueError(f"Unknown aggregate function: {aggregate_func}")
    
    def forward(self, hg: dgl.DGLGraph, node_features: Dict[str, torch.Tensor], 
                relation_embeddings: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Forward pass for heterogeneous NBF layer
        
        Args:
            hg: DGL HeteroGraph
            node_features: Dict mapping node types to feature tensors
            relation_embeddings: Dict mapping relation types to embedding tensors
            
        Returns:
            Updated node features by type
        """
        # Store features in graph
        for ntype, features in node_features.items():
            hg.nodes[ntype].data['h'] = features
        
        # Apply message passing for each edge type (relation)
        funcs = {}
        for etype in hg.canonical_etypes:
            src_type, rel_type, dst_type = etype
            funcs[etype] = (self.message_func_hetero(rel_type, relation_embeddings), 
                           self.aggregate_func_hetero)
        
        # Execute heterogeneous message passing
        hg.multi_update_all(funcs, 'sum')
        
        # Extract updated features
        updated_features = {}
        for ntype in hg.ntypes:
            if 'h_new' in hg.nodes[ntype].data:
                h_new = hg.nodes[ntype].data['h_new']
                
                # Add residual connection (boundary condition)
                if 'h0' in hg.nodes[ntype].data:
                    h_new = h_new + hg.nodes[ntype].data['h0']
                
                # Apply layer normalization and dropout
                if self.layer_norm:
                    h_new = self.layer_norm(h_new)
                h_new = self.dropout(h_new)
                
                updated_features[ntype] = h_new
            else:
                # If no messages received, keep original features
                updated_features[ntype] = node_features[ntype]
        
        return updated_features
    
    def message_func_hetero(self, rel_type: str, relation_embeddings: Dict[str, torch.Tensor]):
        """Create message function for specific relation type"""
        def message_func(edges):
            h_src = edges.src['h']
            rel_emb = relation_embeddings[rel_type]
            
            # Broadcast relation embedding to match batch size
            if rel_emb.dim() == 1:
                rel_emb = rel_emb.unsqueeze(0).expand(h_src.size(0), -1)
            elif rel_emb.size(0) == 1 and h_src.size(0) > 1:
                rel_emb = rel_emb.expand(h_src.size(0), -1)
            
            if self.message_func == "transe":
                # Translation: h + r
                message = h_src + rel_emb
            elif self.message_func == "distmult":
                # Element-wise multiplication: h ⊙ r
                message = h_src * rel_emb
            elif self.message_func == "rotate":
                # Complex rotation (for complex embeddings)
                if self.input_dim % 2 != 0:
                    raise ValueError("RotatE requires even embedding dimension")
                h_re, h_im = h_src.chunk(2, dim=-1)
                r_re, r_im = rel_emb.chunk(2, dim=-1)
                message_re = h_re * r_re - h_im * r_im
                message_im = h_re * r_im + h_im * r_re
                message = torch.cat([message_re, message_im], dim=-1)
            else:
                raise ValueError(f"Unknown message function: {self.message_func}")
            
            return {'msg': message}
        
        return message_func
    
    def aggregate_func_hetero(self, nodes):
        """Aggregate function for heterogeneous message passing"""
        if 'msg' not in nodes.mailbox:
            # No messages received
            return {}
        
        messages = nodes.mailbox['msg']  # Shape: (num_nodes, num_neighbors, hidden_dim)
        
        if self.aggregate_func == "sum":
            aggregated = torch.sum(messages, dim=1)
        elif self.aggregate_func == "mean":
            aggregated = torch.mean(messages, dim=1)
        elif self.aggregate_func == "max":
            aggregated = torch.max(messages, dim=1)[0]
        elif self.aggregate_func == "pna":
            # Principal Neighborhood Aggregation
            mean_msg = torch.mean(messages, dim=1)
            max_msg = torch.max(messages, dim=1)[0]
            min_msg = torch.min(messages, dim=1)[0]
            std_msg = torch.std(messages, dim=1)
            
            pna_features = torch.cat([mean_msg, max_msg, min_msg, std_msg], dim=-1)
            aggregated = self.pna_linear(pna_features)
        else:
            raise ValueError(f"Unknown aggregate function: {self.aggregate_func}")
        
        return {'h_new': aggregated}


class HeteroNBFNet(nn.Module):
    """
    Neural Bellman-Ford Network for DGL HeteroGraph
    """
    def __init__(self, node_types: List[str], edge_types: List[Tuple[str, str, str]], 
                 embedding_dim: int = 32, num_layers: int = 6, 
                 message_func: str = "distmult", aggregate_func: str = "pna", 
                 layer_norm: bool = True, dropout: float = 0.1):
        super().__init__()
        
        self.node_types = node_types
        self.edge_types = edge_types  # List of (src_type, rel_type, dst_type)
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        
        # Extract relation types
        self.relation_types = list(set(etype[1] for etype in edge_types))
        
        # Relation embeddings for each relation type
        self.relation_embeddings = nn.ModuleDict({
            rel_type: nn.Embedding(1, embedding_dim)
            for rel_type in self.relation_types
        })
        
        # Query embeddings for INDICATOR function (one per relation type)
        self.query_embeddings = nn.ModuleDict({
            rel_type: nn.Embedding(1, embedding_dim)
            for rel_type in self.relation_types
        })
        
        # NBF layers
        self.layers = nn.ModuleList([
            HeteroNBFLayer(embedding_dim, embedding_dim, message_func, 
                          aggregate_func, layer_norm, dropout)
            for _ in range(num_layers)
        ])
        
        # Output projection for link prediction scores
        self.output_linear = nn.Linear(embedding_dim, 1)
        
        # Initialize parameters
        self.reset_parameters()
    
    def reset_parameters(self):
        """Initialize model parameters"""
        for rel_emb in self.relation_embeddings.values():
            nn.init.xavier_uniform_(rel_emb.weight)
        for query_emb in self.query_embeddings.values():
            nn.init.xavier_uniform_(query_emb.weight)
        nn.init.xavier_uniform_(self.output_linear.weight)
        nn.init.zeros_(self.output_linear.bias)
    
    def indicator_function(self, hg: dgl.DGLGraph, source_nodes: Dict[int, torch.Tensor], 
                          query_relation: str) -> Dict[str, torch.Tensor]:
        """
        INDICATOR function: Initialize boundary conditions for heterogeneous graph
        
        Args:
            hg: DGL HeteroGraph
            source_nodes: Dict mapping node types to source node indices
            query_relation: Query relation type
            
        Returns:
            Initial node features by type
        """
        h0 = {}
        query_emb = self.query_embeddings[query_relation](torch.tensor([0], device=next(iter(source_nodes.values())).device))
        query_emb = query_emb.squeeze(0)  # Remove batch dimension
        
        for ntype in hg.ntypes:
            num_nodes = hg.number_of_nodes(ntype)
            device = next(iter(source_nodes.values())).device
            
            # Initialize all nodes with zeros
            h0_type = torch.zeros(num_nodes, self.embedding_dim, device=device)
            
            # Set source nodes with query embeddings
            if ntype in source_nodes and len(source_nodes[ntype]) > 0:
                h0_type[source_nodes[ntype]] = query_emb.unsqueeze(0).expand(len(source_nodes[ntype]), -1)
            
            h0[ntype] = h0_type
        
        return h0
    
    def forward(self, hg: dgl.DGLGraph, source_nodes: Dict[int, torch.Tensor], 
                target_nodes: Dict[str, torch.Tensor], query_relation: str) -> torch.Tensor:
        """
        Forward pass of HeteroNBFNet
        
        Args:
            hg: DGL HeteroGraph
            source_nodes: Dict mapping node types to source node indices
            target_nodes: Dict mapping node types to target node indices
            query_relation: Query relation type
        
        Returns:
            Link prediction scores
        """
        # Initialize boundary conditions (INDICATOR function)
        h0 = self.indicator_function(hg, source_nodes, query_relation)
        
        # Store boundary conditions in graph
        for ntype, features in h0.items():
            hg.nodes[ntype].data['h0'] = features
        
        # Initialize current features
        current_features = {ntype: h0[ntype].clone() for ntype in h0.keys()}
        
        # Get relation embeddings
        rel_embeddings = {}
        device = next(iter(source_nodes.values())).device
        for rel_type in self.relation_types:
            rel_embeddings[rel_type] = self.relation_embeddings[rel_type](
                torch.tensor([0], device=device)
            ).squeeze(0)
        
        # Apply NBF layers iteratively
        for layer in self.layers:
            current_features = layer(hg, current_features, rel_embeddings)
        
        # Extract final representations for target nodes and compute scores
        scores = []
        for ntype, target_indices in target_nodes.items():
            if len(target_indices) > 0 and ntype in current_features:
                target_features = current_features[ntype][target_indices]
                target_scores = self.output_linear(target_features).squeeze(-1)
                scores.append(target_scores)
        
        if scores:
            return torch.cat(scores, dim=0)
        else:
            # Return empty tensor if no valid targets
            device = next(iter(source_nodes.values())).device
            return torch.tensor([], device=device)
    
    def get_embeddings_by_type(self, hg: dgl.DGLGraph) -> Dict[str, torch.Tensor]:
        """
        Return final node embeddings organized by type
        
        Args:
            hg: DGL HeteroGraph with computed node features
            
        Returns:
            Dictionary mapping node types to embedding tensors
        """
        embeddings_by_type = {}
        
        for ntype in hg.ntypes:
            if 'h' in hg.nodes[ntype].data:
                embeddings_by_type[ntype] = hg.nodes[ntype].data['h']
            elif f'h_{self.num_layers-1}' in hg.nodes[ntype].data:
                # Try to get last layer features
                embeddings_by_type[ntype] = hg.nodes[ntype].data[f'h_{self.num_layers-1}']
        
        return embeddings_by_type

In [None]:
class HeteroKGDataset(Dataset):
    """Dataset for heterogeneous knowledge graph triplets"""
    def __init__(self, triplets: List[Tuple[str, str, str]], 
                 entity_types: Dict[str, str],
                 entity_to_id: Dict[str, int],
                 num_negatives: int = 32):
        self.triplets = triplets
        self.entity_types = entity_types
        self.entity_to_id = entity_to_id
        self.num_negatives = num_negatives
        
        # Group entities by type for negative sampling
        self.entities_by_type = defaultdict(list)
        for entity_name, entity_type in entity_types.items():
            if entity_name in entity_to_id:
                self.entities_by_type[entity_type].append(entity_to_id[entity_name])
    
    def __len__(self):
        return len(self.triplets)
    
    def __getitem__(self, idx):
        head_name, relation, tail_name = self.triplets[idx]
        
        if head_name not in self.entity_to_id or tail_name not in self.entity_to_id:
            # Return empty sample for invalid triplets
            return {
                'positive': None,
                'negatives': []
            }
        
        head_id = self.entity_to_id[head_name]
        tail_id = self.entity_to_id[tail_name]
        head_type = self.entity_types[head_name]
        tail_type = self.entity_types[tail_name]
        
        # Generate negative samples
        negatives = []
        tail_candidates = self.entities_by_type[tail_type]
        
        if len(tail_candidates) > 1:
            for _ in range(min(self.num_negatives, len(tail_candidates) - 1)):
                neg_tail = np.random.choice(tail_candidates)
                while neg_tail == tail_id:
                    neg_tail = np.random.choice(tail_candidates)
                negatives.append((head_name, relation, self._id_to_entity_name(neg_tail)))
        
        return {
            'positive': (head_name, relation, tail_name),
            'negatives': negatives
        }
    
    def _id_to_entity_name(self, entity_id: int) -> str:
        """Convert entity ID back to name"""
        for name, eid in self.entity_to_id.items():
            if eid == entity_id:
                return name
        return f"unknown_{entity_id}"


class HeteroNBFNetTrainer:
    """Trainer for HeteroNBFNet with batch training support"""
    
    def __init__(self, model: HeteroNBFNet, device: str = 'cuda'):
        self.model = model.to(device)
        self.device = device
        self.optimizer = None
        self.criterion = nn.BCEWithLogitsLoss()
    
    def setup_optimizer(self, lr: float = 5e-3):
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
    
    def train_batch(self, hg: dgl.DGLGraph, batch_data: Dict, 
                   entity_types: Dict[str, str], entity_to_id: Dict[str, int]):
        """Train on a batch of heterogeneous triplets"""
        self.model.train()
        self.optimizer.zero_grad()
        
        # Prepare positive and negative samples
        positive_triplets = [item for item in batch_data['positives'] if item is not None]
        negative_triplets = []
        for negs in batch_data['negatives']:
            negative_triplets.extend(negs)
        
        if not positive_triplets:
            return 0.0
        
        all_triplets = positive_triplets + negative_triplets
        
        # Group by relation type
        triplets_by_relation = defaultdict(list)
        for triplet in all_triplets:
            head_name, relation, tail_name = triplet
            triplets_by_relation[relation].append(triplet)
        
        total_loss = 0.0
        num_relations = 0
        
        # Process each relation separately
        for relation, rel_triplets in triplets_by_relation.items():
            # Separate positive and negative for this relation
            num_pos_rel = sum(1 for t in positive_triplets if t[1] == relation)
            
            # Group source and target nodes by type
            source_nodes = defaultdict(list)
            target_nodes = defaultdict(list)
            labels = []
            
            for i, (head_name, rel, tail_name) in enumerate(rel_triplets):
                if head_name in entity_to_id and tail_name in entity_to_id:
                    head_id = entity_to_id[head_name]
                    tail_id = entity_to_id[tail_name]
                    head_type = entity_types[head_name]
                    tail_type = entity_types[tail_name]
                    
                    source_nodes[head_type].append(head_id)
                    target_nodes[tail_type].append(tail_id)
                    
                    # Label: 1 for positive, 0 for negative
                    labels.append(1.0 if i < num_pos_rel else 0.0)
            
            if not labels:
                continue
            
            # Convert to tensors
            source_nodes_tensor = {}
            target_nodes_tensor = {}
            
            for ntype, nodes in source_nodes.items():
                if nodes:
                    source_nodes_tensor[ntype] = torch.tensor(nodes, device=self.device)
            
            for ntype, nodes in target_nodes.items():
                if nodes:
                    target_nodes_tensor[ntype] = torch.tensor(nodes, device=self.device)
            
            labels_tensor = torch.tensor(labels, device=self.device)
            
            # Forward pass
            if source_nodes_tensor and target_nodes_tensor:
                scores = self.model(hg.to(self.device), source_nodes_tensor, 
                                  target_nodes_tensor, relation)
                
                if len(scores) > 0:
                    # Compute loss
                    loss = self.criterion(scores, labels_tensor)
                    total_loss += loss
                    num_relations += 1
        
        if num_relations > 0:
            # Average loss across relations
            avg_loss = total_loss / num_relations
            
            # Backward pass
            avg_loss.backward()
            self.optimizer.step()
            
            return avg_loss.item()
        
        return 0.0


def collate_hetero_batch(batch):
    """Collate function for HeteroKGDataset"""
    positives = []
    negatives = []
    
    for item in batch:
        if item['positive'] is not None:
            positives.append(item['positive'])
        negatives.append(item['negatives'])
    
    return {
        'positives': positives,
        'negatives': negatives
    }


# Example usage and training script
def train_hetero_nbfnet(hg: dgl.DGLGraph,
                       train_triplets: List[Tuple[str, str, str]], 
                       entity_types: Dict[str, str],
                       val_triplets: Optional[List[Tuple[str, str, str]]] = None,
                       batch_size: int = 256, num_epochs: int = 20,
                       embedding_dim: int = 32, num_layers: int = 6,
                       lr: float = 5e-3, device: str = 'cuda'):
    """
    Complete training pipeline for HeteroNBFNet
    """
    
    # Create entity mappings
    entity_to_id = {}
    entity_id = 0
    
    for triplet in train_triplets + (val_triplets or []):
        for entity in [triplet[0], triplet[2]]:  # head and tail
            if entity not in entity_to_id:
                entity_to_id[entity] = entity_id
                entity_id += 1
    
    print(f"Total entities: {len(entity_to_id)}")
    print(f"Entity types: {set(entity_types.values())}")
    
    print(f"Graph created with:")
    print(f"  Node types: {hg.ntypes}")
    print(f"  Edge types: {hg.canonical_etypes}")
    for ntype in hg.ntypes:
        print(f"  {ntype}: {hg.number_of_nodes(ntype)} nodes")
    
    # Create model
    print("Initializing HeteroNBFNet...")
    model = HeteroNBFNet(
        node_types=hg.ntypes,
        edge_types=hg.canonical_etypes,
        embedding_dim=embedding_dim,
        num_layers=num_layers
    )
    
    # Create trainer
    trainer = HeteroNBFNetTrainer(model, device)
    trainer.setup_optimizer(lr)
    
    # Create datasets and dataloaders
    train_dataset = HeteroKGDataset(train_triplets, entity_types, entity_to_id)
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        collate_fn=collate_hetero_batch
    )
    
    # Training loop
    print("Starting training...")
    
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        
        for batch in train_loader:
            loss = trainer.train_batch(hg, batch, entity_types, entity_to_id)
            total_loss += loss
            num_batches += 1
        
        avg_loss = total_loss / max(num_batches, 1)
        print(f"Epoch {epoch + 1}: Loss = {avg_loss:.4f}")
    
    return model, trainer, hg, entity_to_id

In [None]:
def get_triplets_vectorized(g, mapping):
    triplets = []
    
    for canonical_etype in g.canonical_etypes:
        src_ntype, etype, dst_ntype = canonical_etype
        src, dst = g.edges(etype=canonical_etype)
        
        # Create triplets as tuples directly
        edge_triplets = list(zip(src.tolist(), [mapping[etype]] * len(src), dst.tolist()))
        triplets.extend(edge_triplets)
    
    return triplets

def create_dict_from_tuples(tuples_array):
    # Get unique keys in one pass
    keys = {t[0] for t in tuples_array}
    
    # Pre-allocate dictionary with empty lists
    result = {key: [] for key in keys}
    
    # Single pass to populate
    for key_string, first_int  in tuples_array:
        result[key_string].append(first_int)
    
    return result

gdb = DeepGraphDB()
gdb.load_graph("/home/cc/PHD/dglframework/DeepKG/DeepGraphDB/graphs/primekg.bin")

entity_types = {  }

triplets = list(set(get_triplets_vectorized(gdb.graph, gdb.edge_types_mapping)))

# Split data
random.shuffle(triplets)
split_idx = int(0.7 * len(triplets))
train_triplets = triplets[:split_idx]
val_triplets = triplets[split_idx:]

# Train model
model, trainer, hg, entity_to_id = train_hetero_nbfnet(
    hg=gdb.graph,
    train_triplets=train_triplets,
    entity_types=create_dict_from_tuples(gdb.reverse_node_mapping.keys()),
    val_triplets=val_triplets,
    batch_size=32000,
    num_epochs=50,
    device='cuda'  # Use 'cuda' if available
)

print("Training completed successfully!")
print("HeteroNBFNet with DGL HeteroGraph is working!")

In [None]:
en = create_dict_from_tuples(gdb.reverse_node_mapping.keys())

In [None]:
en.values()