# Knowledge Graph Embedding System Overview

In this notebook, we program out how to embed a KG such as FB15k_237 using a KG Embedding model such as RotatE. The weights are saved under at the top of the project directory under trained_embedding_models/embedding_model_weights.pth 

##### System Components
1. **Model Training**: The system includes functions like `train_model` and `test_model` for training and evaluating the model on validation or test datasets.
2. **Hyperparameter Sweep**: You can choose to run a hyperparameter sweep by setting `run_sweep=True` in the configuration.
3. **Main Function**: The `main()` function serves as the core function for training the model, managing the training process, and saving the best model based on validation loss improvement.

##### Instructions
- To run a hyperparameter sweep, set `run_sweep=True`. This will trigger a sweep to find the best hyperparameters.
- Ensure you are logged into [Weights and Biases](https://wandb.ai/site) using `wandb.login()` before running the system.
- Monitor the training progress and validation loss improvement to save the best model during training.
- Use the provided functions like `train_model` and `test_model` to train and evaluate the model effectively.

In [1]:
import os
import sys
import random
import numpy as np
import yaml
from datetime import datetime
from icecream import ic
import torch
from torch import nn
import torch.optim as optim
from torch.optim import Optimizer
from torch.nn import Module
from datetime import datetime
from ipykernel import get_connection_file
from typing import Dict, Optional

# Set the notebook name for Weights and Biases tracking
os.environ["WANDB_NOTEBOOK_NAME"] = "embedding_model.ipynb"

import wandb

wandb.login()  # Ensure you are logged into Weights and Biases

# Set the device to GPU if available, otherwise fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the path to the notebook and the parent directory for imports
notebook_path = os.path.abspath(
    os.path.join(os.getcwd(), os.path.basename(get_connection_file()))
)
parent_dir = os.path.dirname(os.path.dirname(notebook_path))
sys.path.append(parent_dir)

# Import utility functions and model utilities
from utils.utils import *
from utils.embedding_model_utils import *
from utils.embedding_models.utils import *


set_seed()

[34m[1mwandb[0m: Currently logged in as: [33mbthill1[0m ([33muiuc_idealab_2024[0m). Use [1m`wandb login --relogin`[0m to force relogin
  return torch._C._cuda_getDeviceCount() > 0


Here we decide whether to train the model on specific hyperparameters stored in config or to run a Weights and Biases sweep to locate the best hyperparameters as defined in `sweep_config.yaml`

Set `run_sweep=True` to run the sweep using one of the sweep_config.yaml files and `False` to train the model on the defined config variable.

In [2]:
# Decide whether to run a hyperparameter sweep or use a fixed configuration
run_sweep = True
sweep_config_file_path = "sweep_config_ComplEx.yaml"
if run_sweep:
    with open(sweep_config_file_path, "r") as file:
        config = yaml.safe_load(file)
else:
    config = {
        "dataset_name": "FB15k_237",
        "embedding_model_name": "TransE",
        "task": "kg_completion",
        "max_epochs": 100000,
        "batch_size": 512,
        "lr": 0.075,
        "weight_decay": 1e-07,
        "k": [1, 3, 10],  # used for top-k evaluation
        "hidden_channels": 512,
        "margin": 0.5,
        "p_norm": 2,
        "verbose": True,
        "num_epochs_without_improvement_until_early_finish": 1,
        "validate_after_this_many_epochs": 1,
    }

In [3]:
def build_model(config: Dict) -> nn.Module:
    """
    Constructs and initializes the model based on the provided configuration.

    Args:
        config (Dict): Configuration dictionary containing model and training settings.

    Returns:
        nn.Module: The initialized model ready for training.
    """
    # Load dataset based on configuration
    train_data, val_data, test_data, data_path = load_dataset(
        config["dataset_name"], parent_dir=parent_dir
    )

    # Determine feature dimensions if available
    head_node_feature_dim = (
        train_data.x.shape[1]
        if hasattr(train_data, "x") and train_data.x is not None
        else None
    )
    aux_dict = train_data.aux_dict if hasattr(train_data, "aux_dict") else None

    config["num_nodes"] = train_data.num_nodes
    config["num_relations"] = train_data.num_edge_types

    # Create a dictionary of the required parameters
    model_params = {
        "num_nodes": config["num_nodes"],
        "num_relations": config["num_relations"],
        "hidden_channels": config["hidden_channels"],
        "task": config["task"],
        **({"margin": config["margin"]} if "margin" in config else {}),
        **({"p_norm": config["p_norm"]} if "p_norm" in config else {}),
    }

    # Conditionally add optional parameters to both config and model_params
    if head_node_feature_dim:
        config["head_node_feature_dim"] = head_node_feature_dim
        model_params["head_node_feature_dim"] = head_node_feature_dim
    if aux_dict:
        config["aux_dict"] = aux_dict
        model_params["aux_dict"] = aux_dict

    # Get the model class from the model name provided in config
    model_class = get_embedding_model_class(config["embedding_model_name"])

    # Instantiate the model with the parameters
    original_model = model_class(**model_params).to(device)

    # Calculate batch size based on the number of GPUs available
    num_gpus = torch.cuda.device_count()
    batch_size_per_gpu = config["batch_size"]
    batch_size = (
        batch_size_per_gpu * num_gpus if num_gpus > 0 else batch_size_per_gpu
    )
    num_workers = num_gpus if num_gpus > 0 else 1

    # Helper function to prepare loader arguments
    def prepare_loader_args(data, batch_size, num_workers):
        """Prepares and returns loader arguments based on the data and configuration."""
        loader_args = {
            "head_index": data.edge_index[0],
            "rel_type": data.edge_type,
            "tail_index": data.edge_index[1],
            "batch_size": batch_size,
            "shuffle": True,
            "num_workers": num_workers,
        }
        # Add extra features if available
        if hasattr(data, "x"):
            loader_args["x"] = data.x
        if hasattr(data, "y"):
            loader_args["y"] = data.y
        return loader_args

    # Prepare loader arguments for train, validation, and test datasets
    train_loader_args = prepare_loader_args(
        train_data, batch_size, num_workers
    )
    val_loader_args = prepare_loader_args(val_data, batch_size, num_workers)
    test_loader_args = prepare_loader_args(test_data, batch_size, num_workers)

    # Create data loaders for training, validation, and testing
    train_loader = original_model.loader(**train_loader_args)
    val_loader = original_model.loader(**val_loader_args)
    test_loader = original_model.loader(**test_loader_args)

    # Initialize the optimizer
    optimizer = optim.Adagrad(
        original_model.parameters(),
        lr=config["lr"],
        weight_decay=config["weight_decay"],
    )

    # Use DataParallel for multi-GPU setups on single server
    if num_gpus > 1:
        model = torch.nn.DataParallel(original_model)
    else:
        model = original_model

    # Attach data and loaders to the model for easy access
    model.train_data = train_data
    model.val_data = val_data
    model.test_data = test_data
    model.train_loader = train_loader
    model.val_loader = val_loader
    model.test_loader = test_loader
    model.optimizer = optimizer
    model.config = config
    model.batch_size = batch_size  # this is different from config batch_size when using multiple GPUs
    model.original_model = (
        original_model  # Keep a reference to the original model
    )

    # Determine the path to save results and create directory if it does not exist
    if "save_path" in config:
        save_path = config["save_path"]
        if not os.path.exists(save_path):
            raise ValueError(
                f"Directory {save_path} does not exist in which to save the trained embedding models. Please create it before saving."
            )
        save_path = osp.join(
            save_path,
            f"{config['prefix']}_embedding_model",
        )
    else:
        save_path = osp.join(
            parent_dir,
            "trained_embedding_models",
            f"{config['prefix']}_embedding_model",
        )

    os.makedirs(save_path, exist_ok=True)
    model.save_path = save_path

    # Save the model configuration
    save_embedding_model_config(model)

    return model

In [4]:
def train_model(model: Module) -> float:
    """
    Trains the model for one epoch over the training dataset.

    Args:
        model (Module): The model to be trained, which includes the data loaders and optimizer.

    Returns:
        float: The average training loss for the epoch.
    """
    model.train()  # Set the model to training mode
    total_loss = total_examples = 0  # Initialize loss and example counters

    # Iterate over batches of data in the training loader
    for batch in model.train_loader:
        model.optimizer.zero_grad()

        # Extract data
        head_index, rel_type, tail_index = (
            batch["head_index"],
            batch["rel_type"],
            batch["tail_index"],
        )
        x = batch.get("x", None)  # Optional feature matrix
        y = batch.get("y", None)  # Optional labels

        # Move data
        head_index, rel_type, tail_index = (
            head_index.to(device),
            rel_type.to(device),
            tail_index.to(device),
        )
        if x is not None:
            x = x.to(device)
        if y is not None:
            y = y.to(device)

        # Compute the loss for positive samples
        positive_loss = model.original_model.loss(
            head_index, rel_type, tail_index, x, y
        )

        # Perform negative sampling to generate negative triples
        neg_head_index, neg_rel_type, neg_tail_index = (
            model.original_model.random_sample(
                head_index,
                rel_type,
                tail_index,
            )
        )
        neg_head_index, neg_rel_type, neg_tail_index = (
            neg_head_index.to(device),
            neg_rel_type.to(device),
            neg_tail_index.to(device),
        )

        # Compute the loss for negative samples
        negative_loss = model.original_model.loss(
            neg_head_index,
            neg_rel_type,
            neg_tail_index,
            x,
            y,
        )

        # Calculate total loss and perform backpropagation
        loss = positive_loss + negative_loss
        loss.backward()
        model.optimizer.step()

        total_loss += float(loss) * head_index.size(0)
        total_examples += (
            2 * head_index.numel()
        )  # Count each head_index twice (once for positive and once for negative)

    # Compute average loss over all examples
    return total_loss / total_examples


def test_model(model: Module, val: bool = False) -> dict:
    """
    Tests the model on the validation or test dataset and returns performance metrics.

    Args:
        model (Module): The model to be tested.
        val (bool): Flag indicating whether to test on the validation dataset (default is False).

    Returns:
        dict: A dictionary containing performance metrics based on the model's task.
    """
    model.eval()  # Set the model to evaluation mode
    total_loss = total_examples = 0
    loader = model.val_loader if val else model.test_loader
    all_metrics = []
    only_relation_prediction = val  # during validation, only evaluate on relation prediction - on the test set do head, relation, and tail prediction

    for batch in loader:
        # Extract data from the batch
        head_index, rel_type, tail_index = (
            batch["head_index"],
            batch["rel_type"],
            batch["tail_index"],
        )
        x = batch.get("x", None)
        y = batch.get("y", None)

        # Move data to the appropriate device
        head_index, rel_type, tail_index = (
            head_index.to(device),
            rel_type.to(device),
            tail_index.to(device),
        )
        if x is not None:
            x = x.to(device)
        if y is not None:
            y = y.to(device)

        # Compute loss for the batch
        loss = model.original_model.loss(
            head_index,
            rel_type,
            tail_index,
            x,
            y,
            model.original_model.task,
            model.original_model.aux_dict,
        )
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()

        # Calculate metrics for the batch
        metrics = model.original_model.test(
            head_index=head_index,
            rel_type=rel_type,
            tail_index=tail_index,
            x=x,
            y=y,
            batch_size=model.batch_size,
            k=model.config.k,
            task=model.original_model.task,
            only_relation_prediction=only_relation_prediction,  # for kg_completion: during validation, only evaluate on relation prediction - on the test set do head, relation, and tail prediction
        )
        all_metrics.append(metrics)

    # Aggregate metrics across all batches
    if model.original_model.task in "kg_completion":
        if only_relation_prediction:
            relation_mean_ranks = []
            relation_mrrs = []
            relation_hits_at_ks = []
            # Collect the metrics from each batch
            for metrics in all_metrics:
                relation_mean_rank, relation_mrr, relation_hits_at_k = metrics
                relation_mean_ranks.append(relation_mean_rank)
                relation_mrrs.append(relation_mrr)
                relation_hits_at_ks.append(relation_hits_at_k)
            # Aggregate the metrics across all batches
            relation_mean_rank = sum(relation_mean_ranks) / len(all_metrics)
            relation_mrr = sum(relation_mrrs) / len(all_metrics)

            # Aggregate the hits@k metrics
            all_k_values = set().union(
                *(d.keys() for d in relation_hits_at_ks)
            )
            relation_hits_at_k = {
                f"relation_hits_at_{k}": sum(
                    d.get(k, 0) for d in relation_hits_at_ks
                )
                / len(all_metrics)
                for k in all_k_values
            }
            # Create the performance_metrics dictionary
            performance_metrics = {
                "loss": total_loss / total_examples,
                "relation_mean_rank": relation_mean_rank,
                "relation_mrr": relation_mrr,
                **relation_hits_at_k,
            }
        else:
            head_mean_ranks = []
            relation_mean_ranks = []
            tail_mean_ranks = []
            head_mrrs = []
            relation_mrrs = []
            tail_mrrs = []
            head_hits_at_ks = []
            relation_hits_at_ks = []
            tail_hits_at_ks = []

            # Collect the metrics from each batch
            for metrics in all_metrics:
                (
                    head_mean_rank,
                    relation_mean_rank,
                    tail_mean_rank,
                    head_mrr,
                    relation_mrr,
                    tail_mrr,
                    head_hits_at_k,
                    relation_hits_at_k,
                    tail_hits_at_k,
                ) = metrics
                head_mean_ranks.append(head_mean_rank)
                relation_mean_ranks.append(relation_mean_rank)
                tail_mean_ranks.append(tail_mean_rank)
                head_mrrs.append(head_mrr)
                relation_mrrs.append(relation_mrr)
                tail_mrrs.append(tail_mrr)
                head_hits_at_ks.append(head_hits_at_k)
                relation_hits_at_ks.append(relation_hits_at_k)
                tail_hits_at_ks.append(tail_hits_at_k)

            # Aggregate the metrics across all batches
            head_mean_rank = sum(head_mean_ranks) / len(all_metrics)
            relation_mean_rank = sum(relation_mean_ranks) / len(all_metrics)
            tail_mean_rank = sum(tail_mean_ranks) / len(all_metrics)
            head_mrr = sum(head_mrrs) / len(all_metrics)
            relation_mrr = sum(relation_mrrs) / len(all_metrics)
            tail_mrr = sum(tail_mrrs) / len(all_metrics)

            # Aggregate the hits@k metrics
            all_k_values = set().union(
                *(d.keys() for d in relation_hits_at_ks)
            )
            head_hits_at_k = {
                f"head_hits_at_{k}": sum(d.get(k, 0) for d in head_hits_at_ks)
                / len(all_metrics)
                for k in all_k_values
            }
            relation_hits_at_k = {
                f"relation_hits_at_{k}": sum(
                    d.get(k, 0) for d in relation_hits_at_ks
                )
                / len(all_metrics)
                for k in all_k_values
            }
            tail_hits_at_k = {
                f"tail_hits_at_{k}": sum(d.get(k, 0) for d in tail_hits_at_ks)
                / len(all_metrics)
                for k in all_k_values
            }

            # Create the performance_metrics dictionary
            performance_metrics = {
                "loss": total_loss / total_examples,
                "head_mean_rank": head_mean_rank,
                "relation_mean_rank": relation_mean_rank,
                "tail_mean_rank": tail_mean_rank,
                "head_mrr": head_mrr,
                "relation_mrr": relation_mrr,
                "tail_mrr": tail_mrr,
                **head_hits_at_k,
                **relation_hits_at_k,
                **tail_hits_at_k,
            }

    elif model.original_model.task == "node_classification":
        accuracy = sum(all_metrics) / len(all_metrics)
        performance_metrics = {
            "loss": total_loss / total_examples,
            "accuracy": accuracy,
        }

    else:
        raise ValueError(
            f"model task isn't valid: {model.original_model.task}"
        )

    return performance_metrics

### Saving the Model on Validation Loss Improvement

To ensure that the best model is saved in case the hyperparameter sweep is canceled, the model can be saved whenever there is an improvement in the validation loss. This can be achieved by monitoring the validation loss during training and saving the model whenever a new best validation loss is achieved. By doing so, you can retain the best model obtained during training even if the sweep is interrupted or canceled.

### Main Function Overview

The `main()` function serves as the core function for training the model and managing the training process. It initializes the training configuration, sets up the project environment using Weights and Biases for tracking, builds the model based on the provided configuration, and executes the training loop. Within `main()`, key functionalities include setting up the project, handling dataset-specific configurations, training the model, monitoring training progress, and saving the best model based on validation loss improvement. Additionally, it allows for early stopping based on increasing training loss and provides detailed logging of training metrics and test performance. `main()` orchestrates the training process and ensures the model is saved optimally during training.

In [5]:
def main(
    config: Optional[dict] = None,
    num_epochs_without_improvement_until_early_finish: int = 5,
    validate_after_this_many_epochs: int = 10,
):
    run_timestamp = datetime.now().strftime("%Y.%m.%d.%H.%M.%S")

    with wandb.init(
        project=f"ScoreMatchingDiffKG_Embedding",
        name=f"{run_timestamp}_embedding_model_run",
        config=config if config else {},
    ):
        config = wandb.config
        config["prefix"] = generate_prefix(config, run_timestamp)
        wandb.run.name = f"{config['prefix']}_embedding_model_run"

        if (
            wandb.config.task == "node_classification"
            and wandb.config.dataset_name
            not in [
                "Cora",
                "Citeseer",
                "Pubmed",
            ]
        ) or (
            wandb.config.task != "node_classification"
            and wandb.config.dataset_name
            in [
                "Cora",
                "Citeseer",
                "Pubmed",
            ]
        ):
            print(
                f"Skipping {wandb.config.task} on {wandb.config.dataset_name}"
            )
            return

        model = build_model(config)
        wandb.watch(model)

        num_epochs_without_improvement_until_early_finish = config.get(
            "num_epochs_without_improvement_until_early_finish",
            num_epochs_without_improvement_until_early_finish,
        )
        validate_after_this_many_epochs = config.get(
            "validate_after_this_many_epochs", validate_after_this_many_epochs
        )
        best_metric_to_optimize = float(
            "inf"
        )  # either val mean rank or val accuracy later
        metric_to_optimize = (
            "relation_mean_rank"
            if config["task"] == "kg_completion"
            else "accuracy"
        )  # if config["task"] == "node_classification"
        epochs_without_improvement = 0

        for epoch in range(config.max_epochs):
            loss = train_model(model)

            if config["verbose"]:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.10f}")

            train_metrics = {"train_epoch": epoch, "train_loss": loss}

            if epoch % validate_after_this_many_epochs == 0 and epoch > 0:
                val_metrics = test_model(model, val=True)
                if val_metrics[metric_to_optimize] <= best_metric_to_optimize:
                    best_metric_to_optimize = val_metrics[metric_to_optimize]
                    save_trained_embedding_weights_and_performance(
                        model, epoch + 1, val_metrics
                    )
                    epochs_without_improvement = 0
                else:
                    epochs_without_improvement += 1

                if config["verbose"]:
                    metrics_info = ", ".join(
                        [
                            f'Val {key.replace("_", " ").title()}: {value:.10f}'
                            for key, value in val_metrics.items()
                            if isinstance(value, (int, float))
                        ]
                    )
                    print(metrics_info)

                if (
                    epochs_without_improvement
                    >= num_epochs_without_improvement_until_early_finish
                ):
                    print("Stopping early due to increasing training loss.")
                    break

                val_metrics = {
                    f"val_{key}": value for key, value in val_metrics.items()
                }

            wandb.log(
                {**train_metrics, **val_metrics}
                if "val_metrics" in locals()
                else {**train_metrics}
            )

        model.load_state_dict(
            torch.load(
                osp.join(
                    model.save_path,
                    f"{model.config['prefix']}_embedding_model_weights.pth",
                )
            )
        )
        test_metrics = test_model(model)

        if config["verbose"]:
            test_metrics_info = ", ".join(
                [
                    f'Test {key.replace("_", " ").title()}: {value:.4f}'
                    for key, value in test_metrics.items()
                    if isinstance(value, (int, float))
                ]
            )
            print(test_metrics_info)

        with open(
            osp.join(
                model.save_path,
                f"{model.config['prefix']}_embedding_model_performance.txt",
            ),
            "a",
        ) as file:
            file.write("Test Metrics:\n")
            for metric, value in test_metrics.items():
                file.write(f"{metric}: {value}\n")

        test_metrics = {
            f"test_{key}": value for key, value in test_metrics.items()
        }
        wandb.log({**test_metrics})
        wandb.save(model.save_path, base_path=parent_dir)

        return model

In [None]:
def run_sweep_or_main(run_sweep: bool, config: Optional[dict] = None) -> None:
    """
    Runs a hyperparameter sweep or the main training process based on the provided flag.

    Args:
        run_sweep (bool): Flag indicating whether to run a hyperparameter sweep.
        config (Optional[dict]): Configuration dictionary for the main training process.

    Returns:
        None
    """
    if run_sweep:
        run_timestamp = datetime.now().strftime("%Y.%m.%d.%H.%M.%S")
        sweep_id = wandb.sweep(
            project=f"ScoreMatchingDiffKG_Embedding_Sweep_{run_timestamp}",
            sweep=config,
        )

        wandb.agent(sweep_id, function=main)
    else:
        model = main(config=config)


run_sweep_or_main(run_sweep=run_sweep, config=config)

Create sweep with ID: jr31bq1b
Sweep URL: https://wandb.ai/uiuc_idealab_2024/ScoreMatchingDiffKG_Embedding_Sweep_2024.05.17.17.43.53/sweeps/jr31bq1b


[34m[1mwandb[0m: Agent Starting Run: 83y0mcnz with config:
[34m[1mwandb[0m: 	batch_size: 512
[34m[1mwandb[0m: 	dataset_name: FB15k_237
[34m[1mwandb[0m: 	embedding_model_name: ComplEx
[34m[1mwandb[0m: 	hidden_channels: 512
[34m[1mwandb[0m: 	k: [1, 3, 10]
[34m[1mwandb[0m: 	lr: 0.0020062256919422847
[34m[1mwandb[0m: 	max_epochs: 2500
[34m[1mwandb[0m: 	num_epochs_without_improvement_until_early_finish: 5
[34m[1mwandb[0m: 	task: kg_completion
[34m[1mwandb[0m: 	validate_after_this_many_epochs: 10
[34m[1mwandb[0m: 	verbose: False
[34m[1mwandb[0m: 	weight_decay: 1e-05


Stopping early due to increasing training loss.
