Author: Blaine Hill

In this notebook, we train a score matching model as described by [Yang Song](https://yang-song.net/blog/2021/score/) to do [link prediction](https://paperswithcode.com/task/link-prediction)

$$s_\theta(x) = \nabla_x \log p_\theta(x) = -\nabla_x f_\theta(x) - \nabla_x \log Z_\theta = -\nabla_x f_\theta(x)$$

#TODO: figure out best hidden channel for RotatE on FB15k_237 from other research papers

#TODO: figure out the dataset split for FB15k_237 to ensure it is the same (comparing apples to apples)

In [1]:
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "score_matching_model.ipynb"

import wandb

wandb.login()
# wandb.login(relogin=True)


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

import sys
import os.path as osp
from ipykernel import get_connection_file

import math
import numpy as np
from icecream import ic
import yaml
from datetime import datetime

notebook_path = osp.abspath(
    osp.join(os.getcwd(), osp.basename(get_connection_file()))
)
parent_dir = osp.dirname(osp.dirname(notebook_path))
sys.path.append(parent_dir)

from utils.utils import (
    load_dataset,
    get_model_class,
    convert_indices_to_english,
    load_dicts,
)

device = "cuda" if torch.cuda.is_available() else "cpu"

[34m[1mwandb[0m: Currently logged in as: [33mbthill1[0m ([33muiuc_idealab_2024[0m). Use [1m`wandb login --relogin`[0m to force relogin
  return torch._C._cuda_getDeviceCount() > 0


In [2]:
def initialize_trained_embedding_model(
    embedding_model_path,
):  # Load the saved model weights
    complete_state_dict = torch.load(embedding_model_path)

    # Extract necessary information from the model state dictionary
    num_nodes = complete_state_dict["num_nodes"]
    num_relations = complete_state_dict["num_relations"]
    hidden_channels = complete_state_dict["hidden_channels"]
    embedding_model_name = complete_state_dict["embedding_model_name"]
    dataset_name = complete_state_dict["dataset_name"]
    if any(
        x is None
        for x in [
            num_nodes,
            num_relations,
            hidden_channels,
            embedding_model_name,
            dataset_name,
        ]
    ):
        raise ValueError(
            "embedding model num_nodes, num_relations, hidden_channels, embedding_model_name, or dataset_name not found in the saved embedding model complete state dict"
        )

    embedding_model_class = get_model_class(embedding_model_name)

    # Initialize a temporary instance of the model class
    temp_model = embedding_model_class(
        num_nodes=num_nodes,
        num_relations=num_relations,
        hidden_channels=hidden_channels,
    )

    # Create a new dictionary containing only the model's state
    model_state_dict = {
        k: v
        for k, v in complete_state_dict.items()
        if k in temp_model.state_dict().keys()
    }

    # Initialize the actual embedding model
    embedding_model = embedding_model_class(
        num_nodes=num_nodes,
        num_relations=num_relations,
        hidden_channels=hidden_channels,
    )
    embedding_model.load_state_dict(model_state_dict)
    embedding_model.to(device)
    return embedding_model, embedding_model_name, dataset_name

In [3]:
run_sweep = False
if run_sweep:
    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)
else:
    config = {
        "embedding_model_info": "../embedding_model/FB15k_237_RotatE_embedding_model_weights.pth",
        "score_model_hidden_dim": 512,
        "batch_size": 64,
        "epochs": 3,
        "num_steps": 20,  # Adjust based on your desired SDE iterations
        "lr": 1e-4,
        "similarity_metric": "cosine",
        "verbose": True,
    }

In [4]:
class ScoreModel(nn.Module):
    def __init__(
        self,
        embedding_model,
        node_emb_dim,
        rel_emb_dim,
        score_model_hidden_dim=512,
    ):
        super(ScoreModel, self).__init__()
        self.embedding_model = embedding_model
        self.node_emb_dim = node_emb_dim
        self.rel_emb_dim = rel_emb_dim
        self.score_net = nn.Sequential(
            nn.Linear(
                node_emb_dim + rel_emb_dim,
                score_model_hidden_dim,
                dtype=torch.float,
            ),
            nn.ReLU(),
            nn.Linear(
                score_model_hidden_dim,
                score_model_hidden_dim,
                dtype=torch.float,
            ),
            nn.ReLU(),
            nn.Linear(score_model_hidden_dim, 1, dtype=torch.float),
        )

    def forward(self, h_emb, r_emb, t_emb, timestep=None):
        # Implement your desired distance measure here (e.g., L2 distance)
        distance = torch.linalg.norm(h_emb + r_emb - t_emb, dim=-1)

        # Gradually increase the weight of the distance term during SDE steps
        weight = 0.0  # No weight increase if timestep is None
        if timestep is not None:

            def sigmoid(x):
                return 1 / (1 + math.exp(-x))

            weight = sigmoid(timestep / (self.config["num_steps"] - 1))
        score = weight * distance
        return score

In [5]:
def build_model(config):
    embedding_model_path = config["embedding_model_info"]
    embedding_model, embedding_model_name, dataset_name = (
        initialize_trained_embedding_model(embedding_model_path)
    )
    embedding_model.eval()

    train_data, val_data, test_data, data_path = load_dataset(
        dataset_name, parent_dir=parent_dir, device=device
    )

    # Get entity and relation embeddings
    with torch.no_grad():
        entity_embeddings_real = embedding_model.node_emb.weight.detach()
        relation_embeddings_real = embedding_model.rel_emb.weight.detach()

        if hasattr(
            embedding_model, "node_emb_im"
        ):  # occurs with RotatE, ComplEx
            entity_embeddings_im = embedding_model.node_emb_im.weight.detach()
            entity_embeddings = torch.cat(
                [entity_embeddings_real, entity_embeddings_im], dim=-1
            )
        else:
            entity_embeddings = entity_embeddings_real

        if hasattr(embedding_model, "rel_emb_im"):  # occurs with ComplEx
            relation_embeddings_im = embedding_model.rel_emb_im.weight.detach()
            relation_embeddings = torch.cat(
                [relation_embeddings_real, relation_embeddings_im], dim=-1
            )
        else:
            relation_embeddings = relation_embeddings_real

    score_model = ScoreModel(
        embedding_model,
        entity_embeddings.shape[-1],
        relation_embeddings.shape[-1],
        config["score_model_hidden_dim"],
    ).to(device)

    # Create data loaders
    train_dataset = TensorDataset(
        train_data.edge_index[0],
        train_data.edge_type,
        train_data.edge_index[1],
    )
    train_loader = DataLoader(
        train_dataset, batch_size=config["batch_size"], shuffle=True
    )

    val_dataset = TensorDataset(
        val_data.edge_index[0], val_data.edge_type, val_data.edge_index[1]
    )
    val_loader = DataLoader(
        val_dataset, batch_size=config["batch_size"], shuffle=False
    )

    test_dataset = TensorDataset(
        test_data.edge_index[0], test_data.edge_type, test_data.edge_index[1]
    )
    test_loader = DataLoader(
        test_dataset, batch_size=config["batch_size"], shuffle=False
    )

    optimizer = Adam(score_model.parameters(), lr=config["lr"])

    if config["verbose"]:
        entity_dict, relation_dict = load_dicts(
            data_path
        )  # note: if you want to load english for RotatE, store the dicts to map from indices to freebase ids/ relations in data_path/processed
        score_model.entity_dict = entity_dict
        score_model.relation_dict = relation_dict

    score_model.train_data = train_data
    score_model.val_data = val_data
    score_model.test_data = test_data
    score_model.train_loader = train_loader
    score_model.val_loader = val_loader
    score_model.test_loader = test_loader
    score_model.optimizer = optimizer
    score_model.config = config
    score_model.dataset_name = dataset_name
    score_model.embedding_model_name = embedding_model_name
    score_model.embedding_model_path = embedding_model_path

    return score_model


def denoising_score_matching_loss(score_model, h, r, t, timestep):
    """
    Denoising score-matching loss with noise-conditional score networks.
    """
    h_emb, r_emb, t_emb = (
        score_model.embedding_model.node_emb(h),
        score_model.embedding_model.rel_emb(r),
        score_model.embedding_model.node_emb(t),
    )
    true_score = score_model(
        h_emb, r_emb, t_emb
    )  # simply do not pass in the timestep
    noisy_score = score_model(h_emb, r_emb, t_emb, timestep)
    return ((true_score - noisy_score) ** 2).mean()


def train_model(score_model):
    score_model.train()
    loss_epoch = 0
    for batch in score_model.train_loader:
        h, r, t = batch
        h, r, t = h.to(device), r.to(device), t.to(device)

        score_model.optimizer.zero_grad()
        for timestep in range(score_model.config["num_steps"]):
            loss = denoising_score_matching_loss(
                score_model, h, r, t, timestep
            )
            loss.backward()
        score_model.optimizer.step()

        loss_epoch += loss.item()

    return loss_epoch / len(score_model.train_loader)


def test_model(score_model, val=False):
    score_model.eval()
    test_loss = 0
    loader = score_model.val_loader if val else score_model.test_loader
    with torch.no_grad():
        for batch in loader:
            h, r, t = batch
            h, r, t = h.to(device), r.to(device), t.to(device)
            for timestep in range(score_model.config["num_steps"]):
                loss = denoising_score_matching_loss(
                    score_model, h, r, t, timestep
                )
                test_loss += loss.item()

    test_loss = test_loss / (
        len(loader) * score_model.config["num_steps"]
    )  # Average across steps

    return test_loss


def reverse_sde_link_prediction(h_emb, t_emb, score_model):
    """
    Refine the relation embedding using reverse SDE for better link prediction.
    h_emb: Embeddings of the head entities (batch_size, node_emb_dim).
    t_emb: Embeddings of the tail entities (batch_size, node_emb_dim).
    score_model: The trained score model.
    """

    # Initialize with random noise
    r_emb = torch.randn(h_emb.size(0), score_model.rel_emb_dim, device=device)

    # Define the time steps for the reverse SDE process
    t_steps = torch.linspace(
        score_model.config["num_steps"] - 1, 0, score_model.config["num_steps"]
    ).to(device)

    # Define the SDE function for the reverse process
    def sde_func(t, r_emb):
        with torch.enable_grad():
            r_emb.requires_grad_(True)
            score = score_model(h_emb, r_emb, t_emb, t)
            grad_r_emb = torch.autograd.grad(
                score.sum(dim=0), r_emb, create_graph=True
            )[0]
        return -grad_r_emb  # Reverse direction

    # Perform the reverse SDE integration
    with torch.no_grad():
        for t in reversed(t_steps):
            r_emb = sde_func(t, r_emb)

    return r_emb  # returns predictions for r of shape batch, rel_emb_dim


def compute_link_prediction_metrics(
    score_model, val=False, view_english=False
):
    all_ranks = []
    all_hits_at_1, all_hits_at_3, all_hits_at_10 = [], [], []

    # Get all relation embeddings from the model
    all_relations = score_model.embedding_model.rel_emb.weight.detach()
    loader = score_model.val_loader if val else score_model.test_loader
    for batch in loader:
        h, true_r, t = batch
        h, true_r, t = h.to(device), true_r.to(device), t.to(device)

        h_emb = score_model.embedding_model.node_emb(h)
        t_emb = score_model.embedding_model.node_emb(t)
        # Assuming reverse_sde_link_prediction function returns refined relation embeddings
        refined_r_emb = reverse_sde_link_prediction(h_emb, t_emb, score_model)

        if score_model.config["similarity_metric"] == "cosine":
            # Calculate cosine similarity
            similarity = F.cosine_similarity(
                refined_r_emb.unsqueeze(1), all_relations.unsqueeze(0), dim=-1
            )

            # Since cosine similarity is higher for closer vectors, we use largest=True
            values1, indices1 = similarity.topk(1, dim=1, largest=True)
            values3, indices3 = similarity.topk(3, dim=1, largest=True)
            values10, indices10 = similarity.topk(10, dim=1, largest=True)
        elif score_model.config["similarity_metric"] == "l2":
            # Calculate Euclidean distance
            dist = torch.norm(
                refined_r_emb.unsqueeze(1) - all_relations.unsqueeze(0),
                p=2,
                dim=-1,
            )

            # For Euclidean distance, closer vectors have smaller distances, so largest=False
            values1, indices1 = dist.topk(1, dim=1, largest=False)
            values3, indices3 = dist.topk(3, dim=1, largest=False)
            values10, indices10 = dist.topk(10, dim=1, largest=False)
        else:
            raise NotImplementedError(
                f"Haven't implemented a similarity metric yet for {score_model.config['similarity_metric']}"
            )
        if view_english and score_model.dataset_name == "FB15k_237":
            # for getting english from freebase ID for FB15k_237 data
            # show_only_first=True means that only the first entry in the batch will be processed and returned in english
            ic(
                convert_indices_to_english(
                    indices10, score_model.relation_dict, is_entities=False
                )
            )
            ic(
                convert_indices_to_english(
                    true_r, score_model.relation_dict, is_entities=False
                )
            )
            ic(convert_indices_to_english(h, score_model.entity_dict))
            ic(convert_indices_to_english(t, score_model.entity_dict))

        # Check if true relation is within the top K predictions
        hits_at_1 = (
            (indices1 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        )
        all_hits_at_1.append(hits_at_1)
        hits_at_3 = (
            (indices3 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        )
        all_hits_at_3.append(hits_at_3)
        hits_at_10 = (
            (indices10 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        )
        all_hits_at_10.append(hits_at_10)

        # Calculate rank of the true relation
        for i in range(indices1.size(0)):
            true_relation_idx = (indices10[i] == true_r[i]).nonzero(
                as_tuple=True
            )[0]
            if true_relation_idx.numel() > 0:
                all_ranks.append(true_relation_idx.item() + 1)

    mean_rank = np.mean(all_ranks)
    mrr = np.mean([1.0 / rank for rank in all_ranks if rank > 0])
    hits_at_1 = np.mean(all_hits_at_1)
    hits_at_3 = np.mean(all_hits_at_3)
    hits_at_10 = np.mean(all_hits_at_10)

    link_prediction_performance_metrics = {
        "mean_rank": mean_rank,
        "mrr": mrr,
        "hits_at_1": hits_at_1,
        "hits_at_3": hits_at_3,
        "hits_at_10": hits_at_10,
    }

    return link_prediction_performance_metrics


def main(config=None):
    with wandb.init(
        project=f"ScoreMatchingDiffKG_ScoreMatching",
        name=f'{config["embedding_model_info"][0]}_score_matching_model {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
        config=config if config is not None else {},
    ):
        config = wandb.config
        score_model = build_model(config)
        wandb.watch(score_model)
        for epoch in range(config.epochs):
            loss = train_model(score_model)
            if config["verbose"]:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.4f}")
            train_metrics = {"train_epoch": epoch, "train_loss": loss}
            if epoch % 10 == 0 and epoch > 0:
                # set val to be True since we are running over our validation data
                val = True
                val_loss = test_model(score_model, val=val)
                link_prediction_performance_metrics = (
                    compute_link_prediction_metrics(score_model, val=val)
                )
                if config["verbose"]:
                    print(
                        f'Val Mean Rank: {link_prediction_performance_metrics["mean_rank"]:.2f}, Val Mean Reciprocal Rank: {link_prediction_performance_metrics["mrr"]:.2f}, Val Hits@1: {link_prediction_performance_metrics["hits_at_1"]:.4f}, Val Hits@3: {link_prediction_performance_metrics["hits_at_3"]:.4f}, Val Hits@10: {link_prediction_performance_metrics["hits_at_10"]:.4f}'
                    )
                val_metrics = {
                    "val_loss": val_loss,
                    "val_mean_rank": link_prediction_performance_metrics[
                        "mean_rank"
                    ],
                    "val_mrr": link_prediction_performance_metrics["mrr"],
                    "val_hits_at_1": link_prediction_performance_metrics[
                        "hits_at_1"
                    ],
                    "val_hits_at_3": link_prediction_performance_metrics[
                        "hits_at_3"
                    ],
                    "val_hits_at_10": link_prediction_performance_metrics[
                        "hits_at_10"
                    ],
                }

            # log to wandb
            wandb.log(
                {**train_metrics, **val_metrics}
                if "val_metrics" in locals()
                else {**train_metrics}
            )

        # once everything is finished, test model
        test_loss = test_model(score_model)
        link_prediction_performance_metrics = compute_link_prediction_metrics(
            score_model
        )
        if config["verbose"]:
            print(
                f'Test Mean Rank: {link_prediction_performance_metrics["mean_rank"]:.2f}, Test Mean Reciprocal Rank: {link_prediction_performance_metrics["mrr"]:.2f}, Test Hits@1: {link_prediction_performance_metrics["hits_at_1"]:.4f}, Test Hits@3: {link_prediction_performance_metrics["hits_at_3"]:.4f}, Test Hits@10: {link_prediction_performance_metrics["hits_at_10"]:.4f}'
            )
        test_metrics = {
            "test_loss": test_loss,
            "test_mean_rank": link_prediction_performance_metrics["mean_rank"],
            "test_mrr": link_prediction_performance_metrics["mrr"],
            "test_hits_at_1": link_prediction_performance_metrics["hits_at_1"],
            "test_hits_at_3": link_prediction_performance_metrics["hits_at_3"],
            "test_hits_at_10": link_prediction_performance_metrics[
                "hits_at_10"
            ],
        }
        wandb.log({**test_metrics})

        # Save the trained model
        path = osp.join(
            os.getcwd(),
            f"{score_model.dataset_name}_score_matching_model_weights.pth",
        )
        torch.save(score_model.state_dict(), path)

        # Fetch a batch from train_loader
        for batch in score_model.train_loader:
            # Assuming batch contains head_index, rel_type, tail_index, and possibly other data
            head_index, rel_type, tail_index = batch
            head_index = head_index.float()
            rel_type = rel_type.float()
            tail_index = tail_index.float()
            break  # Only need one batch for this purpose
        # Use the fetched batch to provide dummy inputs for the export
        # Ensure these variables are moved to the same device as model if necessary
        torch.onnx.export(
            score_model,
            (
                head_index,
                rel_type,
                tail_index,
            ),  # Use actual data as dummy inputs
            f"{score_model.dataset_name}_score_matching_model_weights.onnx",
            opset_version=11,
            do_constant_folding=True,
            input_names=[
                "head_index",
                "rel_type",
                "tail_index",
            ],  # Adjust input names as needed
            dynamic_axes={
                "head_index": {0: "batch_size"},
                "rel_type": {0: "batch_size"},
                "tail_index": {0: "batch_size"},
            },
        )
        wandb.save(
            f"{score_model.dataset_name}_score_matching_model_weights.onnx"
        )

        return score_model

In [6]:
if run_sweep:

    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)

    sweep_id = wandb.sweep(
        project=f"ScoreMatchingDiffKG_Score_Sweep", sweep=sweep_config
    )

    wandb.agent(sweep_id, function=main)
else:
    model = main(config)

Epoch: 000, Train Loss: 318.3321
Epoch: 001, Train Loss: 271.2310
Epoch: 002, Train Loss: 229.8930
Test Mean Rank: 2.58, Test Mean Reciprocal Rank: 0.70, Test Hits@1: 0.2936, Test Hits@3: 0.3843, Test Hits@10: 0.5051


0,1
test_hits_at_1,▁
test_hits_at_10,▁
test_hits_at_3,▁
test_loss,▁
test_mean_rank,▁
test_mrr,▁
train_epoch,▁▅█
train_loss,█▄▁

0,1
test_hits_at_1,0.29363
test_hits_at_10,0.50511
test_hits_at_3,0.38426
test_loss,152.05083
test_mean_rank,2.57821
test_mrr,0.69826
train_epoch,2.0
train_loss,229.89305
