Author: Blaine Hill

In this notebook, we train a score matching model as described by [Yang Song](https://yang-song.net/blog/2021/score/) to do KG Reasoning tasks such as [link prediction](https://paperswithcode.com/task/link-prediction)

$$s_\theta(x) = \nabla_x \log p_\theta(x) = -\nabla_x f_\theta(x) - \nabla_x \log Z_\theta = -\nabla_x f_\theta(x)$$

In [None]:
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "score_matching_model.ipynb"

import wandb

wandb.login()
# wandb.login(relogin=True)


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

import sys
import os.path as osp
from ipykernel import get_connection_file

import math
import numpy as np
from icecream import ic
import yaml
from datetime import datetime
import random

notebook_path = osp.abspath(
    osp.join(os.getcwd(), osp.basename(get_connection_file()))
)
parent_dir = osp.dirname(osp.dirname(notebook_path))
sys.path.append(parent_dir)

# Import utility functions and model utilities
from utils.utils import *
from utils.score_matching_model_utils import *
# from utils.score_matching_model.utils import *

from utils.score_matching_models.ScoreModel import ScoreModel


set_seed()

device = "cuda" if torch.cuda.is_available() else "cpu"

set_seed()

In [None]:
run_sweep = True
if run_sweep:
    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)
else:
    config = {
        "embedding_model_dir": "../trained_embedding_models/YAGO3_10_ComplEx_relation_prediction_2024.05.13.21.03.48_1cce66c3_embedding_model/",
        "score_model_hidden_dim": 512,
        "batch_size": 64,
        "epochs": 3,
        "num_timesteps": 20,  # Adjust based on your desired SDE iterations
        "lr": 1e-4,
        "similarity_metric": "cosine",
        "verbose": True,
        "k": [1, 3, 10],  # used for top-k evaluation
    }

In [None]:
def build_model(config, view_english=False):
    embedding_model_dir = config["embedding_model_dir"]
    embedding_model, embedding_model_name, dataset_name = (
        initialize_trained_embedding_model(embedding_model_dir)
    )
    embedding_model.eval()

    train_data, val_data, test_data, data_path = load_dataset(
        dataset_name, parent_dir=parent_dir, device=device
    )

    original_model = ScoreModel(
        embedding_model,
        config["score_model_hidden_dim"],
    ).to(device)


    # Calculate batch size based on the number of GPUs available
    num_gpus = torch.cuda.device_count()
    batch_size_per_gpu = config["batch_size"]
    batch_size = (
        batch_size_per_gpu * num_gpus if num_gpus > 0 else batch_size_per_gpu
    )
    num_workers = num_gpus if num_gpus > 0 else 1

    # Helper function to prepare loader arguments
    def prepare_loader_args(data, batch_size, num_workers):
        """Prepares and returns loader arguments based on the data and configuration."""
        loader_args = {
            "head_index": data.edge_index[0],
            "rel_type": data.edge_type,
            "tail_index": data.edge_index[1],
            "batch_size": batch_size,
            "shuffle": True,
            "num_workers": num_workers,
        }
        # Add extra features if available
        if hasattr(data, "x"):
            loader_args["x"] = data.x
        if hasattr(data, "y"):
            loader_args["y"] = data.y
        return loader_args

    # Prepare loader arguments for train, validation, and test datasets
    train_loader_args = prepare_loader_args(
        train_data, batch_size, num_workers
    )
    val_loader_args = prepare_loader_args(val_data, batch_size, num_workers)
    test_loader_args = prepare_loader_args(test_data, batch_size, num_workers)

    # Create data loaders for training, validation, and testing
    train_loader = DataLoader(**train_loader_args)
    val_loader = DataLoader(**val_loader_args)
    test_loader = DataLoader(**test_loader_args)

    # # Create data loaders
    # train_dataset = TensorDataset(
    #     train_data.edge_index[0],
    #     train_data.edge_type,
    #     train_data.edge_index[1],
    # )
    # train_loader = DataLoader(
    #     train_dataset, batch_size=config["batch_size"], shuffle=True
    # )

    # val_dataset = TensorDataset(
    #     val_data.edge_index[0], val_data.edge_type, val_data.edge_index[1]
    # )
    # val_loader = DataLoader(
    #     val_dataset, batch_size=config["batch_size"], shuffle=False
    # )

    # test_dataset = TensorDataset(
    #     test_data.edge_index[0], test_data.edge_type, test_data.edge_index[1]
    # )
    # test_loader = DataLoader(
    #     test_dataset, batch_size=config["batch_size"], shuffle=False
    # )

    optimizer = Adam(score_model.parameters(), lr=config["lr"])

    # Use DataParallel for multi-GPU setups on single server
    if num_gpus > 1:
        score_model = torch.nn.DataParallel(original_model)
    else:
        score_model = original_model

    if view_english:
        entity_dict, relation_dict = load_dicts(
            data_path
        )  # note: if you want to load english for RotatE, store the dicts to map from indices to freebase ids/ relations in data_path/processed
        score_model.entity_dict = entity_dict
        score_model.relation_dict = relation_dict

    score_model.train_data = train_data
    score_model.val_data = val_data
    score_model.test_data = test_data
    score_model.train_loader = train_loader
    score_model.val_loader = val_loader
    score_model.test_loader = test_loader
    score_model.optimizer = optimizer
    score_model.config = config
    score_model.dataset_name = dataset_name
    score_model.embedding_model_name = embedding_model_name
    score_model.embedding_model_dir = embedding_model_dir
    score_model.original_model = (
        original_model  # Keep a reference to the original model
    )

    # Determine the path to save results and create directory if it does not exist
    if "save_path" in config:
        save_path = config["save_path"]
        if not os.path.exists(save_path):
            raise ValueError(
                f"Directory {save_path} does not exist in which to save the trained score matching models. Please create it before saving."
            )
        save_path = osp.join(
            save_path,
            f"{config['prefix']}_score_matching_model",
        )
    else:
        save_path = osp.join(
            parent_dir,
            "trained_score_matching_models",
            f"{config['prefix']}_score_matching_model",
        )
    
    os.makedirs(save_path, exist_ok=True)
    model.save_path = save_path

    # Save the model configuration
    save_embedding_model_config(model)


# def denoising_score_matching_loss(score_model, h, r, t, timestep):
#     """
#     Denoising score-matching loss with noise-conditional score networks.
#     """
#     h_emb, r_emb, t_emb = (
#         score_model.embedding_model.node_emb(h),
#         score_model.embedding_model.rel_emb(r),
#         score_model.embedding_model.node_emb(t),
#     )
#     true_score = score_model(
#         h_emb, r_emb, t_emb
#     )  # simply do not pass in the timestep
#     noisy_score = score_model(h_emb, r_emb, t_emb, timestep)
#     return ((true_score - noisy_score) ** 2).mean()


def train_model(score_model):
    score_model.train()
    loss_epoch = 0
    for batch in score_model.train_loader:
            # Extract data from the batch
            h, r, t = (
                batch["head_index"],
                batch["rel_type"],
                batch["tail_index"],
            )
            x = batch.get("x", None)
            y = batch.get("y", None)

            # Move data to the appropriate device
            h, r, t = (
                h.to(device),
                r.to(device),
                t.to(device),
            )

        for timestep in range(score_model.config["num_timesteps"]):
            loss = score_model.original_model.loss(
                score_model, h, r, t, timestep, x, y, task
            )
            loss.backward()
        score_model.optimizer.step()

        loss_epoch += loss.item()

    return loss_epoch / len(score_model.train_loader)


def test_model(score_model, val=False):
    score_model.eval()
    test_loss = 0
    loader = score_model.val_loader if val else score_model.test_loader
    all_metrics = []
    with torch.no_grad():
        for batch in loader:
            # Extract data from the batch
            h, r, t = (
                batch["head_index"],
                batch["rel_type"],
                batch["tail_index"],
            )
            x = batch.get("x", None)
            y = batch.get("y", None)

            # Move data to the appropriate device
            h, r, t = (
                h.to(device),
                r.to(device),
                t.to(device),
            )
            if x is not None:
                x = x.to(device)
            if y is not None:
                y = y.to(device)
            for timestep in range(score_model.config["num_timesteps"]):
                loss = score_model.original_model.loss(
                    score_model, h, r, t, timestep
                )
                test_loss += loss.item()

            # Calculate metrics for the batch
            metrics = model.original_model.test(
                head_index=h,
                rel_type=r,
                tail_index=t,
                x=x,
                y=y,
                batch_size=model.config["batch_size"],
                k=model.config["k"],
                task=model.original_model.config["task"],
            )
            all_metrics.append(metrics)

    # Aggregate metrics across all batches
    if model.original_model.config["task"] in ["relation_prediction", "head_prediction", "tail_prediction"]:
        mean_rank = sum(m[0] for m in all_metrics) / len(all_metrics)
        mrr = sum(m[1] for m in all_metrics) / len(all_metrics)
        hits_at_k = {k: sum(m[2][k] for m in all_metrics) / len(all_metrics) for k in m[2]}
        performance_metrics = {
            "loss": total_loss / total_examples,
            "mean_rank": mean_rank,
            "mrr": mrr,
        }
        for k_value, hits_value in hits_at_k.items():
            performance_metrics[f"hits_at_{k_value}"] = hits_value

    elif model.original_model.config["task"] == "node_classification":
        accuracy = sum(all_metrics) / len(all_metrics)
        performance_metrics = {
            "loss": total_loss / total_examples,
            "accuracy": accuracy,
        }
    
    else: 
        raise ValueError(f"model task isn't valid: {model.original_model.config["task"]}")

    return performance_metrics


def reverse_sde_link_prediction(h_emb, t_emb, score_model):
    """
    Refine the relation embedding using reverse SDE for better link prediction.
    h_emb: Embeddings of the head entities (batch_size, node_emb_dim).
    t_emb: Embeddings of the tail entities (batch_size, node_emb_dim).
    score_model: The trained score model.
    """

    # Initialize with random noise
    r_emb = torch.randn(h_emb.size(0), score_model.rel_emb_dim, device=device)

    # Define the time steps for the reverse SDE process
    t_steps = torch.linspace(
        score_model.config["num_timesteps"] - 1, 0, score_model.config["num_timesteps"]
    ).to(device)

    # Define the SDE function for the reverse process
    def sde_func(t, r_emb):
        with torch.enable_grad():
            r_emb.requires_grad_(True)
            score = score_model(h_emb, r_emb, t_emb, t)
            grad_r_emb = torch.autograd.grad(
                score.sum(dim=0), r_emb, create_graph=True
            )[0]
        return -grad_r_emb  # Reverse direction

    # Perform the reverse SDE integration
    with torch.no_grad():
        for t in reversed(t_steps):
            r_emb = sde_func(t, r_emb)

    return r_emb  # returns predictions for r of shape batch, rel_emb_dim


def compute_link_prediction_metrics(
    score_model, val=False, view_english=False
):
    all_ranks = []
    all_hits_at_1, all_hits_at_3, all_hits_at_10 = [], [], []

    # Get all relation embeddings from the model
    all_relations = score_model.embedding_model.rel_emb.weight.detach()
    loader = score_model.val_loader if val else score_model.test_loader
    for batch in loader:
        h, true_r, t = batch
        h, true_r, t = h.to(device), true_r.to(device), t.to(device)

        h_emb = score_model.embedding_model.node_emb(h)
        t_emb = score_model.embedding_model.node_emb(t)
        # Assuming reverse_sde_link_prediction function returns refined relation embeddings
        refined_r_emb = reverse_sde_link_prediction(h_emb, t_emb, score_model)

        if score_model.config["similarity_metric"] == "cosine":
            # Calculate cosine similarity
            similarity = F.cosine_similarity(
                refined_r_emb.unsqueeze(1), all_relations.unsqueeze(0), dim=-1
            )

            # Since cosine similarity is higher for closer vectors, we use largest=True
            values1, indices1 = similarity.topk(1, dim=1, largest=True)
            values3, indices3 = similarity.topk(3, dim=1, largest=True)
            values10, indices10 = similarity.topk(10, dim=1, largest=True)
        elif score_model.config["similarity_metric"] == "l2":
            # Calculate Euclidean distance
            dist = torch.norm(
                refined_r_emb.unsqueeze(1) - all_relations.unsqueeze(0),
                p=2,
                dim=-1,
            )

            # For Euclidean distance, closer vectors have smaller distances, so largest=False
            values1, indices1 = dist.topk(1, dim=1, largest=False)
            values3, indices3 = dist.topk(3, dim=1, largest=False)
            values10, indices10 = dist.topk(10, dim=1, largest=False)
        else:
            raise NotImplementedError(
                f"Haven't implemented a similarity metric yet for {score_model.config['similarity_metric']}"
            )
        if view_english and score_model.dataset_name == "FB15k_237":
            # for getting english from freebase ID for FB15k_237 data
            # show_only_first=True means that only the first entry in the batch will be processed and returned in english
            ic(
                convert_indices_to_english(
                    indices10, score_model.relation_dict, is_entities=False
                )
            )
            ic(
                convert_indices_to_english(
                    true_r, score_model.relation_dict, is_entities=False
                )
            )
            ic(convert_indices_to_english(h, score_model.entity_dict))
            ic(convert_indices_to_english(t, score_model.entity_dict))

        # Check if true relation is within the top K predictions
        hits_at_1 = (
            (indices1 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        )
        all_hits_at_1.append(hits_at_1)
        hits_at_3 = (
            (indices3 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        )
        all_hits_at_3.append(hits_at_3)
        hits_at_10 = (
            (indices10 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        )
        all_hits_at_10.append(hits_at_10)

        # Calculate rank of the true relation
        for i in range(indices1.size(0)):
            true_relation_idx = (indices10[i] == true_r[i]).nonzero(
                as_tuple=True
            )[0]
            if true_relation_idx.numel() > 0:
                all_ranks.append(true_relation_idx.item() + 1)

    mean_rank = np.mean(all_ranks)
    mrr = np.mean([1.0 / rank for rank in all_ranks if rank > 0])
    hits_at_1 = np.mean(all_hits_at_1)
    hits_at_3 = np.mean(all_hits_at_3)
    hits_at_10 = np.mean(all_hits_at_10)

    link_prediction_performance_metrics = {
        "mean_rank": mean_rank,
        "mrr": mrr,
        "hits_at_1": hits_at_1,
        "hits_at_3": hits_at_3,
        "hits_at_10": hits_at_10,
    }

    return link_prediction_performance_metrics

    
def main(config=None):
    run_timestamp = datetime.now().strftime("%Y.%m.%d.%H.%M.%S")
    with wandb.init(
        project=f"ScoreMatchingDiffKG_ScoreMatching",
        name=f"{run_timestamp}_score_matching_model_run",  # Use a temporary name
        config=config if config is not None else {},
    ):
        config = wandb.config
        config["dataset_name"], config["embedding_model_name"], config["task"] = extract_info_from_string(config["embedding_model_dir"])
        config["prefix"] = (
            f'{config["dataset_name"]}_{config["embedding_model_name"]}_{config["task"]}_{run_timestamp}_{generate_unique_string(config)}'
        )
        config["prefix"] = generate_prefix(config, run_timestamp)


        wandb.run.name = f"{config['prefix']}_score_matching_model_run"
        score_model = build_model(config)
        wandb.watch(score_model)

        best_val_loss = float("inf")
        best_model_state = None

        for epoch in range(config.epochs):
            loss = train_model(score_model)
            if config["verbose"]:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.4f}")
            train_metrics = {"train_epoch": epoch, "train_loss": loss}
            if epoch % 10 == 0 and epoch > 0:
                # set val to be True since we are running over our validation data
                val = True
                val_loss = test_model(score_model, val=val)
                link_prediction_performance_metrics = (
                    compute_link_prediction_metrics(score_model, val=val)
                )
                if config["verbose"]:
                    print(
                        f'Val Mean Rank: {link_prediction_performance_metrics["mean_rank"]:.2f}, Val Mean Reciprocal Rank: {link_prediction_performance_metrics["mrr"]:.2f}, Val Hits@1: {link_prediction_performance_metrics["hits_at_1"]:.4f}, Val Hits@3: {link_prediction_performance_metrics["hits_at_3"]:.4f}, Val Hits@10: {link_prediction_performance_metrics["hits_at_10"]:.4f}'
                    )
                val_metrics = {
                    "val_loss": val_loss,
                    "val_mean_rank": link_prediction_performance_metrics[
                        "mean_rank"
                    ],
                    "val_mrr": link_prediction_performance_metrics["mrr"],
                    "val_hits_at_1": link_prediction_performance_metrics[
                        "hits_at_1"
                    ],
                    "val_hits_at_3": link_prediction_performance_metrics[
                        "hits_at_3"
                    ],
                    "val_hits_at_10": link_prediction_performance_metrics[
                        "hits_at_10"
                    ],
                }

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_model_state = score_model.state_dict()

            # log to wandb
            wandb.log(
                {**train_metrics, **val_metrics}
                if "val_metrics" in locals()
                else {**train_metrics}
            )

        if best_model_state:
            score_model.load_state_dict(best_model_state)

            # once everything is finished, test model
            test_loss = test_model(score_model)
            link_prediction_performance_metrics = (
                compute_link_prediction_metrics(score_model)
            )
            if config["verbose"]:
                print(
                    f'Test Mean Rank: {link_prediction_performance_metrics["mean_rank"]:.2f}, Test Mean Reciprocal Rank: {link_prediction_performance_metrics["mrr"]:.2f}, Test Hits@1: {link_prediction_performance_metrics["hits_at_1"]:.4f}, Test Hits@3: {link_prediction_performance_metrics["hits_at_3"]:.4f}, Test Hits@10: {link_prediction_performance_metrics["hits_at_10"]:.4f}'
                )
            test_metrics = {
                "test_loss": test_loss,
                "test_mean_rank": link_prediction_performance_metrics[
                    "mean_rank"
                ],
                "test_mrr": link_prediction_performance_metrics["mrr"],
                "test_hits_at_1": link_prediction_performance_metrics[
                    "hits_at_1"
                ],
                "test_hits_at_3": link_prediction_performance_metrics[
                    "hits_at_3"
                ],
                "test_hits_at_10": link_prediction_performance_metrics[
                    "hits_at_10"
                ],
            }
            wandb.log({**test_metrics})

            # Save the trained model
            path = osp.join(
                os.getcwd(),
                f"{score_model.dataset_name}_score_matching_model_weights.pth",
            )
            torch.save(score_model.state_dict(), path)

            score_model.load_state_dict(best_model_state)

            # Fetch a batch from the train_loader and prepare it
            batch = fetch_and_prepare_batch(score_model.train_loader, device)
            head_index, rel_type, tail_index = batch  # Unpack the batch

            # Define the file path for the ONNX model
            file_path = f"{config['dataset_name']}_best_score_matching_model_weights.onnx"

            # Export the model to ONNX format using the fetched batch as dummy inputs
            export_model_to_onnx(
                model,
                (head_index, rel_type, tail_index),
                file_path,
                input_names=["head_index", "rel_type", "tail_index"],
                dynamic_axes={
                    "head_index": {0: "batch_size"},
                    "rel_type": {0: "batch_size"},
                    "tail_index": {0: "batch_size"},
                },
            )

            wandb.save(
                f"{score_model.dataset_name}_score_matching_model_weights.onnx"
            )

        return score_model

In [4]:
if run_sweep:

    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)

    sweep_id = wandb.sweep(
        project=f"ScoreMatchingDiffKG_Score_Sweep", sweep=sweep_config
    )

    wandb.agent(sweep_id, function=main)
else:
    model = main(config)