Author: Blaine Hill

In this notebook, we train a score matching model as described by [Yang Song](https://yang-song.net/blog/2021/score/) to do KG Reasoning tasks such as [link prediction](https://paperswithcode.com/task/link-prediction)

$$s_\theta(x) = \nabla_x \log p_\theta(x) = -\nabla_x f_\theta(x) - \nabla_x \log Z_\theta = -\nabla_x f_\theta(x)$$

In [1]:
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "score_matching_model.ipynb"

import wandb

wandb.login()
# wandb.login(relogin=True)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

import sys
import os.path as osp
from ipykernel import get_connection_file

import math
import numpy as np
import yaml
from datetime import datetime
import random

notebook_path = osp.abspath(
    osp.join(os.getcwd(), osp.basename(get_connection_file()))
)
parent_dir = osp.dirname(osp.dirname(notebook_path))
sys.path.append(parent_dir)

# Import utility functions and model utilities
from utils.utils import *
from utils.embedding_model_utils import load_dataset
from utils.score_matching_model_utils import *

# from utils.score_matching_model.utils import *

from utils.score_matching_model.ScoreModel import ScoreModel

device = "cuda" if torch.cuda.is_available() else "cpu"

set_seed()

[34m[1mwandb[0m: Currently logged in as: [33mbthill1[0m ([33muiuc_idealab_2024[0m). Use [1m`wandb login --relogin`[0m to force relogin
  return torch._C._cuda_getDeviceCount() > 0


In [2]:
run_sweep = False
sweep_config_file_path = "sweep_config.yaml"

if run_sweep:
    with open(sweep_config_file_path, "r") as file:
        config = yaml.safe_load(file)
else:
    config = {
        "embedding_model_dir": "../trained_embedding_models/FB15k_237_TransE_kg_completion_2024.05.17.19.12.56_6134bf01_embedding_model/",
        "score_model_hidden_dim": 512,
        "batch_size": 512,
        "max_epochs": 3,
        "num_sde_timesteps": 20,  # Adjust based on your desired SDE iterations
        "lr": 1e-4,
        "k": [1, 3, 10],  # used for top-k evaluation
        "verbose": True,
        "num_epochs_without_improvement_until_early_finish": 5,
        "validate_after_this_many_epochs": 1,
    }

In [3]:
def build_model(config, view_english=False):
    embedding_model_dir = config["embedding_model_dir"]
    embedding_model, embedding_model_name, dataset_name = (
        initialize_trained_embedding_model(embedding_model_dir, device)
    )
    embedding_model.eval()

    train_data, val_data, test_data, data_path = load_dataset(
        dataset_name, parent_dir=parent_dir
    )

    original_model = ScoreModel(
        embedding_model,
        config["score_model_hidden_dim"],
        config["num_sde_timesteps"],
        config["task"],
        config["aux_dict"],
    ).to(device)

    # Calculate batch size based on the number of GPUs available
    num_gpus = torch.cuda.device_count()
    batch_size_per_gpu = config["batch_size"]
    batch_size = (
        batch_size_per_gpu * num_gpus if num_gpus > 0 else batch_size_per_gpu
    )
    num_workers = num_gpus if num_gpus > 0 else 1
    config["num_gpus"] = num_gpus

    # Helper function to prepare loader arguments
    def prepare_loader_args(data, batch_size, num_workers):
        edge_index_0 = data.edge_index[0]
        edge_type = data.edge_type
        edge_index_1 = data.edge_index[1]

        x = data.x if "x" in data else None

        if x is not None:
            dataset = TensorDataset(edge_index_0, edge_type, edge_index_1, x)
        else:
            dataset = TensorDataset(edge_index_0, edge_type, edge_index_1)

        loader_args = {
            "batch_size": batch_size,
            "shuffle": True,
            "num_workers": num_workers,
        }

        return dataset, loader_args

    # Prepare loader arguments for train, validation, and test datasets
    train_dataset, train_loader_args = prepare_loader_args(
        train_data, batch_size, num_workers
    )
    val_dataset, val_loader_args = prepare_loader_args(
        val_data, batch_size, num_workers
    )
    test_dataset, test_loader_args = prepare_loader_args(
        test_data, batch_size, num_workers
    )

    train_loader = DataLoader(train_dataset, **train_loader_args)
    val_loader = DataLoader(val_dataset, **val_loader_args)
    test_loader = DataLoader(test_dataset, **test_loader_args)

    optimizer = Adam(original_model.parameters(), lr=config["lr"])

    # Use DataParallel for multi-GPU setups on single server
    if num_gpus > 1:
        score_model = torch.nn.DataParallel(original_model)
    else:
        score_model = original_model

    # Use DataParallel for multi-GPU setups on single server
    if num_gpus > 1:
        score_model = torch.nn.DataParallel(original_model)
    else:
        score_model = original_model

    if view_english:
        entity_dict, relation_dict = load_dicts(
            data_path
        )  # note: if you want to load english for RotatE, store the dicts to map from indices to freebase ids/ relations in data_path/processed
        score_model.entity_dict = entity_dict
        score_model.relation_dict = relation_dict

    score_model.train_data = train_data
    score_model.val_data = val_data
    score_model.test_data = test_data
    score_model.train_loader = train_loader
    score_model.val_loader = val_loader
    score_model.test_loader = test_loader
    score_model.optimizer = optimizer
    score_model.config = config
    score_model.dataset_name = dataset_name
    score_model.embedding_model_name = embedding_model_name
    score_model.embedding_model_dir = embedding_model_dir
    score_model.original_model = (
        original_model  # Keep a reference to the original model
    )

    # Determine the path to save results and create directory if it does not exist
    if "save_path" in config:
        save_path = config["save_path"]
        if not os.path.exists(save_path):
            raise ValueError(
                f"Directory {save_path} does not exist in which to save the trained score matching models. Please create it before saving."
            )
        save_path = osp.join(
            save_path,
            f"{config['prefix']}_score_matching_model",
        )
    else:
        save_path = osp.join(
            parent_dir,
            "trained_score_matching_models",
            f"{config['prefix']}_score_matching_model",
        )

    os.makedirs(save_path, exist_ok=True)
    score_model.save_path = save_path

    # Save the model configuration
    save_score_matching_model_config(score_model)

    return score_model


def train_model(score_model):
    score_model.train()
    loss_epoch = 0
    for batch in score_model.train_loader:
        # Extract data from the batch
        h, r, t = batch[0].to(device), batch[1].to(device), batch[2].to(device)

        # Check if 'x' is present in the batch
        x = batch[3].to(device) if len(batch) > 3 else None

        for timestep in range(score_model.config["num_sde_timesteps"]):
            loss = score_model.original_model.loss(
                h, r, t, timestep, x, score_model.config["task"]
            )
            loss.backward()
        score_model.optimizer.step()

        loss_epoch += loss.item()

    return loss_epoch / len(score_model.train_loader)


def test_model(score_model, val=False, view_english=False):
    score_model.eval()
    total_loss = total_examples = 0
    only_relation_prediction = val
    loader = score_model.val_loader if val else score_model.test_loader
    all_metrics = []
    with torch.no_grad():
        for batch in loader:
            # Extract data from the batch
            h, r, t = (
                batch[0].to(device),
                batch[1].to(device),
                batch[2].to(device),
            )

            # Check if 'x' is present in the batch
            x = batch[3].to(device) if len(batch) > 3 else None

            if x is not None:
                x = x.to(device)

            for timestep in range(score_model.config["num_sde_timesteps"]):
                loss = score_model.original_model.loss(
                    h, r, t, timestep, x, score_model.config["task"]
                )
                total_loss += loss.item()
                total_examples += h.numel()

            # Calculate metrics for the batch
            with torch.enable_grad():
                metrics = score_model.original_model.test(
                    h=h,
                    r=r,
                    t=t,
                    x=x,
                    k=score_model.config["k"],
                    task=score_model.config["task"],
                    only_relation_prediction=only_relation_prediction,  # for kg_completion: during validation, only evaluate on relation prediction - on the test set do head, relation, and tail prediction
                )
            all_metrics.append(metrics)

    # Aggregate metrics across all batches
    if score_model.original_model.task in "kg_completion":
        if only_relation_prediction:
            relation_mean_ranks = []
            relation_mrrs = []
            relation_hits_at_ks = []
            # Collect the metrics from each batch
            for metrics in all_metrics:
                relation_mean_rank, relation_mrr, relation_hits_at_k = metrics
                relation_mean_ranks.append(relation_mean_rank)
                relation_mrrs.append(relation_mrr)
                relation_hits_at_ks.append(relation_hits_at_k)
            # Aggregate the metrics across all batches
            relation_mean_rank = sum(relation_mean_ranks) / len(all_metrics)
            relation_mrr = sum(relation_mrrs) / len(all_metrics)

            # Aggregate the hits@k metrics
            all_k_values = set().union(
                *(d.keys() for d in relation_hits_at_ks)
            )
            relation_hits_at_k = {
                f"relation_hits_at_{k}": sum(
                    d.get(k, 0) for d in relation_hits_at_ks
                )
                / len(all_metrics)
                for k in all_k_values
            }
            # Create the performance_metrics dictionary
            performance_metrics = {
                "loss": total_loss / total_examples,
                "relation_mean_rank": relation_mean_rank,
                "relation_mrr": relation_mrr,
                **relation_hits_at_k,
            }
        else:
            head_mean_ranks = []
            relation_mean_ranks = []
            tail_mean_ranks = []
            head_mrrs = []
            relation_mrrs = []
            tail_mrrs = []
            head_hits_at_ks = []
            relation_hits_at_ks = []
            tail_hits_at_ks = []

            # Collect the metrics from each batch
            for metrics in all_metrics:
                (
                    head_mean_rank,
                    relation_mean_rank,
                    tail_mean_rank,
                    head_mrr,
                    relation_mrr,
                    tail_mrr,
                    head_hits_at_k,
                    relation_hits_at_k,
                    tail_hits_at_k,
                ) = metrics
                head_mean_ranks.append(head_mean_rank)
                relation_mean_ranks.append(relation_mean_rank)
                tail_mean_ranks.append(tail_mean_rank)
                head_mrrs.append(head_mrr)
                relation_mrrs.append(relation_mrr)
                tail_mrrs.append(tail_mrr)
                head_hits_at_ks.append(head_hits_at_k)
                relation_hits_at_ks.append(relation_hits_at_k)
                tail_hits_at_ks.append(tail_hits_at_k)

            # Aggregate the metrics across all batches
            head_mean_rank = sum(head_mean_ranks) / len(all_metrics)
            relation_mean_rank = sum(relation_mean_ranks) / len(all_metrics)
            tail_mean_rank = sum(tail_mean_ranks) / len(all_metrics)
            head_mrr = sum(head_mrrs) / len(all_metrics)
            relation_mrr = sum(relation_mrrs) / len(all_metrics)
            tail_mrr = sum(tail_mrrs) / len(all_metrics)

            # Aggregate the hits@k metrics
            all_k_values = set().union(
                *(d.keys() for d in relation_hits_at_ks)
            )
            head_hits_at_k = {
                f"head_hits_at_{k}": sum(d.get(k, 0) for d in head_hits_at_ks)
                / len(all_metrics)
                for k in all_k_values
            }
            relation_hits_at_k = {
                f"relation_hits_at_{k}": sum(
                    d.get(k, 0) for d in relation_hits_at_ks
                )
                / len(all_metrics)
                for k in all_k_values
            }
            tail_hits_at_k = {
                f"tail_hits_at_{k}": sum(d.get(k, 0) for d in tail_hits_at_ks)
                / len(all_metrics)
                for k in all_k_values
            }

            # Create the performance_metrics dictionary
            performance_metrics = {
                "loss": total_loss / total_examples,
                "head_mean_rank": head_mean_rank,
                "relation_mean_rank": relation_mean_rank,
                "tail_mean_rank": tail_mean_rank,
                "head_mrr": head_mrr,
                "relation_mrr": relation_mrr,
                "tail_mrr": tail_mrr,
                **head_hits_at_k,
                **relation_hits_at_k,
                **tail_hits_at_k,
            }

    elif score_model.original_model.task == "node_classification":
        accuracy = sum(all_metrics) / len(all_metrics)
        performance_metrics = {
            "loss": total_loss / total_examples,
            "accuracy": accuracy,
        }

    else:
        raise ValueError(
            f"score_model task isn't valid: {score_model.original_model.task}"
        )

    return performance_metrics

In [4]:
def main(config=None):
    run_timestamp = datetime.now().strftime("%Y.%m.%d.%H.%M.%S")
    with wandb.init(
        project=f"ScoreMatchingDiffKG_ScoreMatching",
        name=f"{run_timestamp}_score_matching_model_run",  # Use a temporary name
        config=config if config is not None else {},
    ):
        config = wandb.config
        (
            config["dataset_name"],
            config["embedding_model_name"],
            config["task"],
            config["aux_dict"],
        ) = extract_info_from_embedding_model_dir(
            config["embedding_model_dir"]
        )
        config["prefix"] = generate_prefix(config, run_timestamp)

        wandb.run.name = f"{config['prefix']}_score_matching_model_run"
        if (
            config["task"] == "node_classification"
            and config["dataset_name"]
            not in [
                "Cora",
                "Citeseer",
                "Pubmed",
            ]
        ) or (
            config["task"] != "node_classification"
            and config["dataset_name"]
            in [
                "Cora",
                "Citeseer",
                "Pubmed",
            ]
        ):
            print(f"Skipping {config['task']} on {config['dataset_name']}")
            return

        wandb.run.name = f"{config['prefix']}_score_matching_model_run"
        score_model = build_model(config)
        wandb.watch(score_model)

        num_epochs_without_improvement_until_early_finish = config[
            "num_epochs_without_improvement_until_early_finish"
        ]
        validate_after_this_many_epochs = config[
            "validate_after_this_many_epochs"
        ]

        best_metric_to_optimize = float(
            "inf"
        )  # either val mean rank or val accuracy later
        metric_to_optimize = (
            "relation_mean_rank"
            if config["task"] == "kg_completion"
            else "accuracy"
        )  # if config["task"] == "node_classification"
        epochs_without_improvement = 0

        for epoch in range(config.max_epochs):
            loss = train_model(score_model)

            if config["verbose"]:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.10f}")

            train_metrics = {"train_epoch": epoch, "train_loss": loss}
            if epoch % validate_after_this_many_epochs == 0 and epoch > 0:
                val_metrics = test_model(score_model, val=True)
                if (
                    val_metrics[metric_to_optimize] <= best_metric_to_optimize
                ):  # minimize
                    best_metric_to_optimize = val_metrics[metric_to_optimize]
                    save_trained_score_matching_weights_and_performance(
                        score_model, epoch + 1, val_metrics
                    )
                    epochs_without_improvement = 0
                else:
                    epochs_without_improvement += 1

                if config["verbose"]:
                    metrics_info = ", ".join(
                        [
                            f'Val {key.replace("_", " ").title()}: {value:.10f}'
                            for key, value in val_metrics.items()
                            if isinstance(value, (int, float))
                        ]
                    )
                    print(metrics_info)

                if (
                    epochs_without_improvement
                    >= num_epochs_without_improvement_until_early_finish
                ):
                    print("Stopping early due to increasing training loss.")
                    break

                val_metrics = {
                    f"val_{key}": value for key, value in val_metrics.items()
                }

            # log to wandb
            wandb.log(
                {**train_metrics, **val_metrics}
                if "val_metrics" in locals()
                else {**train_metrics}
            )

        score_model.load_state_dict(
            torch.load(
                osp.join(
                    score_model.save_path,
                    f"{score_model.config['prefix']}_score_matching_model_weights.pth",
                )
            )
        )

        test_metrics = test_model(score_model)

        if config["verbose"]:
            test_metrics_info = ", ".join(
                [
                    f'Test {key.replace("_", " ").title()}: {value:.4f}'
                    for key, value in test_metrics.items()
                    if isinstance(value, (int, float))
                ]
            )
            print(test_metrics_info)

        with open(
            osp.join(
                score_model.save_path,
                f"{score_model.config['prefix']}_score_matching_model_performance.txt",
            ),
            "a",
        ) as file:
            file.write("Test Metrics:\n")
            for metric, value in test_metrics.items():
                file.write(f"{metric}: {value}\n")

        test_metrics = {
            f"test_{key}": value for key, value in test_metrics.items()
        }
        wandb.log({**test_metrics})
        wandb.save(score_model.save_path, base_path=parent_dir)

        return score_model

In [5]:
run_sweep_or_main(
    run_sweep=run_sweep,
    project_name="ScoreMatchingDiffKG_Score_Sweep",
    main=main,
    config=config,
)

Epoch: 000, Train Loss: 0.0002059253
Epoch: 001, Train Loss: 0.0000210212
score requires grad: True
score grad fn: <AddmmBackward0 object at 0x7f69767c3be0>
score grad after backward: tensor([[1.],
        [1.],
        [1.],
        ...,
        [1.],
        [1.],
        [1.]])
Shape of score before squeeze: torch.Size([2048, 1])
Shape of score after squeeze: torch.Size([2048])
Gradient of squeezed score before einsum: None


  print("Gradient of squeezed score before einsum:", squeezed_score.grad)
  print("Shape of squeezed score gradient:", squeezed_score.grad.shape)
Traceback (most recent call last):
  File "/tmp/ipykernel_43895/3287206081.py", line 69, in main
    val_metrics = test_model(score_model, val=True)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_43895/4032591174.py", line 180, in test_model
    metrics = score_model.original_model.test(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blaineh2/ScoreMatchingDiffKG/utils/score_matching_model/ScoreModel.py", line 290, in test
    return self.evaluate_prediction_task(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/blaineh2/ScoreMatchingDiffKG/utils/score_matching_model/ScoreModel.py", line 349, in evaluate_prediction_task
    self.calculate_similarity_and_sort(
  File "/home/blaineh2/ScoreMatchingDiffKG/utils/score_matching_model/ScoreModel.py", line 512, in calculate_similarity_and_sort
    print

0,1
train_epoch,▁
train_loss,▁

0,1
train_epoch,0.0
train_loss,0.00021


AttributeError: 'NoneType' object has no attribute 'shape'