# Overview of Project Structure



This project implements a knowledge graph embedding model using TransE. The code is modularized into three key files:

- `data.py`: Handles dataset loading, generate triples and splitting into unrestricted and confidential subsets based on a configurable `confidential_ratio`. It also generates positive/negative samples and prints relevant statistics.
- `model.py`: Defines the TransE model class. It initializes entity and relation embeddings, computes distances between triples, performs forward passes, and manages model parameters.
- `trainer.py`: Contains the training logic. It trains the model using the configured parameters and privacy settings and logs the performance metrics.


# Install Required Libraries




This notebook uses:
- `pykeen`: For working with knowledge graph embeddings.
- `opacus`: To enable differential privacy during training.

In [None]:
%%capture
!pip install pykeen==1.10.1 class-resolver==0.3.10

In [None]:
%%capture
!pip install opacus

# Define data.py: Knowledge Graph Data Handler




Creates `KGDataHandler` class, which is responsible for:

- Loading training, validation, and test triples from files.
- Initializing entity and relation mappings using PyKEEN.
- Splitting training data into *confidential* and *unrestricted* subsets based on a given ratio.
- Generating both positive and negative samples for training.
- Providing key dataset statistics for transparency and debugging.

In [None]:
%%writefile data.py
import os
import random
import torch
from typing import List
from pykeen.datasets import PathDataset
from typing import List, Tuple, Dict, Set, Optional

class KGDataHandler:
    """Knowledge Graph Data Handler"""

    def __init__(self, fb_path: str, confidential_ratio: float = 0.3):
        """
        Initialize the KG data handler

        Args:
            fb_path: Path to the dataset files
            confidential_ratio: Fraction of training data that requires privacy protection
        """
        self.fb_path = fb_path
        self.confidential_ratio = confidential_ratio

        self.train_data = self.load_triplets(os.path.join(fb_path, 'train.txt'))
        self.valid_data = self.load_triplets(os.path.join(fb_path, 'valid.txt'))
        self.test_data = self.load_triplets(os.path.join(fb_path, 'test.txt'))

        self.train_path = os.path.join(fb_path, 'train.txt')
        self.valid_path = os.path.join(fb_path, 'valid.txt')
        self.test_path = os.path.join(fb_path, 'test.txt')

        self.dataset = PathDataset(
            training_path=self.train_path,
            testing_path=self.test_path,
            validation_path=self.valid_path
        )

        self._prepare_data()

    def load_triplets(self, file_path: str) -> List[List[str]]:
        """
        Load triplets from file

        Args:
            file_path: Path to the file

        Returns:
            List of triplets [head, relation, tail]
        """
        with open(file_path, 'r') as f:
            triplets = [line.strip().split('\t') for line in f]
        return triplets

    def _prepare_data(self):
        """Prepare data for training"""
        # Create entity and relation dictionaries
        self.entities = set()
        self.relations = set()
        self.entities_t = set()
        self.relations_t = set()

        for h, r, t in self.train_data:
            self.entities.add(h)
            self.entities.add(t)
            self.relations.add(r)

        for h, r, t in self.test_data:
            self.entities_t.add(h)
            self.entities_t.add(t)
            self.relations_t.add(r)

        # Get mappings from dataset
        self.entity_to_id = self.dataset.training.entity_to_id
        self.relation_to_id = self.dataset.training.relation_to_id
        self.entity_to_id_t = self.dataset.testing.entity_to_id
        self.relation_to_id_t = self.dataset.testing.relation_to_id

        self.entity_count = len(self.entity_to_id)
        self.relation_count = len(self.relation_to_id)

        # Convert string triples to ID triples
        self.train_triples = [
            (self.entity_to_id[h], self.relation_to_id[r], self.entity_to_id[t])
            for h, r, t in self.train_data
        ]

        self.test_triples = [
            (self.entity_to_id_t[h], self.relation_to_id_t[r], self.entity_to_id_t[t])
            for h, r, t in self.test_data
            if (h in self.entity_to_id and r in self.relation_to_id and t in self.entity_to_id)
        ]

        # Split training data into confidential and unrestricted
        split_point_conf = int(len(self.train_triples) * self.confidential_ratio)
        self.confidential_triples = self.train_triples[:split_point_conf]
        self.unrestricted_triples = self.train_triples[split_point_conf:]

        # All triples for filtered evaluation
        self.all_triples = self.train_triples + self.test_triples

    def print_data_stats(self):
        """Print statistics about the data"""
        print("--------------------------------------------------------------------------------")
        print(f"Total entities: {self.entity_count}")
        print("--------------------------------------------------------------------------------")
        print(f"Total relations: {self.relation_count}")
        print("--------------------------------------------------------------------------------")
        print(f"Confidential triples: {len(self.confidential_triples)}")
        print("--------------------------------------------------------------------------------")
        print(f"Unrestricted triples: {len(self.unrestricted_triples)}")
        print("--------------------------------------------------------------------------------")
        print(f"Testing triples: {len(self.test_triples)}")

    def get_positive_and_negative_samples(self, triples, batch_size, neg_ratio=10, entity_count=None):
        """
        Get positive samples and corresponding negative samples

        Args:
            triples: List of triples [(head, relation, tail), ...]
            batch_size: Number of positive samples
            neg_ratio: Number of negative samples per positive sample
            entity_count: Number of entities in KG

        Returns:
            Positive and negative samples as tensors
        """
        if entity_count is None:
            entity_count = self.entity_count

        if len(triples) > batch_size:
            batch_indices = random.sample(range(len(triples)), batch_size)
            batch_triples = [triples[i] for i in batch_indices]
        else:
            batch_triples = triples

        # Create positive samples tensor
        pos_samples = torch.tensor(batch_triples, dtype=torch.long).to("cuda" if torch.cuda.is_available() else "cpu")

        # Create negative samples by corrupting either head or tail
        neg_samples = []
        for h, r, t in batch_triples:
            for _ in range(neg_ratio):
                if random.random() < 0.5:  # Corrupt head
                    h_corrupt = random.randint(0, entity_count - 1)
                    while h_corrupt == h:
                        h_corrupt = random.randint(0, entity_count - 1)
                    neg_samples.append((h_corrupt, r, t))
                else:  # Corrupt tail
                    t_corrupt = random.randint(0, entity_count - 1)
                    while t_corrupt == t:
                        t_corrupt = random.randint(0, entity_count - 1)
                    neg_samples.append((h, r, t_corrupt))

        neg_samples = torch.tensor(neg_samples, dtype=torch.long).to("cuda" if torch.cuda.is_available() else "cpu")
        return pos_samples, neg_samples


Overwriting data.py


# Define model.py : Model Definition



Creates a file contains the `TransEModel` class, which encapsulates:
- Initialization of entity and relation embeddings
- L2 distance computation (used as a scoring function)
- A forward pass to compute triple scores
- Normalization of embeddings


In [None]:
%%writefile model.py
import torch # import torch module
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class TransEModel(nn.Module):
    """TransE Knowledge Graph Embedding Model"""

    def __init__(self, entity_count: int, relation_count: int, embedding_dim: int = 100, device: str = "cuda"):
        """
        Initialize the TransE model

        Args:
            entity_count: Number of entities in the knowledge graph
            relation_count: Number of relations in the knowledge graph
            embedding_dim: Dimension of the embedding vectors
            device: Device to run the model on (cuda or cpu)
        """
        super(TransEModel, self).__init__()

        self.entity_count = entity_count
        self.relation_count = relation_count
        self.embedding_dim = embedding_dim
        self.device = device

        # Initialize embeddings with better scaling
        self.entity_embeddings = nn.Embedding(entity_count, embedding_dim).to(device)
        self.relation_embeddings = nn.Embedding(relation_count, embedding_dim).to(device)

        torch.nn.init.uniform_(
            self.entity_embeddings.weight.data,
            -6/np.sqrt(embedding_dim),
            6/np.sqrt(embedding_dim)
        )
        torch.nn.init.uniform_(
            self.relation_embeddings.weight.data,
            -6/np.sqrt(embedding_dim),
            6/np.sqrt(embedding_dim)
        )

        # Save initial embeddings for later comparison
        self.initial_entity_embeddings = self.entity_embeddings.weight.data.clone().cpu().numpy()
        self.initial_relation_embeddings = self.relation_embeddings.weight.data.clone().cpu().numpy()

    def l2_distance(self, head, relation, tail):
        """Compute L2 distance for TransE: ||h + r - t||_2."""
        return torch.norm(head + relation - tail, p=2, dim=1)

    def forward(self, triples, normalize=True):
        """
        Forward pass to compute scores for triples

        Args:
            triples: Tensor of shape (batch_size, 3) containing (head, relation, tail) triples
            normalize: Whether to normalize embeddings

        Returns:
            Tensor of scores (L2 distances) for each triple
        """
        heads = self.entity_embeddings(triples[:, 0])
        relations = self.relation_embeddings(triples[:, 1])
        tails = self.entity_embeddings(triples[:, 2])

        # Optional normalization
        if normalize:
            heads = F.normalize(heads, p=2, dim=1)
            relations = F.normalize(relations, p=2, dim=1)
            tails = F.normalize(tails, p=2, dim=1)

        return self.l2_distance(heads, relations, tails)

    def normalize_embeddings(self):
        """Normalize embeddings to unit length"""
        with torch.no_grad():
            self.entity_embeddings.weight.data = F.normalize(self.entity_embeddings.weight.data, p=2, dim=1)
            self.relation_embeddings.weight.data = F.normalize(self.relation_embeddings.weight.data, p=2, dim=1)

    def get_parameters(self):
        """Get all parameters of the model"""
        return list(self.entity_embeddings.parameters()) + list(self.relation_embeddings.parameters())

    def get_state_dict(self):
        """Get state dict for saving the model"""
        return {
            'entity_embeddings': self.entity_embeddings.state_dict(),
            'relation_embeddings': self.relation_embeddings.state_dict()
        }

    def load_state_dict_from_dict(self, state_dict):
        """Load state dict from dictionary"""
        self.entity_embeddings.load_state_dict(state_dict['entity_embeddings'])
        self.relation_embeddings.load_state_dict(state_dict['relation_embeddings'])

Overwriting model.py


#Define trainer.py: TransE Model Trainer with Differential Privacy


This script defines the `TransETrainer` class, which is responsible for:

- Managing the complete training process of the TransE model.
- Supporting differential privacy through per-sample gradient clipping and Gaussian noise injection.
- Alternate between training on confidential and unrestricted triples based on dataset balance.
- Apply a custom margin-based ranking loss with optional L2 regularization.
- Monitoring differential privacy guarantees using Opacus's RDPAccountant.
- Evaluate model performance using standard filtered ranking metrics (MR, MRR, Hits@K).
- Incorporate early stopping based on validation performance or privacy budget (ε) exhaustion.
- Normalize embeddings after updates to ensure stable optimization in the TransE space.

In [None]:
%%writefile trainer.py

import numpy as np
import random
from torch.optim import AdamW
from opacus.accountants import RDPAccountant
import torch
from model import TransEModel
from data import KGDataHandler

class TransETrainer:
    """TransE Model Trainer with Differential Privacy Support"""

    def __init__(
        self,
        model: TransEModel,
        data_handler: KGDataHandler,
        learning_rate: float = 0.005,
        noise_multiplier: float = 0.7,
        batch_size: int = 256,
        norm_clipping: float = 1.0,
        margin: float = 0.5,
        epochs: int = 300,
        reg_lambda: float = 1e-5,
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        """
        Initialize the trainer

        Args:
            model: TransE model
            data_handler: Knowledge graph data handler
            learning_rate: Step size for optimizer updates
            noise_multiplier: Amount of noise added to gradients for differential privacy
            batch_size: Number of samples per training batch
            norm_clipping: Maximum L2 norm of per-sample gradients (before adding noise)
            margin: Margin used in ranking loss to separate positive/negative triples
            epochs: Total number of training iterations over the dataset
            reg_lambda: Regularization parameter
            device: Device to run the model on (cuda or cpu)
        """
        self.model = model
        self.data_handler = data_handler
        self.learning_rate = learning_rate
        self.noise_multiplier = noise_multiplier
        self.batch_size = batch_size
        self.norm_clipping = norm_clipping
        self.margin = margin
        self.epochs = epochs
        self.reg_lambda = reg_lambda
        self.device = device

        # Initialize optimizer and scheduler
        self.parameters = self.model.get_parameters()
        self.optimizer = AdamW(self.parameters, lr=learning_rate, weight_decay=1e-5)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )

        # Privacy accountant
        self.accountant = RDPAccountant()

    def loss_function(self, pos_scores, neg_scores):
        """
        Margin-based ranking loss for TransE with regularization

        Args:
            pos_scores: Scores for positive samples
            neg_scores: Scores for negative samples

        Returns:
            Loss value
        """
        # Reshape to ensure compatibility
        pos_expanded = pos_scores.unsqueeze(1).expand(-1, neg_scores.size(0) // pos_scores.size(0))
        pos_expanded = pos_expanded.reshape(-1)

        ranking_loss = torch.mean(torch.relu(pos_expanded - neg_scores + self.margin))

        # Add regularization term
        if self.reg_lambda > 0:
            # L2 regularization on parameters
            reg_loss = 0
            for param in self.parameters:
                reg_loss += torch.norm(param, p=2)
            return ranking_loss + self.reg_lambda * reg_loss

        return ranking_loss

    def optimize_confidential(self, triples):
        """
        Optimize parameters for confidential statements with improved differential privacy

        Args:
            triples: List of triples [(head, relation, tail), ...]

        Returns:
            Loss value
        """
        # Get positive and negative samples
        pos_samples, neg_samples = self.data_handler.get_positive_and_negative_samples(
            triples, self.batch_size, neg_ratio=5
        )

        # Forward pass
        pos_score = self.model.forward(pos_samples, normalize=False)
        neg_score = self.model.forward(neg_samples, normalize=False)

        # Compute loss
        loss = self.loss_function(pos_score, neg_score)

        # Backward pass to get gradients
        self.optimizer.zero_grad()
        loss.backward()

        # Improved gradient clipping using global norm
        total_norm = 0
        for param in self.parameters:
            if param.grad is not None:
                param_norm = param.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5

        # Clip gradients globally (more stable than per-parameter)
        clip_coef = min(1.0, self.norm_clipping / (total_norm + 1e-6))
        for param in self.parameters:
            if param.grad is not None:
                param.grad.data.mul_(clip_coef)

        # Add calibrated noise for differential privacy
        for param in self.parameters:
            if param.grad is not None:
                noise = torch.normal(
                    mean=0.0,
                    std=self.noise_multiplier * self.norm_clipping / self.batch_size**0.5,
                    size=param.grad.shape,
                    device=self.device
                )
                param.grad.data += noise

        # Update parameters
        self.optimizer.step()
        sample_rate = self.batch_size / len(self.data_handler.confidential_triples)

        self.accountant.step(noise_multiplier=self.noise_multiplier, sample_rate=sample_rate)

        # Normalize embeddings after update
        self.model.normalize_embeddings()

        return loss.item()

    def optimize_unrestricted(self, triples):
        """
        Optimize parameters for unrestricted statements without differential privacy

        Args:
            triples: List of triples [(head, relation, tail), ...]

        Returns:
            Loss value
        """
        # Get positive and negative samples
        pos_samples, neg_samples = self.data_handler.get_positive_and_negative_samples(
            triples, self.batch_size, neg_ratio=10
        )

        # Forward pass
        pos_score = self.model.forward(pos_samples, normalize=False)
        neg_score = self.model.forward(neg_samples, normalize=False)

        # Compute loss
        loss = self.loss_function(pos_score, neg_score)

        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()

        # Update parameters
        self.optimizer.step()

        # Normalize embeddings after update
        self.model.normalize_embeddings()

        return loss.item()

    def evaluate_model(self, test_triples, all_triples, batch_size=128, k=10):
        """
        Evaluate model using filtered setting with batch processing

        Args:
            test_triples: List of test triples
            all_triples: List of all triples (train + test)
            batch_size: Batch size for evaluation
            k: K for Hits@K metric

        Returns:
            Dictionary of evaluation metrics
        """
        # Convert all_triples to a dictionary for O(1) lookup
        head_filter = {}
        tail_filter = {}

        for h, r, t in all_triples:
            if (h, r) not in tail_filter:
                tail_filter[(h, r)] = []
            tail_filter[(h, r)].append(t)

            if (r, t) not in head_filter:
                head_filter[(r, t)] = []
            head_filter[(r, t)].append(h)

        head_ranks = []
        tail_ranks = []

        test_batches = [test_triples[i:i+batch_size] for i in range(0, len(test_triples), batch_size)]

        with torch.no_grad():
            for batch in test_batches:
                for h, r, t in batch:
                    # Corrupt head
                    head_candidates = []
                    head_ids = []

                    for e in range(self.data_handler.entity_count):
                        if e != h and (e not in head_filter.get((r, t), [])):
                            head_candidates.append((e, r, t))
                            head_ids.append(e)

                    # Add true triple
                    head_candidates.append((h, r, t))
                    head_ids.append(h)

                    # Get scores for head batch
                    if head_candidates:
                        head_tensors = torch.tensor(head_candidates, device=self.device)
                        head_scores = self.model.forward(head_tensors).cpu().numpy()

                        # In TransE, LOWER scores are better (distance-based)
                        true_idx = head_ids.index(h)
                        true_score = head_scores[true_idx]
                        # Count entities with better (lower) scores than the true entity
                        head_rank = 1 + np.sum(head_scores < true_score)
                        head_ranks.append(head_rank)

                    # Corrupt tail
                    tail_candidates = []
                    tail_ids = []

                    for e in range(self.data_handler.entity_count):
                        if e != t and (e not in tail_filter.get((h, r), [])):
                            tail_candidates.append((h, r, e))
                            tail_ids.append(e)

                    # Add true triple
                    tail_candidates.append((h, r, t))
                    tail_ids.append(t)

                    # Get scores for tail batch
                    if tail_candidates:
                        tail_tensors = torch.tensor(tail_candidates, device=self.device)
                        tail_scores = self.model.forward(tail_tensors).cpu().numpy()

                        # Find rank of true triple (lower is better)
                        true_idx = tail_ids.index(t)
                        true_score = tail_scores[true_idx]
                        # Count entities with better (lower) scores than the true entity
                        tail_rank = 1 + np.sum(tail_scores < true_score)
                        tail_ranks.append(tail_rank)

        # Calculate metrics
        all_ranks = head_ranks + tail_ranks
        mr = sum(all_ranks) / len(all_ranks) if all_ranks else 0
        mrr = sum(1.0/r for r in all_ranks) / len(all_ranks) if all_ranks else 0
        hits_at_1 = sum(1 for r in all_ranks if r <= 1) / len(all_ranks) if all_ranks else 0
        hits_at_3 = sum(1 for r in all_ranks if r <= 3) / len(all_ranks) if all_ranks else 0
        hits_at_k = sum(1 for r in all_ranks if r <= k) / len(all_ranks) if all_ranks else 0

        return {
            'MR': mr,
            'MRR': mrr,
            'Hits@1': hits_at_1,
            'Hits@3': hits_at_3,
            'Hits@10': hits_at_k
        }

    def train_with_early_stopping(self, patience=5):
        """
        Train model with early stopping

        Args:
            patience: Number of evaluation rounds with no improvement before early stopping

        Returns:
            Best model state, training losses, epsilon values
        """
        best_hits = 0
        no_improve = 0
        best_model_state = None

        mU = 0  # Counter for unrestricted optimization steps
        mC = 0  # Counter for confidential optimization steps
        losses = []
        epsilon_values = []
        all_metrics = []

        for epoch in range(self.epochs):
            print(f"Epoch {epoch + 1}/{self.epochs}")
            steps = 0
            epoch_losses = []

            while True:
                # Stop if both datasets are exhausted
                if (len(self.data_handler.unrestricted_triples) == 0 and
                    len(self.data_handler.confidential_triples) == 0):
                    break

                # Decide which type of batch to sample (improved balance calculation)
                if (len(self.data_handler.unrestricted_triples) == 0 or
                    (mC < mU * len(self.data_handler.confidential_triples) /
                     len(self.data_handler.unrestricted_triples))):
                    batch_type = "confidential"
                    loss = self.optimize_confidential(self.data_handler.confidential_triples)
                    mC += 1
                else:
                    batch_type = "unrestricted"
                    loss = self.optimize_unrestricted(self.data_handler.unrestricted_triples)
                    mU += 1

                epoch_losses.append(loss)
                steps += 1

                if steps >= ((len(self.data_handler.unrestricted_triples) +
                              len(self.data_handler.confidential_triples)) // self.batch_size):
                    break

            avg_loss = sum(epoch_losses) / len(epoch_losses) if epoch_losses else 0
            print(f"  Epoch {epoch+1}: mU={mU}, mC={mC}, avg_loss={avg_loss:.4f}")
            losses.append(avg_loss)

            # Update learning rate based on average loss
            self.scheduler.step(avg_loss)

            # Evaluate on validation set every 10 epochs
            if epoch % 10 == 0:
                val_sample = random.sample(
                    self.data_handler.test_triples,
                    min(500, len(self.data_handler.test_triples))
                )
                metrics = self.evaluate_model(val_sample, self.data_handler.all_triples)
                print(f"  Validation: MR={metrics['MR']:.2f}, MRR={metrics['MRR']:.4f}, " +
                      f"Hits@10={metrics['Hits@10']:.4f}")
                metrics["epoch"] = epoch
                all_metrics.append(metrics)
                # Check for improvement
                current_hits = metrics['Hits@10']
                if current_hits > best_hits:
                    best_hits = current_hits
                    no_improve = 0
                    # Save best model
                    best_model_state = {
                        'entity_embeddings': self.model.entity_embeddings.state_dict(),
                        'relation_embeddings': self.model.relation_embeddings.state_dict(),
                        'epoch': epoch,
                        'metrics': metrics
                    }
                else:
                    no_improve += 1

                if no_improve >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    # Restore best model
                    if best_model_state:
                        self.model.load_state_dict_from_dict(best_model_state)
                    break

            if epoch % 5 == 0:
                eps = self.accountant.get_epsilon(delta=1e-5)
                print(f"  Current privacy guarantee: (ε = {eps:.2f}, δ = 1e-5)")
                epsilon_values.append((epoch, eps))
                if eps > 10:
                    print("Epsilon exceed, stopping at epoch ", (epoch + 1))
                    if best_model_state:
                        self.model.load_state_dict_from_dict(best_model_state)
                    return best_model_state, losses, epsilon_values, all_metrics

        return best_model_state, losses, epsilon_values, all_metrics


Overwriting trainer.py
