Author: Blaine Hill

In this notebook, we train a score matching model as described by [Yang Song](https://yang-song.net/blog/2021/score/)


$$s_\theta(x) = \nabla_x \log p_\theta(x) = -\nabla_x f_\theta(x) - \nabla_x \log Z_\theta = -\nabla_x f_\theta(x)$$

#TODO: figure out best hidden channel for RotatE on FB15k_237 from other research papers
#TODO: figure out the dataset split for FB15k_237 to ensure it is the same (comparing apples to apples)

In [1]:
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "score_matching_model.ipynb"

import wandb

wandb.login()
# wandb.login(relogin=True)


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.nn import RotatE
from torch_geometric.datasets import FB15k_237
import sys
import os
import os.path as osp
import math
import numpy as np
from icecream import ic
import yaml
from datetime import datetime

# for getting english from freebase ID
import sys

sys.path.append("../freebase/")
from converter import EntityConverter
from wikidata.client import Client


def get_english_from_freebase_id(freebase_id):
    try:
        entity_converter = EntityConverter("https://query.wikidata.org/sparql")
        res = entity_converter.get_wikidata_id(freebase_id)
        item = Client().get(res)
        return item.label if item and item.label else "No english found for this freebase id"
    except AssertionError:
        return "This freebase id has no corresponding wikidata id"


dataset_name = "FB15k_237"

parent_dir = os.path.abspath(os.pardir)
sys.path.append(parent_dir)
data_path = osp.join(parent_dir, "data", dataset_name)

device = "cuda" if torch.cuda.is_available() else "cpu"


train_data = FB15k_237(data_path, split="train")[0].to(device)
val_data = FB15k_237(data_path, split="val")[0].to(device)
test_data = FB15k_237(data_path, split="test")[0].to(device)

[34m[1mwandb[0m: Currently logged in as: [33mbthill1[0m ([33muiuc_idealab_2024[0m). Use [1m`wandb login --relogin`[0m to force relogin
  return torch._C._cuda_getDeviceCount() > 0


In [2]:
run_sweep = False
if run_sweep:
    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)
else:
    config = {
        "batch_size": 64,
        "epochs": 100,
        "num_steps": 20,  # Adjust based on your desired SDE iterations
        "lr": 1e-4,
        "similarity_metric": "cosine",
    }

In [3]:
class ScoreModel(nn.Module):
    def __init__(self, embedding_model, node_emb_dim, rel_emb_dim, config):
        super(ScoreModel, self).__init__()
        self.config = config
        self.embedding_model = embedding_model
        self.node_emb_dim = node_emb_dim
        self.rel_emb_dim = rel_emb_dim
        self.score_net = nn.Sequential(
            nn.Linear(node_emb_dim + rel_emb_dim, 512, dtype=torch.float),
            nn.ReLU(),
            nn.Linear(512, 512, dtype=torch.float),
            nn.ReLU(),
            nn.Linear(512, 1, dtype=torch.float),
        )

    def forward(self, h_emb, r_emb, t_emb, timestep=None):
        # Implement your desired distance measure here (e.g., L2 distance)
        distance = torch.linalg.norm(h_emb + r_emb - t_emb, dim=-1)

        # Gradually increase the weight of the distance term during SDE steps
        weight = 0.0  # No weight increase if timestep is None
        if timestep is not None:

            def sigmoid(x):
                return 1 / (1 + math.exp(-x))

            weight = sigmoid(timestep / (self.config["num_steps"] - 1))
        score = weight * distance
        return score


def build_model(config):
    embedding_model = RotatE(
        num_nodes=train_data.num_nodes,
        num_relations=train_data.num_edge_types,
        hidden_channels=50,
    ).to(device)
    embedding_model.load_state_dict(
        torch.load("../embedding_model/FB15k_237_RotatE_embedding_model_weights.pth")
    )
    embedding_model.eval()

    def load_dict_from_pt(file_path):
        data = torch.load(file_path)
        return data if isinstance(data, dict) else None

    entity_dict = load_dict_from_pt(
        osp.join(
            osp.join(data_path, "processed"), "entity_dict.pt"
        )  # entity dict matches index: freebaseID
    )
    relation_dict = load_dict_from_pt(
        osp.join(osp.join(data_path, "processed"), "relation_dict.pt")
    )

    # Get entity and relation embeddings
    with torch.no_grad():
        entity_embeddings_real = embedding_model.node_emb.weight.detach()
        entity_embeddings_im = embedding_model.node_emb_im.weight.detach()
        entity_embeddings = torch.cat([entity_embeddings_real, entity_embeddings_im], dim=-1)
        relation_embeddings = embedding_model.rel_emb.weight.detach()

    score_model = ScoreModel(
        embedding_model, entity_embeddings.shape[-1], relation_embeddings.shape[-1], config
    ).to(device)

    # Create data loaders
    train_dataset = TensorDataset(
        train_data.edge_index[0], train_data.edge_type, train_data.edge_index[1]
    )
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)

    val_dataset = TensorDataset(
        val_data.edge_index[0], val_data.edge_type, val_data.edge_index[1]
    )
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)

    test_dataset = TensorDataset(
        test_data.edge_index[0], test_data.edge_type, test_data.edge_index[1]
    )
    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)

    optimizer = Adam(score_model.parameters(), lr=config["lr"])

    return (
        score_model,
        optimizer,
        train_loader,
        val_loader,
        test_loader,
        entity_dict,
        relation_dict,
    )


def denoising_score_matching_loss(score_model, h, r, t, timestep):
    """
    Denoising score-matching loss with noise-conditional score networks.
    """
    h_emb, r_emb, t_emb = (
        score_model.embedding_model.node_emb(h),
        score_model.embedding_model.rel_emb(r),
        score_model.embedding_model.node_emb(t),
    )
    true_score = score_model(h_emb, r_emb, t_emb)  # simply do not pass in the timestep
    noisy_score = score_model(h_emb, r_emb, t_emb, timestep)
    return ((true_score - noisy_score) ** 2).mean()


def train_model(score_model, optimizer, train_loader, config):
    score_model.train()
    loss_epoch = 0
    for batch in train_loader:
        h, r, t = batch
        h, r, t = h.to(device), r.to(device), t.to(device)

        optimizer.zero_grad()
        for timestep in range(config["num_steps"]):
            loss = denoising_score_matching_loss(score_model, h, r, t, timestep)
            loss.backward()
        optimizer.step()

        loss_epoch += loss.item()

    return loss_epoch / len(train_loader)


def test_model(score_model, test_loader, config):
    score_model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            h, r, t = batch
            h, r, t = h.to(device), r.to(device), t.to(device)
            for timestep in range(config["num_steps"]):
                loss = denoising_score_matching_loss(score_model, h, r, t, timestep)
                test_loss += loss.item()

    return test_loss / (len(test_loader) * config["num_steps"])  # Average across steps


def reverse_sde(h_emb, t_emb, score_model, config):
    """
    Refine the relation embedding using reverse SDE for better link prediction.
    h_emb: Embeddings of the head entities (batch_size, node_emb_dim).
    t_emb: Embeddings of the tail entities (batch_size, node_emb_dim).
    score_model: The trained score model.
    config: Configuration dictionary.
    """

    # Initialize with random noise
    r_emb = torch.randn(h_emb.size(0), score_model.rel_emb_dim, device=device)

    # Define the time steps for the reverse SDE process
    t_steps = torch.linspace(config["num_steps"] - 1, 0, config["num_steps"]).to(device)

    # Define the SDE function for the reverse process
    def sde_func(t, r_emb):
        with torch.enable_grad():
            r_emb.requires_grad_(True)
            score = score_model(h_emb, r_emb, t_emb, t)
            grad_r_emb = torch.autograd.grad(score.sum(dim=0), r_emb, create_graph=True)[0]
        return -grad_r_emb  # Reverse direction

    # Perform the reverse SDE integration
    with torch.no_grad():
        for t in reversed(t_steps):
            r_emb = sde_func(t, r_emb)

    return r_emb  # returns predictions for r of shape batch, rel_emb_dim


def compute_metrics(score_model, test_loader, config, entity_dict, relation_dict):
    all_ranks = []
    all_hits_at_1, all_hits_at_3, all_hits_at_10 = [], [], []

    # Get all relation embeddings from the model
    all_relations = score_model.embedding_model.rel_emb.weight.detach()

    for batch in test_loader:
        h, true_r, t = batch
        h, true_r, t = h.to(device), true_r.to(device), t.to(device)

        h_emb = score_model.embedding_model.node_emb(h)
        t_emb = score_model.embedding_model.node_emb(t)
        # Assuming reverse_sde function returns refined relation embeddings
        refined_r_emb = reverse_sde(h_emb, t_emb, score_model, config)

        if config["similarity_metric"] == "cosine":
            # Calculate cosine similarity
            similarity = F.cosine_similarity(
                refined_r_emb.unsqueeze(1), all_relations.unsqueeze(0), dim=-1
            )

            # Since cosine similarity is higher for closer vectors, we use largest=True
            values1, indices1 = similarity.topk(1, dim=1, largest=True)
            values3, indices3 = similarity.topk(3, dim=1, largest=True)
            values10, indices10 = similarity.topk(10, dim=1, largest=True)
        elif config["similarity_metric"] == "l2":
            # Calculate Euclidean distance
            dist = torch.norm(
                refined_r_emb.unsqueeze(1) - all_relations.unsqueeze(0), p=2, dim=-1
            )

            # For Euclidean distance, closer vectors have smaller distances, so largest=False
            values1, indices1 = dist.topk(1, dim=1, largest=False)
            values3, indices3 = dist.topk(3, dim=1, largest=False)
            values10, indices10 = dist.topk(10, dim=1, largest=False)
        else:
            raise NotImplementedError(
                f"Haven't implemented a similarity metric yet for {config['similarity_metric']}"
            )

        # def map_to_string(index, dict):
        #     return dict.get(index.item(), "Not Found")

        # english_indices = [map_to_string(index, relation_dict) for index in indices]
        # english_true_r = [map_to_string(index, relation_dict) for index in true_r]
        # english_h = [
        #     get_english_from_freebase_id(map_to_string(index, entity_dict)) for index in h
        # ]
        # english_t = [
        #     get_english_from_freebase_id(map_to_string(index, entity_dict)) for index in t
        # ]

        # # examine for reasonable predictions
        # ic(english_indices)  # top k real relations closest to r_emb in english
        # ic(english_true_r)  # what the model should have predicted in english
        # ic(english_h)  # head entity tensor
        # ic(english_t)

        # Check if true relation is within the top K predictions
        hits_at_1 = (indices1 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        all_hits_at_1.append(hits_at_1)
        hits_at_3 = (indices3 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        all_hits_at_3.append(hits_at_3)
        hits_at_10 = (indices10 == true_r.unsqueeze(1)).any(dim=1).float().mean().item()
        all_hits_at_10.append(hits_at_10)

        # Calculate rank of the true relation #TODO: fix MRR
        for i in range(indices1.size(0)):
            true_relation_idx = (indices10[i] == true_r[i]).nonzero(as_tuple=True)[0]
            ic(true_relation_idx.shape)
            if true_relation_idx.numel() > 0:
                all_ranks.append(true_relation_idx.item() + 1)

    mean_rank = np.mean(all_ranks)
    mrr = np.mean([1.0 / rank for rank in all_ranks if rank > 0])
    hits_at_1 = np.mean(all_hits_at_1)
    hits_at_3 = np.mean(all_hits_at_3)
    hits_at_10 = np.mean(all_hits_at_10)

    return mean_rank, mrr, hits_at_1, hits_at_3, hits_at_10


def main(config=None, verbose=False):
    with wandb.init(
        project=f"ScoreMatchingDiffKG_ScoreMatching",
        name=f"{dataset_name}_score_matching_model {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
        config=config if config is not None else {},
    ):
        config = wandb.config
        model, optimizer, train_loader, val_loader, test_loader, entity_dict, relation_dict = (
            build_model(config)
        )
        wandb.watch(model)
        for epoch in range(config.epochs):
            loss = train_model(model, optimizer, train_loader, config)
            if verbose:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.4f}")
            train_metrics = {"train_epoch": epoch, "train_loss": loss}
            if epoch % 2 == 0 and epoch > 0:
                loss = test_model(model, test_loader, config)
                mean_rank, mrr, hits_at_1, hits_at_3, hits_at_10 = compute_metrics(
                    model, test_loader, config, entity_dict, relation_dict
                )
                if verbose:
                    print(
                        f"Test Mean Rank: {mean_rank:.2f}, Test Mean Reciprocal Rank: {mrr:.2f}, Test Hits@1: {hits_at_1:.4f}, Test Hits@3: {hits_at_3:.4f}, Test Hits@10: {hits_at_10:.4f}"
                    )
                test_metrics = {
                    "test_loss": loss,
                    "test_mean_rank": mean_rank,
                    "test_mrr": mrr,
                    "test_hits_at_1": hits_at_1,
                    "test_hits_at_3": hits_at_3,
                    "test_hits_at_10": hits_at_10,
                }
            # log to wandb
            wandb.log(
                {**train_metrics, **test_metrics}
                if "test_metrics" in locals()
                else {**train_metrics}
            )

        # Save the trained model
        path = osp.join(os.getcwd(), f"{dataset_name}_score_matching_model_weights.pth")
        torch.save(model.state_dict(), path)

        # Fetch a batch from train_loader
        for batch in train_loader:
            # Assuming batch contains head_index, rel_type, tail_index, and possibly other data
            head_index, rel_type, tail_index = batch
            break  # Only need one batch for this purpose

        # Use the fetched batch to provide dummy inputs for the export
        # Ensure these variables are moved to the same device as model if necessary
        torch.onnx.export(
            model,
            (head_index, rel_type, tail_index),  # Use actual data as dummy inputs
            f"{dataset_name}_score_matching_model_weights.onnx",
            opset_version=11,
            do_constant_folding=True,
            input_names=[
                "head_index",
                "rel_type",
                "tail_index",
            ],  # Adjust input names as needed
            dynamic_axes={
                "head_index": {0: "batch_size"},
                "rel_type": {0: "batch_size"},
                "tail_index": {0: "batch_size"},
            },
        )
        wandb.save(f"{dataset_name}_score_matching_model_weights.onnx")

        return model

In [4]:
if run_sweep:

    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)

    sweep_id = wandb.sweep(project=f"ScoreMatchingDiffKG_Score_Sweep", sweep=sweep_config)

    wandb.agent(sweep_id, function=main)
else:
    model = main(config, verbose=True)

Epoch: 000, Train Loss: 21.8293
Epoch: 001, Train Loss: 18.4090
Epoch: 002, Train Loss: 15.4505
Test Mean Rank: 0.04, Test Mean Reciprocal Rank: 1.00, Test Hits@1: 0.0357, Test Hits@3: 0.0862, Test Hits@10: 0.1865
Epoch: 003, Train Loss: 12.8834
Epoch: 004, Train Loss: 10.6679
Test Mean Rank: 0.08, Test Mean Reciprocal Rank: 1.00, Test Hits@1: 0.0837, Test Hits@3: 0.1662, Test Hits@10: 0.3302
Epoch: 005, Train Loss: 8.7703
Epoch: 006, Train Loss: 7.1650
Test Mean Rank: 0.14, Test Mean Reciprocal Rank: 1.00, Test Hits@1: 0.1430, Test Hits@3: 0.2639, Test Hits@10: 0.4669
Epoch: 007, Train Loss: 5.8160
Epoch: 008, Train Loss: 4.7015
Test Mean Rank: 0.24, Test Mean Reciprocal Rank: 1.00, Test Hits@1: 0.2440, Test Hits@3: 0.3928, Test Hits@10: 0.5587
Epoch: 009, Train Loss: 3.7879
Epoch: 010, Train Loss: 3.0446
Test Mean Rank: 0.33, Test Mean Reciprocal Rank: 1.00, Test Hits@1: 0.3277, Test Hits@3: 0.4683, Test Hits@10: 0.6213
Epoch: 011, Train Loss: 2.4418
Epoch: 012, Train Loss: 1.9540
Te