#### Author: Blaine Hill

In this notebook, we program out how to embed a KG such as FB15k_237 using a KG Embedding model such as RotatE. The weights are saved under at the top of the project directory under trained_embedding_models/embedding_model_weights.pth 

In [None]:
import os
import sys
import random
import numpy as np
import yaml
from datetime import datetime
from icecream import ic
import torch
from torch import nn
import torch.optim as optim
from torch.optim import Optimizer
from torch.nn import Module
from datetime import datetime
from ipykernel import get_connection_file
from typing import Dict, Optional

# Set the notebook name for Weights and Biases tracking
os.environ["WANDB_NOTEBOOK_NAME"] = "embedding_model.ipynb"

import wandb

wandb.login()  # Ensure you are logged into Weights and Biases

# Set the device to GPU if available, otherwise fallback to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define the path to the notebook and the parent directory for imports
notebook_path = os.path.abspath(
    os.path.join(os.getcwd(), os.path.basename(get_connection_file()))
)
parent_dir = os.path.dirname(os.path.dirname(notebook_path))
sys.path.append(parent_dir)

# Import utility functions and model utilities
from utils.utils import *
from utils.embedding_models.utils import *


def set_seed(seed_value=123):
    """Set the seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


set_seed()

Here we decide whether to train the model on specific hyperparameters stored in config or to run a Weights and Biases sweep to locate the best hyperparameters as defined in `sweep_config.yaml`

Set `run_sweep=True` to run the sweep and `False` to train the model on the defined config variable.

In [None]:
# Decide whether to run a hyperparameter sweep or use a fixed configuration
run_sweep = True
sweep_config_file_path = "sweep_config.yaml"
if run_sweep:
    with open(sweep_config_file_path, "r") as file:
        config = yaml.safe_load(file)
else:
    config = {
        "dataset_name": "Cora",
        "embedding_model_name": "ComplEx",
        "task": "node_classification",  # relation_prediction=link prediction
        "epochs": 10000,
        "batch_size": 512,
        "lr": 0.001,
        "weight_decay": 1e-6,
        "k": 10,  # used for top-k evaluation
        "hidden_channels": 512,
        "verbose": True,
        "max_epochs_without_improvement": 5,
        "validate_after_this_many_epochs": 1,
    }

In [None]:
def build_model(config: Dict) -> nn.Module:
    """
    Constructs and initializes the model based on the provided configuration.

    Args:
        config (Dict): Configuration dictionary containing model and training settings.

    Returns:
        nn.Module: The initialized model ready for training.
    """
    # Load dataset based on configuration
    train_data, val_data, test_data, data_path = load_dataset(
        config["dataset_name"], parent_dir=parent_dir, device=device
    )

    # Determine feature dimensions if available
    head_node_feature_dim = (
        train_data.x.shape[1]
        if hasattr(train_data, "x") and train_data.x is not None
        else None
    )
    aux_dict = train_data.aux_dict if hasattr(train_data, "aux_dict") else None

    # Get the model class from the model name provided in config
    model_class = get_model_class(config["embedding_model_name"])
    original_model = model_class(
        num_nodes=train_data.num_nodes,
        num_relations=train_data.num_edge_types,
        hidden_channels=config["hidden_channels"],
        head_node_feature_dim=head_node_feature_dim,
        task=config["task"],
        aux_dict=aux_dict,
    ).to(device)

    # Calculate batch size based on the number of GPUs available
    num_gpus = torch.cuda.device_count()
    batch_size_per_gpu = config["batch_size"]
    batch_size = (
        batch_size_per_gpu * num_gpus if num_gpus > 0 else batch_size_per_gpu
    )
    num_workers = num_gpus if num_gpus > 0 else 1

    # Helper function to prepare loader arguments
    def prepare_loader_args(data, batch_size, num_workers):
        """Prepares and returns loader arguments based on the data and configuration."""
        loader_args = {
            "head_index": data.edge_index[0],
            "rel_type": data.edge_type,
            "tail_index": data.edge_index[1],
            "batch_size": batch_size,
            "shuffle": True,
            "num_workers": num_workers,
        }
        # Add extra features if available
        if hasattr(data, "x"):
            loader_args["x"] = data.x
        if hasattr(data, "y"):
            loader_args["y"] = data.y
        return loader_args

    # Prepare loader arguments for train, validation, and test datasets
    train_loader_args = prepare_loader_args(
        train_data, batch_size, num_workers
    )
    val_loader_args = prepare_loader_args(val_data, batch_size, num_workers)
    test_loader_args = prepare_loader_args(test_data, batch_size, num_workers)

    # Create data loaders for training, validation, and testing
    train_loader = original_model.loader(**train_loader_args)
    val_loader = original_model.loader(**val_loader_args)
    test_loader = original_model.loader(**test_loader_args)

    # Initialize the optimizer
    optimizer = optim.Adagrad(
        original_model.parameters(),
        lr=config["lr"],
        weight_decay=config["weight_decay"],
    )

    # Use DataParallel for multi-GPU setups on single server
    if num_gpus > 1:
        model = torch.nn.DataParallel(original_model)
    else:
        model = original_model

    # Attach data and loaders to the model for easy access
    model.train_data = train_data
    model.val_data = val_data
    model.test_data = test_data
    model.train_loader = train_loader
    model.val_loader = val_loader
    model.test_loader = test_loader
    model.optimizer = optimizer
    model.config = config
    model.original_model = (
        original_model  # Keep a reference to the original model
    )

    # Determine the path to save results and create directory if it does not exist
    if "save_path" in config:
        save_path = config["save_path"]
        if not os.path.exists(save_path):
            raise ValueError(
                f"Directory {save_path} does not exist in which to save the trained embedding models. Please create it before saving."
            )
    else:
        save_path = osp.join(
            parent_dir,
            "trained_embedding_models",
            f"{config['prefix']}_embedding_model",
        )
        os.makedirs(save_path, exist_ok=True)
    model.save_path = save_path

    # Save the model configuration
    save_model_config(model)

    return model

In [None]:
def train_model(model: Module) -> float:
    """
    Trains the model for one epoch over the training dataset.

    Args:
        model (Module): The model to be trained, which includes the data loaders and optimizer.

    Returns:
        float: The average training loss for the epoch.
    """
    model.train()  # Set the model to training mode
    total_loss = total_examples = 0  # Initialize loss and example counters

    # Iterate over batches of data in the training loader
    for batch in model.train_loader:
        model.optimizer.zero_grad()

        # Extract data
        head_index, rel_type, tail_index = (
            batch["head_index"],
            batch["rel_type"],
            batch["tail_index"],
        )
        x = batch.get("x", None)  # Optional feature matrix
        y = batch.get("y", None)  # Optional labels

        # Move data
        head_index, rel_type, tail_index = (
            head_index.to(device),
            rel_type.to(device),
            tail_index.to(device),
        )
        if x is not None:
            x = x.to(device)
        if y is not None:
            y = y.to(device)

        # Compute the loss for positive samples
        positive_loss = model.original_model.loss(
            head_index, rel_type, tail_index, x, y
        )

        # Perform negative sampling to generate negative triples
        neg_head_index, neg_rel_type, neg_tail_index = (
            model.original_model.random_sample(
                head_index,
                rel_type,
                tail_index,
                model.original_model.task,
                model.original_model.aux_dict,
            )
        )
        neg_head_index, neg_rel_type, neg_tail_index = (
            neg_head_index.to(device),
            neg_rel_type.to(device),
            neg_tail_index.to(device),
        )

        # Compute the loss for negative samples
        negative_loss = model.original_model.loss(
            neg_head_index,
            neg_rel_type,
            neg_tail_index,
            x,
            y,
            model.original_model.task,
            model.original_model.aux_dict,
        )

        # Calculate total loss and perform backpropagation
        loss = positive_loss + negative_loss
        loss.backward()
        model.optimizer.step()

        total_loss += float(loss) * head_index.size(0)
        total_examples += (
            2 * head_index.numel()
        )  # Count each head_index twice (once for positive and once for negative)

    # Compute average loss over all examples
    return total_loss / total_examples


def test_model(model: Module, val: bool = False) -> dict:
    """
    Tests the model on the validation or test dataset and returns performance metrics.

    Args:
        model (Module): The model to be tested.
        val (bool): Flag indicating whether to test on the validation dataset (default is False).

    Returns:
        dict: A dictionary containing performance metrics based on the model's task.
    """
    model.eval()  # Set the model to evaluation mode
    total_loss = total_examples = 0
    loader = model.val_loader if val else model.test_loader

    for batch in loader:
        # Extract data from the batch
        head_index, rel_type, tail_index = (
            batch["head_index"],
            batch["rel_type"],
            batch["tail_index"],
        )
        x = batch.get("x", None)
        y = batch.get("y", None)

        # Move data to the appropriate device
        head_index, rel_type, tail_index = (
            head_index.to(device),
            rel_type.to(device),
            tail_index.to(device),
        )
        if x is not None:
            x = x.to(device)
        if y is not None:
            y = y.to(device)

        # Compute loss for the batch
        loss = model.original_model.loss(
            head_index,
            rel_type,
            tail_index,
            x,
            y,
            model.original_model.task,
            model.original_model.aux_dict,
        )
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()

    # Calculate additional metrics based on the model's task
    metrics = model.original_model.test(
        head_index=head_index,
        rel_type=rel_type,
        tail_index=tail_index,
        x=x,
        y=y,
        batch_size=model.config.batch_size,
        k=model.config.k,
        task=model.original_model.task,
    )

    # Determine performance metrics based on the model's task
    if model.original_model.task in [
        "relation_prediction",
        "head_prediction",
        "tail_prediction",
    ]:
        mean_rank, mrr, hits_at_k = metrics
        performance_metrics = {
            "loss": total_loss / total_examples,
            "mean_rank": mean_rank,
            "mrr": mrr,
            "hits_at_k": hits_at_k,
        }
    elif model.original_model.task == "node_classification":
        accuracy = metrics
        performance_metrics = {
            "loss": total_loss / total_examples,
            "accuracy": accuracy,
        }

    return performance_metrics

In [None]:
def main(
    config: Optional[dict] = None,
    max_epochs_without_improvement: int = 5,
    validate_after_this_many_epochs: int = 10,
):
    run_timestamp = datetime.now().strftime("%Y.%m.%d.%H.%M.%S")

    with wandb.init(
        project=f"ScoreMatchingDiffKG_Embedding",
        name=f"{run_timestamp}_run",
        config=config if config else {},
    ):
        config = wandb.config
        config["prefix"] = (
            f'{config["dataset_name"]}_{config["embedding_model_name"]}_{run_timestamp}_{generate_unique_string(config)}'
        )
        dataset_name = wandb.config.dataset_name
        task = wandb.config.task

        if (
            task == "node_classification"
            and dataset_name
            not in [
                "Cora",
                "Citeseer",
                "Pubmed",
            ]
        ) or (
            task != "node_classification"
            and dataset_name
            in [
                "Cora",
                "Citeseer",
                "Pubmed",
            ]
        ):
            print(f"Skipping {task} on {dataset_name}")
            return

        wandb.run.name = f"{config['prefix']}_run"

        model = build_model(config)
        wandb.watch(model)

        max_epochs_without_improvement = config.get(
            "max_epochs_without_improvement", max_epochs_without_improvement
        )
        validate_after_this_many_epochs = config.get(
            "validate_after_this_many_epochs", validate_after_this_many_epochs
        )
        best_train_loss = best_val_loss = float("inf")
        epochs_without_improvement = 0

        for epoch in range(config.epochs):
            loss = train_model(model)

            if config["verbose"]:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.10f}")

            if loss <= best_train_loss:
                best_train_loss = loss
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1

                if (
                    epochs_without_improvement
                    >= max_epochs_without_improvement
                ):
                    print("Stopping early due to increasing training loss.")
                    break

            train_metrics = {"train_epoch": epoch, "train_loss": loss}

            if epoch % validate_after_this_many_epochs == 0 and epoch > 0:
                val_metrics = test_model(model, val=True)

                if config["verbose"]:
                    metrics_info = ", ".join(
                        [
                            f'Val {key.replace("_", " ").title()}: {value:.10f}'
                            for key, value in val_metrics.items()
                            if isinstance(value, (int, float))
                        ]
                    )
                    print(metrics_info)

                if val_metrics["loss"] < best_val_loss:
                    best_val_loss = val_metrics["loss"]
                    save_trained_embedding_model_and_config(
                        model, epoch, val_metrics
                    )

                val_metrics = {
                    f"val_{key}": value for key, value in val_metrics.items()
                }

            wandb.log(
                {**train_metrics, **val_metrics}
                if "val_metrics" in locals()
                else {**train_metrics}
            )

        model.load_state_dict(
            torch.load(
                osp.join(
                    model.save_path,
                    f"{model.config['prefix']}_embedding_model_weights.pth",
                )
            )
        )
        test_metrics = test_model(model)

        if config["verbose"]:
            test_metrics_info = ", ".join(
                [
                    f'Test {key.replace("_", " ").title()}: {value:.4f}'
                    for key, value in test_metrics.items()
                    if isinstance(value, (int, float))
                ]
            )
            print(test_metrics_info)

        with open(
            osp.join(
                model.save_path,
                f"{model.config['prefix']}_embedding_model_performance.txt",
            ),
            "a",
        ) as file:
            file.write("Test Metrics:\n")
            for metric, value in test_metrics.items():
                file.write(f"{metric}: {value}\n")

        test_metrics = {
            f"test_{key}": value for key, value in test_metrics.items()
        }
        wandb.log({**test_metrics})
        wandb.save(model.save_path)

        return model

In [18]:
def run_sweep_or_main(run_sweep: bool, config: Optional[dict] = None) -> None:
    """
    Runs a hyperparameter sweep or the main training process based on the provided flag.

    Args:
        run_sweep (bool): Flag indicating whether to run a hyperparameter sweep.
        config (Optional[dict]): Configuration dictionary for the main training process.

    Returns:
        None
    """
    if run_sweep:

        sweep_id = wandb.sweep(
            project=f"ScoreMatchingDiffKG_Embedding_Sweep", sweep=config
        )

        wandb.agent(sweep_id, function=main)
    else:
        model = main(config=config)


run_sweep_or_main(run_sweep=run_sweep, config=config)