In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.nn as dglnn
from dgl.nn import GraphConv
import numpy as np
from typing import Dict, List, Tuple, Optional
from DeepGraphDB import DeepGraphDB
from sklearn.metrics import roc_auc_score

class HeteroEntityEmbedding(nn.Module):
    """Handles embeddings for heterogeneous entity types"""
    
    def __init__(self, entity_types: List[str], node_counts: Dict[str, int], 
                 embedding_dim: int = 256):
        super().__init__()
        self.entity_types = entity_types
        self.embedding_dim = embedding_dim
        self.node_counts = node_counts
        
        # Create separate embedding tables for each entity type
        self.embeddings = nn.ModuleDict()
        for entity_type in entity_types:
            if node_counts[entity_type] > 0:
                self.embeddings[entity_type] = nn.Embedding(
                    node_counts[entity_type], embedding_dim
                )
                nn.init.xavier_uniform_(self.embeddings[entity_type].weight)
    
    def forward(self, entity_ids: torch.Tensor, entity_types: List[str]) -> torch.Tensor:
        """Get embeddings for entities of different types"""
        batch_size = entity_ids.size(0)
        embeddings = torch.zeros(batch_size, self.embedding_dim, 
                                device=entity_ids.device, dtype=torch.float)
        
        # Group entities by type for efficient lookup
        type_to_indices = {}
        for i, etype in enumerate(entity_types):
            if etype not in type_to_indices:
                type_to_indices[etype] = []
            type_to_indices[etype].append(i)
        
        # Get embeddings for each type
        for etype, indices in type_to_indices.items():
            if etype in self.embeddings:
                indices_tensor = torch.tensor(indices, device=entity_ids.device, dtype=torch.long)
                entity_subset = entity_ids[indices_tensor]
                embeddings[indices_tensor] = self.embeddings[etype](entity_subset)
        
        return embeddings

class RelationEmbedding(nn.Module):
    """Handles relation embeddings"""
    
    def __init__(self, relation_types: List[str], embedding_dim: int = 256):
        super().__init__()
        self.relation_types = relation_types
        self.num_relations = len(relation_types)
        self.embedding_dim = embedding_dim
        
        self.relation_embedding = nn.Embedding(self.num_relations, embedding_dim)
        nn.init.xavier_uniform_(self.relation_embedding.weight)
    
    def forward(self, relation_ids: torch.Tensor) -> torch.Tensor:
        return self.relation_embedding(relation_ids)

class NBFLayer(nn.Module):
    """Neural Bellman-Ford Layer for message passing"""
    
    def __init__(self, input_dim: int, hidden_dim: int, num_relations: int, 
                 message_func: str = 'distmult'):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_relations = num_relations
        self.message_func = message_func
        
        # Message function layers
        if message_func == 'distmult':
            self.message_layer = nn.Linear(input_dim, hidden_dim)
        elif message_func == 'rotate':
            self.message_layer = nn.Linear(input_dim * 2, hidden_dim)
        else:
            self.message_layer = nn.Linear(input_dim + input_dim, hidden_dim)
        
        # Relation-specific transformation
        self.relation_linear = nn.Linear(input_dim, hidden_dim)
        
        # Update function
        self.update_layer = nn.Sequential(
            nn.Linear(hidden_dim + input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Layer norm
        self.layer_norm = nn.LayerNorm(hidden_dim)
        
    def message_function(self, head_emb: torch.Tensor, rel_emb: torch.Tensor, 
                        tail_emb: torch.Tensor) -> torch.Tensor:
        """Compute messages based on the specified function"""
        if self.message_func == 'distmult':
            # Element-wise product
            message = head_emb * rel_emb
        elif self.message_func == 'rotate':
            # Rotation in complex space (simplified)
            real_head = head_emb[:, :self.input_dim//2]
            imag_head = head_emb[:, self.input_dim//2:]
            real_rel = rel_emb[:, :self.input_dim//2]
            imag_rel = rel_emb[:, self.input_dim//2:]
            
            real_msg = real_head * real_rel - imag_head * imag_rel
            imag_msg = real_head * imag_rel + imag_head * real_rel
            message = torch.cat([real_msg, imag_msg], dim=-1)
        else:
            # Concatenation
            message = torch.cat([head_emb, rel_emb], dim=-1)
        
        return self.message_layer(message)
    
    def forward(self, head_emb: torch.Tensor, rel_emb: torch.Tensor, 
                prev_emb: torch.Tensor) -> torch.Tensor:
        """Forward pass of NBF layer"""
        # Compute message
        message = self.message_function(head_emb, rel_emb, prev_emb)
        
        # Update step
        updated = self.update_layer(torch.cat([message, prev_emb], dim=-1))
        
        # Add residual connection and layer norm
        output = self.layer_norm(updated + prev_emb)
        
        return output

class SimpleNBFNet(nn.Module):
    """Simplified NBFNet for heterogeneous knowledge graphs"""
    
    def __init__(self, entity_types: List[str], relation_types: List[str],
                 node_counts: Dict[str, int], embedding_dim: int = 256, 
                 hidden_dim: int = 256, num_layers: int = 3,
                 message_func: str = 'distmult', dropout: float = 0.1):
        super().__init__()
        
        self.entity_types = entity_types
        self.relation_types = relation_types
        self.node_counts = node_counts
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.dropout = dropout
        
        # Embeddings
        self.entity_embedding = HeteroEntityEmbedding(
            entity_types, node_counts, embedding_dim
        )
        self.relation_embedding = RelationEmbedding(
            relation_types, embedding_dim
        )
        
        # Input projection
        self.input_projection = nn.Linear(embedding_dim, hidden_dim)
        
        # NBF layers
        self.nbf_layers = nn.ModuleList([
            NBFLayer(hidden_dim, hidden_dim, len(relation_types), message_func)
            for _ in range(num_layers)
        ])
        
        # Output layers
        self.output_projection = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )
        
        self.dropout_layer = nn.Dropout(dropout)
    
    def forward(self, triplets: torch.Tensor, head_types: List[str], 
                tail_types: List[str]) -> torch.Tensor:
        """
        Forward pass for batch of triplets
        
        Args:
            triplets: (batch_size, 3) tensor of [head, relation, tail]
            head_types: List of head entity types for each triplet
            tail_types: List of tail entity types for each triplet
        
        Returns:
            scores: (batch_size,) tensor of scores
        """
        batch_size = triplets.size(0)
        
        # Extract components
        heads = triplets[:, 0]
        relations = triplets[:, 1]
        tails = triplets[:, 2]
        
        # Get embeddings
        head_emb = self.entity_embedding(heads, head_types)
        rel_emb = self.relation_embedding(relations)
        tail_emb = self.entity_embedding(tails, tail_types)
        
        # Project to hidden dimension
        head_emb = self.input_projection(head_emb)
        rel_emb = self.input_projection(rel_emb)
        tail_emb = self.input_projection(tail_emb)
        
        # Apply dropout
        head_emb = self.dropout_layer(head_emb)
        rel_emb = self.dropout_layer(rel_emb)
        tail_emb = self.dropout_layer(tail_emb)
        
        # Initialize with tail embeddings (for link prediction h,r,?)
        current_emb = tail_emb
        
        # Apply NBF layers
        for layer in self.nbf_layers:
            current_emb = layer(head_emb, rel_emb, current_emb)
            current_emb = self.dropout_layer(current_emb)
        
        # Compute final scores
        scores = self.output_projection(current_emb).squeeze(-1)
        
        return scores
    
    def get_embeddings(self, entity_ids: torch.Tensor, 
                      entity_types: List[str]) -> torch.Tensor:
        """Get entity embeddings for given entities"""
        return self.entity_embedding(entity_ids, entity_types)

class NBFNetTrainer:
    """Training utilities for NBFNet"""
    
    def __init__(self, model: SimpleNBFNet, learning_rate: float = 1e-3,
                 weight_decay: float = 1e-5):
        self.model = model
        self.optimizer = torch.optim.Adam(
            model.parameters(), lr=learning_rate, weight_decay=weight_decay
        )
        self.criterion = nn.BCEWithLogitsLoss()
        
    def train_step(self, pos_triplets: torch.Tensor, neg_triplets: torch.Tensor,
                   pos_head_types: List[str], pos_tail_types: List[str],
                   neg_head_types: List[str], neg_tail_types: List[str]) -> float:
        """Single training step"""
        self.model.train()
        self.optimizer.zero_grad()
        
        # Forward pass
        pos_scores = self.model(pos_triplets, pos_head_types, pos_tail_types)
        neg_scores = self.model(neg_triplets, neg_head_types, neg_tail_types)
        
        # Create labels
        pos_labels = torch.ones(pos_scores.size(0), device=pos_scores.device)
        neg_labels = torch.zeros(neg_scores.size(0), device=neg_scores.device)
        
        # Compute loss
        all_scores = torch.cat([pos_scores, neg_scores], dim=0)
        all_labels = torch.cat([pos_labels, neg_labels], dim=0)
        
        loss = self.criterion(all_scores, all_labels)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        return loss.item()
    
    def evaluate(self, pos_triplets: torch.Tensor, neg_triplets: torch.Tensor,
                pos_head_types: List[str], pos_tail_types: List[str],
                neg_head_types: List[str], neg_tail_types: List[str]) -> Dict[str, float]:
        """Evaluate model performance"""
        
        self.model.eval()
        with torch.no_grad():
            pos_scores = self.model(pos_triplets, pos_head_types, pos_tail_types)
            neg_scores = self.model(neg_triplets, neg_head_types, neg_tail_types)
            
            # Compute accuracy
            pos_preds = (torch.sigmoid(pos_scores) > 0.5).float()
            neg_preds = (torch.sigmoid(neg_scores) > 0.5).float()
            pos_acc = pos_preds.mean().item()
            neg_acc = (1 - neg_preds).mean().item()
            total_acc = (pos_acc + neg_acc) / 2
            
            # Compute AUC
            # Combine positive and negative scores
            all_scores = torch.cat([pos_scores, neg_scores])
            all_labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)])
            
            # Convert to numpy for sklearn
            scores_np = torch.sigmoid(all_scores).cpu().numpy()
            labels_np = all_labels.cpu().numpy()
            
            # Calculate AUC
            auc = roc_auc_score(labels_np, scores_np)
            
            return {
                'pos_accuracy': pos_acc,
                'neg_accuracy': neg_acc,
                'total_accuracy': total_acc,
                'avg_pos_score': torch.sigmoid(pos_scores).mean().item(),
                'avg_neg_score': torch.sigmoid(neg_scores).mean().item(),
                'auc': auc
            }   

def prepare_batch_data(triplets: List[Tuple], triplets_ntypes: List[Tuple], 
                      batch_size: int = 1024) -> List[Tuple]:
    """Prepare batched data for training"""
    batches = []
    
    for i in range(0, len(triplets), batch_size):
        batch_triplets = triplets[i:i+batch_size]
        batch_ntypes = triplets_ntypes[i:i+batch_size] if triplets_ntypes else None
        
        # Convert to tensors
        triplet_tensor = torch.tensor(batch_triplets, dtype=torch.long)
        
        if batch_ntypes:
            head_types = [nt[0] for nt in batch_ntypes]
            tail_types = [nt[1] for nt in batch_ntypes]
        else:
            head_types = ['entity'] * len(batch_triplets)
            tail_types = ['entity'] * len(batch_triplets)
        
        batches.append((triplet_tensor, head_types, tail_types))
    
    return batches

# Example usage and training loop
def train_nbfnet(pos_triplets, pos_triplets_ntypes, neg_triplets, neg_triplets_ntypes,
                 entity_types, relation_types, node_counts, 
                 num_epochs: int = 100, batch_size: int = 1024):
    """Complete training function"""
    
    # Initialize model
    model = SimpleNBFNet(
        entity_types=entity_types,
        relation_types=relation_types,
        node_counts=node_counts,
        embedding_dim=256,
        hidden_dim=256,
        num_layers=3,
        message_func='distmult'
    )
    
    # Initialize trainer
    trainer = NBFNetTrainer(model, learning_rate=1e-3)
    
    # Prepare batched data
    pos_batches = prepare_batch_data(pos_triplets, pos_triplets_ntypes, batch_size)
    neg_batches = prepare_batch_data(neg_triplets, neg_triplets_ntypes, batch_size)
    
    # Training loop
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = min(len(pos_batches), len(neg_batches))
        
        for i in range(num_batches):
            pos_batch = pos_batches[i % len(pos_batches)]
            neg_batch = neg_batches[i % len(neg_batches)]
            
            pos_triplets_batch, pos_head_types, pos_tail_types = pos_batch
            neg_triplets_batch, neg_head_types, neg_tail_types = neg_batch
            
            # Ensure same batch size
            min_batch_size = min(pos_triplets_batch.size(0), neg_triplets_batch.size(0))
            pos_triplets_batch = pos_triplets_batch[:min_batch_size]
            neg_triplets_batch = neg_triplets_batch[:min_batch_size]
            pos_head_types = pos_head_types[:min_batch_size]
            pos_tail_types = pos_tail_types[:min_batch_size]
            neg_head_types = neg_head_types[:min_batch_size]
            neg_tail_types = neg_tail_types[:min_batch_size]
            
            loss = trainer.train_step(
                pos_triplets_batch, neg_triplets_batch,
                pos_head_types, pos_tail_types,
                neg_head_types, neg_tail_types
            )
            total_loss += loss
        
        if epoch % 10 == 0:
            # Evaluate on a sample
            eval_metrics = trainer.evaluate(
                pos_triplets_batch, neg_triplets_batch,
                pos_head_types, pos_tail_types,
                neg_head_types, neg_tail_types
            )
            print(f"Epoch {epoch}: Loss={total_loss/num_batches:.4f}, "
                  f"Acc={eval_metrics['total_accuracy']:.4f}")
    
    return model, trainer

In [None]:
def get_triplets_vectorized(g, edge_mapping):
    """Extract triplets from heterogeneous DGL graph with efficient negative sampling."""
    triplets = []
    triplets_ntypes = []
    neg_triplets = []
    neg_triplets_ntypes = []
    
    if hasattr(g, 'canonical_etypes'):
        # Pre-compute invalid entity types for each relation
        relation_invalid_heads = {}
        relation_invalid_tails = {}
        node_type_ranges = {}
        
        # Cache node counts for each type
        for ntype in g.ntypes:
            node_type_ranges[ntype] = g.num_nodes(ntype)
        
        # Pre-compute invalid types for each relation
        for canonical_etype in g.canonical_etypes:
            src_ntype, etype, dst_ntype = canonical_etype
            
            if etype not in relation_invalid_heads:
                # Find valid head types for this relation
                valid_head_types = set()
                valid_tail_types = set()
                for canon_et in g.canonical_etypes:
                    if canon_et[1] == etype:
                        valid_head_types.add(canon_et[0])
                        valid_tail_types.add(canon_et[2])
                
                # Get invalid types
                all_node_types = set(g.ntypes)
                relation_invalid_heads[etype] = list(all_node_types - valid_head_types)
                relation_invalid_tails[etype] = list(all_node_types - valid_tail_types)
        
        # Process each relation type
        for canonical_etype in g.canonical_etypes:
            src_ntype, etype, dst_ntype = canonical_etype
            src, dst = g.edges(etype=canonical_etype)
            
            if len(src) == 0:
                continue
                
            # Map relation name to index
            relation_idx = edge_mapping[etype]
            edge_triplets = list(zip(src.tolist(), [relation_idx] * len(src), dst.tolist()))
            triplets.extend(edge_triplets)
            triplets_ntypes.extend(list(zip([src_ntype] * len(src), [dst_ntype] * len(dst))))
            
            # Vectorized negative triplet generation
            num_edges = len(src)
            src_numpy = src.numpy()
            dst_numpy = dst.numpy()
            
            # Random choice for head vs tail corruption (vectorized)
            corrupt_head_mask = np.random.random(num_edges) < 0.5
            
            # Generate corrupted heads
            invalid_head_types = relation_invalid_heads[etype]
            if invalid_head_types and np.any(corrupt_head_mask):
                head_corruption_indices = np.where(corrupt_head_mask)[0]
                
                # Vectorized selection of corrupted head types
                corrupted_head_types = np.random.choice(
                    invalid_head_types, 
                    size=len(head_corruption_indices)
                )
                
                # Vectorized generation of corrupted head IDs
                corrupted_head_ids = np.array([
                    np.random.randint(0, node_type_ranges[head_type]) 
                    if node_type_ranges[head_type] > 0 else 0
                    for head_type in corrupted_head_types
                ])
                
                # Create negative triplets for head corruption
                valid_mask = np.array([node_type_ranges[ht] > 0 for ht in corrupted_head_types])
                if np.any(valid_mask):
                    valid_indices = head_corruption_indices[valid_mask]
                    valid_corrupted_heads = corrupted_head_ids[valid_mask]
                    valid_corrupted_head_types = corrupted_head_types[valid_mask]
                    
                    head_neg_triplets = list(zip(
                        valid_corrupted_heads.tolist(),
                        [relation_idx] * len(valid_indices),
                        dst_numpy[valid_indices].tolist()
                    ))
                    neg_triplets.extend(head_neg_triplets)
                    neg_triplets_ntypes.extend(list(zip(
                        valid_corrupted_head_types.tolist(),
                        [dst_ntype] * len(valid_indices)
                    )))
            
            # Generate corrupted tails
            invalid_tail_types = relation_invalid_tails[etype]
            if invalid_tail_types and np.any(~corrupt_head_mask):
                tail_corruption_indices = np.where(~corrupt_head_mask)[0]
                
                # Vectorized selection of corrupted tail types
                corrupted_tail_types = np.random.choice(
                    invalid_tail_types, 
                    size=len(tail_corruption_indices)
                )
                
                # Vectorized generation of corrupted tail IDs
                corrupted_tail_ids = np.array([
                    np.random.randint(0, node_type_ranges[tail_type]) 
                    if node_type_ranges[tail_type] > 0 else 0
                    for tail_type in corrupted_tail_types
                ])
                
                # Create negative triplets for tail corruption
                valid_mask = np.array([node_type_ranges[tt] > 0 for tt in corrupted_tail_types])
                if np.any(valid_mask):
                    valid_indices = tail_corruption_indices[valid_mask]
                    valid_corrupted_tails = corrupted_tail_ids[valid_mask]
                    valid_corrupted_tail_types = corrupted_tail_types[valid_mask]
                    
                    tail_neg_triplets = list(zip(
                        src_numpy[valid_indices].tolist(),
                        [relation_idx] * len(valid_indices),
                        valid_corrupted_tails.tolist()
                    ))
                    neg_triplets.extend(tail_neg_triplets)
                    neg_triplets_ntypes.extend(list(zip(
                        [src_ntype] * len(valid_indices),
                        valid_corrupted_tail_types.tolist()
                    )))
    
    else:
        # Homogeneous graph - vectorized version
        src, dst = g.edges()
        edge_triplets = list(zip(src.tolist(), [0] * len(src), dst.tolist()))
        triplets.extend(edge_triplets)
        
        # Vectorized negative sampling for homogeneous graphs
        num_edges = len(src)
        total_nodes = g.num_nodes()
        
        if total_nodes > 0 and num_edges > 0:
            src_numpy = src.numpy()
            dst_numpy = dst.numpy()
            
            # Random choice for head vs tail corruption
            corrupt_head_mask = np.random.random(num_edges) < 0.5
            
            # Generate all corrupted nodes at once
            corrupted_heads = np.random.randint(0, total_nodes, size=num_edges)
            corrupted_tails = np.random.randint(0, total_nodes, size=num_edges)
            
            # Apply corruption based on mask
            neg_src = np.where(corrupt_head_mask, corrupted_heads, src_numpy)
            neg_dst = np.where(corrupt_head_mask, dst_numpy, corrupted_tails)
            
            neg_triplets = list(zip(
                neg_src.tolist(),
                [0] * num_edges,
                neg_dst.tolist()
            ))
    
    return triplets, triplets_ntypes, neg_triplets, neg_triplets_ntypes

def get_entity_embeddings(model, entity_ids, entity_types):
    """Extract learned entity embeddings"""
    model.eval()
    device = next(model.parameters()).device
    
    if isinstance(entity_ids, list):
        entity_ids = torch.tensor(entity_ids, dtype=torch.long)
    
    entity_ids = entity_ids.to(device)
    
    with torch.no_grad():
        embeddings = model.get_embeddings(entity_ids, entity_types)
    
    return embeddings.cpu().numpy()

def predict_triplet_scores(model, triplets, head_types, tail_types):
    """Predict scores for given triplets"""
    model.eval()
    device = next(model.parameters()).device
    
    if isinstance(triplets, list):
        triplets = torch.tensor(triplets, dtype=torch.long)
    
    triplets = triplets.to(device)
    
    with torch.no_grad():
        scores = model(triplets, head_types, tail_types)
        probabilities = torch.sigmoid(scores)
    
    return scores.cpu().numpy(), probabilities.cpu().numpy()

In [None]:
import random
from tqdm.notebook import tqdm

gdb = DeepGraphDB()
gdb.load_graph("/home/cc/PHD/dglframework/DeepKG/DeepGraphDB/graphs/primekg.bin")

# Extract graph information
num_entities = gdb.graph.number_of_nodes()

# relation_types = list(set([etype[1] for etype in gdb.graph.canonical_etypes]))
# relation_to_idx = {rel: idx for idx, rel in enumerate(relation_types)}

# Extract triplets
pos_triplets, pos_triplets_ntypes, neg_triplets, neg_triplets_ntypes = get_triplets_vectorized(gdb.graph, gdb.edge_types_mapping)

print(f"Graph stats: {num_entities} entities, {len(gdb.edge_types_mapping)} relations, {len(pos_triplets)} triplets")

In [None]:
if hasattr(gdb.graph, 'ntypes'):
    # Heterogeneous graph
    entity_types = list(gdb.graph.ntypes)
    node_counts = {ntype: gdb.graph.num_nodes(ntype) for ntype in entity_types}
    
    # Get relation types from canonical edge types
    if hasattr(gdb.graph, 'canonical_etypes'):
        relation_types = list(set([etype[1] for etype in gdb.graph.canonical_etypes]))
    else:
        relation_types = list(gdb.edge_types_mapping.keys()) if hasattr(gdb, 'edge_types_mapping') else ['default_relation']
else:
    # Homogeneous graph
    entity_types = ['entity']
    node_counts = {'entity': gdb.graph.num_nodes()}
    relation_types = ['relation']

print(f"Entity types: {entity_types}")
print(f"Node counts: {node_counts}")
print(f"Relation types: {relation_types}")

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleNBFNet(
    entity_types=entity_types,
    relation_types=relation_types,
    node_counts=node_counts,
    embedding_dim=256,  # Reduced for faster training
    hidden_dim=256,
    num_layers=6,       # Reduced for faster training
    message_func='distmult',
    dropout=0.25
).to(device)

# Initialize trainer
trainer = NBFNetTrainer(model, learning_rate=1e-3, weight_decay=1e-5)

# Prepare data
batch_size = 250000
pos_batches = prepare_batch_data(pos_triplets, pos_triplets_ntypes, batch_size)
neg_batches = prepare_batch_data(neg_triplets, neg_triplets_ntypes, batch_size)

print(f"Prepared {len(pos_batches)} positive batches and {len(neg_batches)} negative batches")

# Training loop with progress tracking
num_epochs = 100
best_acc = 0.0

#TODO: Split triples based on entity types or relation types

random.shuffle(pos_batches)
random.shuffle(neg_batches)

eval_pos_batches = pos_batches[:5]
eval_neg_batches = neg_batches[:5]

train_pos_batches = pos_batches[5:]
train_neg_batches = neg_batches[5:]

for epoch in tqdm(range(num_epochs)):
    model.train()
    total_loss = 0
    num_batches = min(len(train_pos_batches), len(train_neg_batches))
    
    for i in range(num_batches):  # Limit batches per epoch for speed
        pos_batch = train_pos_batches[i % len(train_pos_batches)]
        neg_batch = train_neg_batches[i % len(train_neg_batches)]
        
        pos_triplets_batch, pos_head_types, pos_tail_types = pos_batch
        neg_triplets_batch, neg_head_types, neg_tail_types = neg_batch
        
        # Move to device
        pos_triplets_batch = pos_triplets_batch.to(device)
        neg_triplets_batch = neg_triplets_batch.to(device)
        
        # Ensure same batch size
        min_batch_size = min(pos_triplets_batch.size(0), neg_triplets_batch.size(0))
        pos_triplets_batch = pos_triplets_batch[:min_batch_size]
        neg_triplets_batch = neg_triplets_batch[:min_batch_size]
        pos_head_types = pos_head_types[:min_batch_size]
        pos_tail_types = pos_tail_types[:min_batch_size]
        neg_head_types = neg_head_types[:min_batch_size]
        neg_tail_types = neg_tail_types[:min_batch_size]
        
        loss = trainer.train_step(
            pos_triplets_batch, neg_triplets_batch,
            pos_head_types, pos_tail_types,
            neg_head_types, neg_tail_types
        )
        total_loss += loss

    evals = {
        'total_accuracy': np.array([0.0]),
        'auc': np.array([0.0]),
        'pos_accuracy': np.array([0.0]),
        'neg_accuracy': np.array([0.0]),
        'avg_pos_score': np.array([0.0]),
        'avg_neg_score': np.array([0.0])
    }

    # Evaluate every 5 epochs
    if (epoch + 1) % 5 == 0:
        model.eval()
        # Evaluate on a sample batch
        for eval_pos_batch, eval_neg_batch in zip(eval_pos_batches, eval_neg_batches):
            eval_pos_triplets, eval_pos_head_types, eval_pos_tail_types = eval_pos_batch
            eval_neg_triplets, eval_neg_head_types, eval_neg_tail_types = eval_neg_batch
            
            eval_pos_triplets = eval_pos_triplets.to(device)
            eval_neg_triplets = eval_neg_triplets.to(device)
            
            min_eval_size = min(eval_pos_triplets.size(0), eval_neg_triplets.size(0))
            eval_pos_triplets = eval_pos_triplets[:min_eval_size]
            eval_neg_triplets = eval_neg_triplets[:min_eval_size]
            eval_pos_head_types = eval_pos_head_types[:min_eval_size]
            eval_pos_tail_types = eval_pos_tail_types[:min_eval_size]
            eval_neg_head_types = eval_neg_head_types[:min_eval_size]
            eval_neg_tail_types = eval_neg_tail_types[:min_eval_size]
            
            eval_metrics = trainer.evaluate(
                eval_pos_triplets, eval_neg_triplets,
                eval_pos_head_types, eval_pos_tail_types,
                eval_neg_head_types, eval_neg_tail_types
            )
            
            # Store mean accuracy for this epoch and calculate AUC
            evals['total_accuracy'] = np.append(evals['total_accuracy'], eval_metrics['total_accuracy'])
            evals['pos_accuracy'] = np.append(evals['pos_accuracy'], eval_metrics['pos_accuracy'])
            evals['neg_accuracy'] = np.append(evals['neg_accuracy'], eval_metrics['neg_accuracy'])
            evals['avg_pos_score'] = np.append(evals['avg_pos_score'], eval_metrics['avg_pos_score'])
            evals['avg_neg_score'] = np.append(evals['avg_neg_score'], eval_metrics['avg_neg_score'])     
            evals['auc'] = np.append(evals['auc'], eval_metrics['auc'])

        current_acc = evals['total_accuracy'].mean()
        if current_acc > best_acc:
            best_acc = current_acc
            # Save best model
            torch.save(model.state_dict(), 'best_nbfnet_model.pth')
        
        print(f"Epoch {epoch+1}: Loss={total_loss/num_batches:.4f}, "
            f"Acc={current_acc:.4f} (Best: {best_acc:.4f}) - AUC={evals['auc'].mean():.4f}")
        print(f"  Pos Acc: {evals['pos_accuracy'].mean():.4f}, "
            f"Neg Acc: {evals['neg_accuracy'].mean():.4f}")
        print(f"  Avg Pos Score: {evals['avg_pos_score'].mean():.4f}, "
            f"Avg Neg Score: {evals['avg_neg_score'].mean():.4f}")
    
# Example: Get embeddings for specific entities
sample_entity_ids = [0, 1, 2, 3, 4]
sample_entity_types = ['gene', 'gene', 'disease', 'drug', 'gene']  # Example types

embeddings = get_entity_embeddings(model, sample_entity_ids, sample_entity_types)
print(f"Embeddings shape: {embeddings.shape}")

# Example: Predict scores for new triplets
test_triplets = [[0, 0, 1], [2, 1, 3]]  # [head, relation, tail]
test_head_types = ['gene', 'disease']
test_tail_types = ['gene', 'drug']

scores, probs = predict_triplet_scores(model, test_triplets, test_head_types, test_tail_types)
print(f"Triplet scores: {scores}")
print(f"Triplet probabilities: {probs}")