#### Author: Blaine Hill

In this notebook, we program out how to embed a KG such as FB15k_237 using RotatE. The weights are saved under embedding_model_weights.pth 

In [1]:
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "embedding_model.ipynb"

import wandb

wandb.login()
# wandb.login(relogin=True)

import torch
import torch.optim as optim

import sys
import os.path as osp
from ipykernel import get_connection_file

from datetime import datetime
import yaml
from icecream import ic

notebook_path = osp.abspath(
    osp.join(os.getcwd(), osp.basename(get_connection_file()))
)
parent_dir = osp.dirname(osp.dirname(notebook_path))
sys.path.append(parent_dir)

from utils.utils import load_dataset, get_model_class
from utils.tqdm_utils import silence_tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

[34m[1mwandb[0m: Currently logged in as: [33mbthill1[0m ([33muiuc_idealab_2024[0m). Use [1m`wandb login --relogin`[0m to force relogin
  return torch._C._cuda_getDeviceCount() > 0


Here we decide whether to train the model on specific hyperparameters stored in config or to run a Weights and Biases sweep to locate the best hyperparameters as defined in `sweep_config.yaml`

Set `run_sweep=True` to run the sweep and `False` to train the model on the defined config variable.

In [2]:
run_sweep = True
if run_sweep:
    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)
else:
    config = {
        "dataset_name": "FB15k_237",
        "embedding_model_name": "RotatE",
        "epochs": 2,
        "batch_size": 64,
        "lr": 0.001,
        "weight_decay": 1e-6,
        "k": 10,  # used for top-k evaluation
        "hidden_channels": 50,
        "verbose": True,
    }

In [3]:
def build_model(config):
    train_data, val_data, test_data, data_path = load_dataset(
        config["dataset_name"], parent_dir=parent_dir, device=device
    )
    model_class = get_model_class(config["embedding_model_name"])
    model = model_class(
        num_nodes=train_data.num_nodes,
        num_relations=train_data.num_edge_types,
        hidden_channels=config.hidden_channels,
    ).to(device)
    train_loader = model.loader(
        head_index=train_data.edge_index[0],
        rel_type=train_data.edge_type,
        tail_index=train_data.edge_index[1],
        batch_size=config.batch_size,
        shuffle=True,
    )
    val_loader = model.loader(
        head_index=val_data.edge_index[0],
        rel_type=val_data.edge_type,
        tail_index=val_data.edge_index[1],
        batch_size=config.batch_size,
        shuffle=True,
    )
    test_loader = model.loader(
        head_index=test_data.edge_index[0],
        rel_type=test_data.edge_type,
        tail_index=test_data.edge_index[1],
        batch_size=config.batch_size,
        shuffle=True,
    )

    optimizer = optim.Adagrad(
        model.parameters(), lr=config.lr, weight_decay=config.weight_decay
    )

    model.train_data = train_data
    model.val_data = val_data
    model.test_data = test_data
    model.train_loader = train_loader
    model.val_loader = val_loader
    model.test_loader = test_loader
    model.optimizer = optimizer
    model.config = config

    return model

In [4]:
def corrupt_triples(head_index, rel_type, tail_index, num_entities, num_neg=1):
    n = head_index.size(0)
    neg_head_index = (
        head_index.expand(num_neg, n).transpose(0, 1).contiguous().view(-1)
    )
    neg_tail_index = (
        tail_index.expand(num_neg, n).transpose(0, 1).contiguous().view(-1)
    )
    neg_rel_type = (
        rel_type.expand(num_neg, n).transpose(0, 1).contiguous().view(-1)
    )

    # Randomly corrupt either head or tail.
    selector = torch.rand(n * num_neg) < 0.5
    neg_head_index[selector] = torch.randint(
        num_entities, size=neg_head_index[selector].size(), dtype=torch.long
    )
    neg_tail_index[~selector] = torch.randint(
        num_entities, size=neg_tail_index[~selector].size(), dtype=torch.long
    )

    return neg_head_index, neg_rel_type, neg_tail_index


def train_model(model):
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in model.train_loader:
        model.optimizer.zero_grad()

        # Negative sampling.
        neg_head_index, neg_rel_type, neg_tail_index = corrupt_triples(
            head_index,
            rel_type,
            tail_index,
            num_entities=model.train_data.num_nodes,
            num_neg=1,
        )

        # Positive loss.
        loss = model.loss(head_index, rel_type, tail_index)

        # Negative loss.
        loss += model.loss(neg_head_index, neg_rel_type, neg_tail_index)

        loss.backward()
        model.optimizer.step()
        # scale on specific number of elements in batch
        total_loss += float(loss) * head_index.size(0)
        total_examples += 2 * head_index.numel()
    return total_loss / total_examples


@silence_tqdm
@torch.no_grad()
def test_model(model, val=False):
    model.eval()
    total_loss = total_examples = 0
    loader = model.val_loader if val else model.test_loader
    for head_index, rel_type, tail_index in loader:
        loss = model.loss(head_index, rel_type, tail_index)
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()

    mean_rank, mrr, hits_at_k = model.test(
        head_index=head_index,
        rel_type=rel_type,
        tail_index=tail_index,
        batch_size=model.config.batch_size,
        k=model.config.k,
    )

    performance_metrics = {
        "loss": total_loss / total_examples,
        "mean_rank": mean_rank,
        "mrr": mrr,
        "hits_at_k": hits_at_k,
    }

    return performance_metrics

In [5]:
def main(config=None):
    # intialize wandb
    with wandb.init(
        project=f"ScoreMatchingDiffKG_Embedding",
        name=f"run_{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",  # Use a temporary name
        config=config if config is not None else {},
    ):
        config = wandb.config

        # Use config values after they are set
        wandb.run.name = f'{config["dataset_name"]}_{config["embedding_model_name"]}_embedding_model {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}'

        model = build_model(config)
        wandb.watch(model)

        best_val_loss = float("inf")
        best_model_state = None

        # train model
        for epoch in range(config.epochs):
            loss = train_model(model)
            if config["verbose"]:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.4f}")
            train_metrics = {"train_epoch": epoch, "train_loss": loss}
            # test model on validation data
            if epoch % 10 == 0 and epoch > 0:
                performance_metrics = test_model(model, val=True)
                if config["verbose"]:
                    print(
                        f'Val Mean Rank: {performance_metrics["mean_rank"]:.2f}, Val Mean Reciprocal Rank: {performance_metrics["mrr"]:.2f}, Val Hits@k: {performance_metrics["hits_at_k"]:.4f}'
                    )
                val_metrics = {
                    "val_loss": performance_metrics["loss"],
                    "val_mean_rank": performance_metrics["mean_rank"],
                    "val_mrr": performance_metrics["mrr"],
                    "val_hits_at_10": performance_metrics["hits_at_k"],
                }

                if performance_metrics["loss"] < best_val_loss:
                    best_val_loss = performance_metrics["loss"]
                    best_model_state = model.state_dict()

            # log to wandb
            wandb.log(
                {**train_metrics, **val_metrics}
                if "val_metrics" in locals()
                else {**train_metrics}
            )

        # Save the best model after all epochs
        if best_model_state:

            model.load_state_dict(best_model_state)
            # once everything is finished, test model
            performance_metrics = test_model(model)
            if config["verbose"]:
                print(
                    f'Test Mean Rank: {performance_metrics["mean_rank"]:.2f}, Test Mean Reciprocal Rank: {performance_metrics["mrr"]:.2f}, Test Hits@k: {performance_metrics["hits_at_k"]:.4f}'
                )
            test_metrics = {
                "test_loss": performance_metrics["loss"],
                "test_mean_rank": performance_metrics["mean_rank"],
                "test_mrr": performance_metrics["mrr"],
                "test_hits_at_k": performance_metrics["hits_at_k"],
            }
            wandb.log({**test_metrics})

            # Save the trained model
            path = osp.join(
                os.getcwd(),
                f'{config["dataset_name"]}_{config["embedding_model_name"]}_embedding_model_weights.pth',
            )
            model_state_dict = model.state_dict()
            model_state_dict["num_nodes"] = model.train_data.num_nodes
            model_state_dict["num_relations"] = model.train_data.num_edge_types
            model_state_dict["hidden_channels"] = model.config[
                "hidden_channels"
            ]
            model_state_dict["dataset_name"] = config["dataset_name"]
            model_state_dict["embedding_model_name"] = config[
                "embedding_model_name"
            ]

            # Save the model as pth
            torch.save(model_state_dict, path)

            # Fetch a batch from train_loader
            for batch in model.train_loader:
                # Assuming batch contains head_index, rel_type, tail_index, and possibly other data
                head_index, rel_type, tail_index = batch
                break  # Only need one batch for this purpose
            head_index, rel_type, tail_index = (
                head_index.to(device),
                rel_type.to(device),
                tail_index.to(device),
            )

            # Use the fetched batch to provide dummy inputs for the export
            # Ensure these variables are moved to the same device as model if necessary
            torch.onnx.export(
                model,
                (
                    head_index,
                    rel_type,
                    tail_index,
                ),  # Use actual data as dummy inputs
                f'{config["dataset_name"]}_{config["embedding_model_name"]}_embedding_model_weights.onnx',
                opset_version=11,
                do_constant_folding=True,
                input_names=[
                    "head_index",
                    "rel_type",
                    "tail_index",
                ],
                dynamic_axes={
                    "head_index": {0: "batch_size"},
                    "rel_type": {0: "batch_size"},
                    "tail_index": {0: "batch_size"},
                },
            )
            wandb.save(
                f'{config["dataset_name"]}_{config["embedding_model_name"]}_embedding_model_weights.onnx'
            )

        return model

In [6]:
if run_sweep:

    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)

    sweep_id = wandb.sweep(
        project=f"ScoreMatchingDiffKG_Embedding_Sweep", sweep=sweep_config
    )

    wandb.agent(sweep_id, function=main)
else:
    model = main(config)

Epoch: 000, Train Loss: 1.0000
Epoch: 001, Train Loss: 1.0000


0,1
train_epoch,▁█
train_loss,▁▁

0,1
train_epoch,1.0
train_loss,1.0
