Author: Blaine Hill

In this notebook, we train a score matching model as described by [Yang Song](https://yang-song.net/blog/2021/score/) to do KG Reasoning tasks such as [link prediction](https://paperswithcode.com/task/link-prediction)

$$s_\theta(x) = \nabla_x \log p_\theta(x) = -\nabla_x f_\theta(x) - \nabla_x \log Z_\theta = -\nabla_x f_\theta(x)$$

In [None]:
import os

os.environ["WANDB_NOTEBOOK_NAME"] = "score_matching_model.ipynb"

import wandb

wandb.login()
# wandb.login(relogin=True)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

import sys
import os.path as osp
from ipykernel import get_connection_file

import math
import numpy as np
from icecream import ic
import yaml
from datetime import datetime
import random
from tqdm import tqdm

notebook_path = osp.abspath(
    osp.join(os.getcwd(), osp.basename(get_connection_file()))
)
parent_dir = osp.dirname(osp.dirname(notebook_path))
sys.path.append(parent_dir)

# Import utility functions and model utilities
from utils.utils import *
from utils.embedding_model_utils import load_dataset
from utils.score_matching_model_utils import *

# from utils.score_matching_model.utils import *

from utils.score_matching_model.ScoreModel import ScoreModel

set_seed()

device = "cuda" if torch.cuda.is_available() else "cpu"

set_seed()

In [None]:
run_sweep = False
if run_sweep:
    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)
else:
    config = {
        "embedding_model_dir": "../trained_embedding_models/Pubmed_TransE_node_classification_2024.05.15.11.58.29_70244147_embedding_model/",
        "score_model_hidden_dim": 512,
        "batch_size": 512,
        "max_epochs": 3,
        "num_sde_timesteps": 20,  # Adjust based on your desired SDE iterations
        "lr": 1e-4,
        "similarity_metric": "cosine",
        "k": [1, 3, 10],  # used for top-k evaluation
        "verbose": True,
        "num_epochs_without_improvement_until_early_finish": 5,
        "validate_after_this_many_epochs": 1,
    }

In [None]:
def build_model(config, view_english=False):
    embedding_model_dir = config["embedding_model_dir"]
    embedding_model, embedding_model_name, dataset_name = (
        initialize_trained_embedding_model(embedding_model_dir, device)
    )
    embedding_model.eval()

    train_data, val_data, test_data, data_path = load_dataset(
        dataset_name, parent_dir=parent_dir
    )

    original_model = ScoreModel(
        embedding_model,
        config["score_model_hidden_dim"],
        config["num_sde_timesteps"],
        config["similarity_metric"],
        config["task"],
        config["aux_dict"],
    ).to(device)

    # Calculate batch size based on the number of GPUs available
    num_gpus = torch.cuda.device_count()
    batch_size_per_gpu = config["batch_size"]
    batch_size = (
        batch_size_per_gpu * num_gpus if num_gpus > 0 else batch_size_per_gpu
    )
    num_workers = num_gpus if num_gpus > 0 else 1

    # Helper function to prepare loader arguments
    def prepare_loader_args(data, batch_size, num_workers):
        edge_index_0 = data.edge_index[0]
        edge_type = data.edge_type
        edge_index_1 = data.edge_index[1]

        x = data.x if "x" in data else None

        if x is not None:
            dataset = TensorDataset(edge_index_0, edge_type, edge_index_1, x)
        else:
            dataset = TensorDataset(edge_index_0, edge_type, edge_index_1)

        loader_args = {
            "batch_size": batch_size,
            "shuffle": True,
            "num_workers": num_workers,
        }

        return dataset, loader_args

    # Prepare loader arguments for train, validation, and test datasets
    train_dataset, train_loader_args = prepare_loader_args(
        train_data, batch_size, num_workers
    )
    val_dataset, val_loader_args = prepare_loader_args(
        val_data, batch_size, num_workers
    )
    test_dataset, test_loader_args = prepare_loader_args(
        test_data, batch_size, num_workers
    )

    # Create data loaders for training, validation, and testing
    # Access and view the first 5 elements in the PyTorch Geometric Data object
    # keys = list(train_data.keys())
    # for i in range(5):  # View the first 5 elements
    #     print(f"Element {i+1}:")
    #     for key in keys:
    #         if isinstance(train_data[key], torch.Tensor):
    #             print(f"{key}: {train_data[key][i]}")
    #         else:
    #             print(f"{key}: {train_data[key]}")
    #     print()
    train_loader = DataLoader(train_dataset, **train_loader_args)
    val_loader = DataLoader(val_dataset, **val_loader_args)
    test_loader = DataLoader(test_dataset, **test_loader_args)

    optimizer = Adam(original_model.parameters(), lr=config["lr"])

    # Use DataParallel for multi-GPU setups on single server
    if num_gpus > 1:
        score_model = torch.nn.DataParallel(original_model)
    else:
        score_model = original_model

    # Use DataParallel for multi-GPU setups on single server
    if num_gpus > 1:
        score_model = torch.nn.DataParallel(original_model)
    else:
        score_model = original_model

    if view_english:
        entity_dict, relation_dict = load_dicts(
            data_path
        )  # note: if you want to load english for RotatE, store the dicts to map from indices to freebase ids/ relations in data_path/processed
        score_model.entity_dict = entity_dict
        score_model.relation_dict = relation_dict

    score_model.train_data = train_data
    score_model.val_data = val_data
    score_model.test_data = test_data
    score_model.train_loader = train_loader
    score_model.val_loader = val_loader
    score_model.test_loader = test_loader
    score_model.optimizer = optimizer
    score_model.config = config
    score_model.dataset_name = dataset_name
    score_model.embedding_model_name = embedding_model_name
    score_model.embedding_model_dir = embedding_model_dir
    score_model.original_model = (
        original_model  # Keep a reference to the original model
    )

    # Determine the path to save results and create directory if it does not exist
    if "save_path" in config:
        save_path = config["save_path"]
        if not os.path.exists(save_path):
            raise ValueError(
                f"Directory {save_path} does not exist in which to save the trained score matching models. Please create it before saving."
            )
        save_path = osp.join(
            save_path,
            f"{config['prefix']}_score_matching_model",
        )
    else:
        save_path = osp.join(
            parent_dir,
            "trained_score_matching_models",
            f"{config['prefix']}_score_matching_model",
        )

    os.makedirs(save_path, exist_ok=True)
    score_model.save_path = save_path

    # Save the model configuration
    save_score_matching_model_config(score_model)

    return score_model


def train_model(score_model):
    score_model.train()
    loss_epoch = 0
    for batch in tqdm(score_model.train_loader, leave=True):
        # Extract data from the batch
        h, r, t = batch[0].to(device), batch[1].to(device), batch[2].to(device)

        # Check if 'x' is present in the batch
        x = batch[3].to(device) if len(batch) > 3 else None

        for timestep in range(score_model.config["num_sde_timesteps"]):
            loss = score_model.original_model.loss(
                h, r, t, timestep, x, score_model.config["task"]
            )
            loss.backward()
        score_model.optimizer.step()

        loss_epoch += loss.item()
    return loss_epoch / len(score_model.train_loader)


def test_model(score_model, val=False, view_english=False):
    score_model.eval()
    test_loss = 0
    loader = score_model.val_loader if val else score_model.test_loader
    all_metrics = []
    with torch.no_grad():
        for batch in loader:
            # Extract data from the batch
            h, r, t = (
                batch["head_index"],
                batch["rel_type"],
                batch["tail_index"],
            )
            x = batch.get("x", None)

            # Move data to the appropriate device
            h, r, t = (
                h.to(device),
                r.to(device),
                t.to(device),
            )
            if x is not None:
                x = x.to(device)

            for timestep in range(score_model.config["num_sde_timesteps"]):
                loss = score_model.original_model.loss(
                    score_model, h, r, t, timestep
                )
                test_loss += loss.item()

            # Calculate metrics for the batch
            metrics = model.original_model.test(
                h=h,
                r=r,
                t=t,
                x=x,
                k=model.config["k"],
                task=model.original_model.config["task"],
            )
            all_metrics.append(metrics)

    # Aggregate metrics across all batches
    if model.original_model.config["task"] in [
        "relation_prediction",
        "head_prediction",
        "tail_prediction",
    ]:
        mean_rank = sum(m[0] for m in all_metrics) / len(all_metrics)
        mrr = sum(m[1] for m in all_metrics) / len(all_metrics)
        hits_at_k = {
            k: sum(m[2][k] for m in all_metrics) / len(all_metrics)
            for k in m[2]
        }
        performance_metrics = {
            "loss": total_loss / total_examples,
            "mean_rank": mean_rank,
            "mrr": mrr,
        }
        for k_value, hits_value in hits_at_k.items():
            performance_metrics[f"hits_at_{k_value}"] = hits_value

    elif model.original_model.config["task"] == "node_classification":
        accuracy = sum(all_metrics) / len(all_metrics)
        performance_metrics = {
            "loss": total_loss / total_examples,
            "accuracy": accuracy,
        }

    else:
        raise ValueError(
            f"model task isn't valid: {model.original_model.config['task']}"
        )

    return performance_metrics


def main(config=None):
    run_timestamp = datetime.now().strftime("%Y.%m.%d.%H.%M.%S")
    with wandb.init(
        project=f"ScoreMatchingDiffKG_ScoreMatching_{run_timestamp}",
        name=f"{run_timestamp}_score_matching_model_run",  # Use a temporary name
        config=config if config is not None else {},
    ):
        config = wandb.config
        (
            config["dataset_name"],
            config["embedding_model_name"],
            config["task"],
            config["aux_dict"],
        ) = extract_info_from_string(config["embedding_model_dir"])
        config["prefix"] = generate_prefix(config, run_timestamp)

        wandb.run.name = f"{config['prefix']}_score_matching_model_run"
        if (
            config["task"] == "node_classification"
            and config["dataset_name"]
            not in [
                "Cora",
                "Citeseer",
                "Pubmed",
            ]
        ) or (
            config["task"] != "node_classification"
            and config["dataset_name"]
            in [
                "Cora",
                "Citeseer",
                "Pubmed",
            ]
        ):
            print(f"Skipping {config['task']} on {config['dataset_name']}")
            return

        wandb.run.name = f"{config['prefix']}_score_matching_model_run"
        score_model = build_model(config)
        wandb.watch(score_model)

        num_epochs_without_improvement_until_early_finish = config[
            "num_epochs_without_improvement_until_early_finish"
        ]
        validate_after_this_many_epochs = config[
            "validate_after_this_many_epochs"
        ]

        best_train_loss = best_val_loss = float("inf")
        epochs_without_improvement = 0

        for epoch in range(config.max_epochs):
            loss = train_model(score_model)

            if config["verbose"]:
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.10f}")

            if loss <= best_train_loss:
                best_train_loss = loss
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1

                if (
                    epochs_without_improvement
                    >= num_epochs_without_improvement_until_early_finish
                ):
                    print("Stopping early due to increasing training loss.")
                    break

            train_metrics = {"train_epoch": epoch, "train_loss": loss}
            if epoch % validate_after_this_many_epochs == 0 and epoch > 0:
                val_metrics = test_model(model, val=True)

                if config["verbose"]:
                    metrics_info = ", ".join(
                        [
                            f'Val {key.replace("_", " ").title()}: {value:.10f}'
                            for key, value in val_metrics.items()
                            if isinstance(value, (int, float))
                        ]
                    )
                    print(metrics_info)

                if val_metrics["loss"] < best_val_loss:
                    best_val_loss = val_metrics["loss"]
                    save_trained_score_matching_model_and_config(
                        model, epoch + 1, val_metrics
                    )

                val_metrics = {
                    f"val_{key}": value for key, value in val_metrics.items()
                }

            # log to wandb
            wandb.log(
                {**train_metrics, **val_metrics}
                if "val_metrics" in locals()
                else {**train_metrics}
            )

        model.load_state_dict(
            torch.load(
                osp.join(
                    model.save_path,
                    f"{model.config['prefix']}_score_matching_model_weights.pth",
                )
            )
        )

        test_metrics = test_model(model)

        if config["verbose"]:
            test_metrics_info = ", ".join(
                [
                    f'Test {key.replace("_", " ").title()}: {value:.4f}'
                    for key, value in test_metrics.items()
                    if isinstance(value, (int, float))
                ]
            )
            print(test_metrics_info)

        with open(
            osp.join(
                model.save_path,
                f"{model.config['prefix']}_score_matching_model_performance.txt",
            ),
            "a",
        ) as file:
            file.write("Test Metrics:\n")
            for metric, value in test_metrics.items():
                file.write(f"{metric}: {value}\n")

        test_metrics = {
            f"test_{key}": value for key, value in test_metrics.items()
        }
        wandb.log({**test_metrics})
        wandb.save(model.save_path, base_path=parent_dir)

        return model

In [None]:
if run_sweep:

    with open("sweep_config.yaml", "r") as file:
        sweep_config = yaml.safe_load(file)

    sweep_id = wandb.sweep(
        project=f"ScoreMatchingDiffKG_Score_Sweep", sweep=sweep_config
    )

    wandb.agent(sweep_id, function=main)
else:
    model = main(config)