## Author: Blaine Hill

In this notebook, we program out how to embed a KG such as FB15k_237 using RotatE. The weights are saved under embedding_model_weights.pth

In [15]:
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "embedding_model.ipynb"

import wandb

wandb.login()

import torch
from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import RotatE
import torch.optim as optim

import sys
import os
import os.path as osp
from ipykernel import get_connection_file

from datetime import datetime

import yaml
from icecream import ic

notebook_path = osp.abspath(osp.join(os.getcwd(), osp.basename(get_connection_file())))
parent_dir = osp.dirname(osp.dirname(notebook_path))
sys.path.append(parent_dir)

from utils.tqdm_utils import silence_tqdm

Here we decide whether to train the model on specific hyperparameters stored in config or to run a Weights and Biases sweep to locate the best hyperparameters as defined in `sweep_config.yaml`

Set `run_sweep=True` to run the sweep and `False` to train the model on the defined config variable.

In [16]:
run_sweep = False
if run_sweep:
    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)
else:
    config = {
        "epochs": 500,
        "batch_size": 1000,
        "lr": 0.001,
        "weight_decay": 1e-6,
        "k": 10,  # used for top-k evaluation
        "hidden_channels": 50,
    }

In [17]:
dataset_name = "FB15k_237"
embedding_model_name = "RotatE"

device = "cuda" if torch.cuda.is_available() else "cpu"
data_path = osp.join(parent_dir, "data", dataset_name)

train_data = FB15k_237(data_path, split="train")[0].to(device)
val_data = FB15k_237(data_path, split="val")[0].to(device)
test_data = FB15k_237(data_path, split="test")[0].to(device)

In [18]:
def build_model(config):

    model = RotatE(
        num_nodes=train_data.num_nodes,
        num_relations=train_data.num_edge_types,
        hidden_channels=config.hidden_channels,
    ).to(device)

    train_loader = model.loader(
        head_index=train_data.edge_index[0],
        rel_type=train_data.edge_type,
        tail_index=train_data.edge_index[1],
        batch_size=config.batch_size,
        shuffle=True,
    )
    val_loader = model.loader(
        head_index=val_data.edge_index[0],
        rel_type=val_data.edge_type,
        tail_index=val_data.edge_index[1],
        batch_size=config.batch_size,
        shuffle=True,
    )
    test_loader = model.loader(
        head_index=test_data.edge_index[0],
        rel_type=test_data.edge_type,
        tail_index=test_data.edge_index[1],
        batch_size=config.batch_size,
        shuffle=True,
    )

    optimizer = optim.Adagrad(
        model.parameters(), lr=config.lr, weight_decay=config.weight_decay
    )

    return model, optimizer, train_loader, val_loader, test_loader

In [19]:
def corrupt_triples(head_index, rel_type, tail_index, num_entities, num_neg=1):
    n = head_index.size(0)
    neg_head_index = head_index.expand(num_neg, n).transpose(0, 1).contiguous().view(-1)
    neg_tail_index = tail_index.expand(num_neg, n).transpose(0, 1).contiguous().view(-1)
    neg_rel_type = rel_type.expand(num_neg, n).transpose(0, 1).contiguous().view(-1)

    # Randomly corrupt either head or tail.
    selector = torch.rand(n * num_neg) < 0.5
    neg_head_index[selector] = torch.randint(
        num_entities, size=neg_head_index[selector].size(), dtype=torch.long
    )
    neg_tail_index[~selector] = torch.randint(
        num_entities, size=neg_tail_index[~selector].size(), dtype=torch.long
    )

    return neg_head_index, neg_rel_type, neg_tail_index


def train_model(model, optimizer, train_loader):
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in train_loader:
        optimizer.zero_grad()

        # Negative sampling.
        neg_head_index, neg_rel_type, neg_tail_index = corrupt_triples(
            head_index, rel_type, tail_index, num_entities=train_data.num_nodes, num_neg=1
        )

        # Positive loss.
        loss = model.loss(head_index, rel_type, tail_index)

        # Negative loss.
        loss += model.loss(neg_head_index, neg_rel_type, neg_tail_index)

        loss.backward()
        optimizer.step()
        # scale on specific number of elements in batch
        total_loss += float(loss) * head_index.size(0)
        total_examples += 2 * head_index.numel()  # TODO: check if this correct
    return total_loss / total_examples


@silence_tqdm
@torch.no_grad()
def test_model(model, config, test_loader):
    model.eval()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in test_loader:
        loss = model.loss(head_index, rel_type, tail_index)
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()

    mean_rank, mrr, hits_at_k = model.test(
        head_index=head_index,
        rel_type=rel_type,
        tail_index=tail_index,
        batch_size=config.batch_size,
        k=config.k,
    )

    return total_loss / total_examples, mean_rank, mrr, hits_at_k

In [20]:
def main(config=None, verbose=False):
    with wandb.init(
        project=f"ScoreMatchingDiffKG_Embedding",
        name=f"{dataset_name}_{embedding_model_name}_embedding_model {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
        config=config if config is not None else {},
    ):
        config = wandb.config
        model, optimizer, train_loader, val_loader, test_loader = build_model(config)
        wandb.watch(model)
        for epoch in range(config.epochs):
            loss = train_model(model, optimizer, train_loader)
            if verbose:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.4f}")
            train_metrics = {"train_epoch": epoch, "train_loss": loss}
            if epoch % 10 == 0 and epoch > 0:
                loss, mean_rank, mrr, hits_at_10 = test_model(model, config, val_loader)
                if verbose:
                    print(
                        f"Val Mean Rank: {mean_rank:.2f}, Val Mean Reciprocal Rank: {mrr:.2f}, Val Hits@10: {hits_at_10:.4f}"
                    )
                val_metrics = {
                    "val_loss": loss,
                    "val_mean_rank": mean_rank,
                    "val_mrr": mrr,
                    "val_hits_at_10": hits_at_10,
                }
            # log to wandb
            wandb.log(
                {**train_metrics, **val_metrics}
                if "val_metrics" in locals()
                else {**train_metrics}
            )

        # once everything is finished, test model
        loss, mean_rank, mrr, hits_at_10 = test_model(model, config, test_loader)
        if verbose:
            print(
                f"Test Mean Rank: {mean_rank:.2f}, Test Mean Reciprocal Rank: {mrr:.2f}, Test Hits@10: {hits_at_10:.4f}"
            )
        test_metrics = {
            "test_loss": loss,
            "test_mean_rank": mean_rank,
            "test_mrr": mrr,
            "test_hits_at_10": hits_at_10,
        }
        wandb.log({**test_metrics})

        # Save the trained model
        path = osp.join(
            os.getcwd(), f"{dataset_name}_{embedding_model_name}_embedding_model_weights.pth"
        )
        torch.save(model.state_dict(), path)

        # Fetch a batch from train_loader
        for batch in train_loader:
            # Assuming batch contains head_index, rel_type, tail_index, and possibly other data
            head_index, rel_type, tail_index = batch
            break  # Only need one batch for this purpose

        # Use the fetched batch to provide dummy inputs for the export
        # Ensure these variables are moved to the same device as model if necessary
        torch.onnx.export(
            model,
            (head_index, rel_type, tail_index),  # Use actual data as dummy inputs
            f"{dataset_name}_{embedding_model_name}_embedding_model_weights.onnx",
            opset_version=11,
            do_constant_folding=True,
            input_names=[
                "head_index",
                "rel_type",
                "tail_index",
            ],  # Adjust input names as needed
            dynamic_axes={
                "head_index": {0: "batch_size"},
                "rel_type": {0: "batch_size"},
                "tail_index": {0: "batch_size"},
            },
        )
        wandb.save(f"{dataset_name}_{embedding_model_name}_embedding_model_weights.onnx")

        return model

In [21]:
if run_sweep:

    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)

    sweep_id = wandb.sweep(project=f"ScoreMatchingDiffKG_Embedding_Sweep", sweep=sweep_config)

    wandb.agent(sweep_id, function=main)
else:
    model = main(config, verbose=True)

Epoch: 000, Train Loss: 0.7735
Epoch: 001, Train Loss: 0.7657
Epoch: 002, Train Loss: 0.7609
Epoch: 003, Train Loss: 0.7572
Epoch: 004, Train Loss: 0.7541
Epoch: 005, Train Loss: 0.7515
