## Author: Blaine Hill

In this notebook, we train a score matching model as described by [Yang Song](https://yang-song.net/blog/2021/score/)


$$s_\theta(x) = \nabla_x \log p_\theta(x) = -\nabla_x f_\theta(x) - \nabla_x \log Z_\theta = -\nabla_x f_\theta(x)$$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.nn import RotatE
from torch_geometric.datasets import FB15k_237
import sys
import os
import os.path as osp
import math
import numpy as np
from icecream import ic
from torchdiffeq import odeint

parent_dir = os.path.abspath(os.pardir)
sys.path.append(parent_dir)
data_path = osp.join(parent_dir, "data", "FB15k_237")

device = "cuda" if torch.cuda.is_available() else "cpu"
train_data = FB15k_237(data_path, split="train")[0].to(device)
val_data = FB15k_237(data_path, split="val")[0].to(device)
test_data = FB15k_237(data_path, split="test")[0].to(device)


class ScoreModel(nn.Module):
    def __init__(self, embedding_model, node_emb_dim, rel_emb_dim, config):
        super(ScoreModel, self).__init__()
        self.config = config  # store config for hyperparameters
        self.embedding_model = embedding_model
        self.node_emb_dim = node_emb_dim
        self.rel_emb_dim = rel_emb_dim
        self.score_net = nn.Sequential(
            nn.Linear(node_emb_dim + rel_emb_dim, 512, dtype=torch.float),
            nn.ReLU(),
            nn.Linear(512, 512, dtype=torch.float),
            nn.ReLU(),
            nn.Linear(512, 1, dtype=torch.float),
        )

    def forward(self, h_emb, r_emb, t_emb, timestep=None):
        # Implement your desired distance measure here (e.g., L2 distance)
        distance = torch.linalg.norm(h_emb + r_emb - t_emb, dim=-1)

        # Gradually increase the weight of the distance term during SDE steps
        weight = 0.0  # No weight increase if timestep is None
        if timestep is not None:

            def sigmoid(x):
                return 1 / (1 + math.exp(-x))

            weight = sigmoid(timestep / (self.config["num_steps"] - 1))
        score = weight * distance
        return score


def build_model(config):
    embedding_model = RotatE(
        num_nodes=train_data.num_nodes,
        num_relations=train_data.num_edge_types,
        hidden_channels=50,
    ).to(config["device"])
    embedding_model.load_state_dict(
        torch.load("../embedding_model/FB15k_237_RotatE_embedding_model_weights.pth")
    )
    embedding_model.eval()

    # Get entity and relation embeddings
    with torch.no_grad():
        entity_embeddings_real = embedding_model.node_emb.weight.detach()
        entity_embeddings_im = embedding_model.node_emb_im.weight.detach()
        entity_embeddings = torch.cat([entity_embeddings_real, entity_embeddings_im], dim=-1)
        relation_embeddings = embedding_model.rel_emb.weight.detach()

    score_model = ScoreModel(
        embedding_model, entity_embeddings.shape[-1], relation_embeddings.shape[-1], config
    ).to(config["device"])

    # Create data loaders
    train_dataset = TensorDataset(
        train_data.edge_index[0], train_data.edge_type, train_data.edge_index[1]
    )
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)

    val_dataset = TensorDataset(
        val_data.edge_index[0], val_data.edge_type, val_data.edge_index[1]
    )
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)

    test_dataset = TensorDataset(
        test_data.edge_index[0], test_data.edge_type, test_data.edge_index[1]
    )
    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)

    optimizer = Adam(score_model.parameters(), lr=config["lr"])

    return score_model, optimizer, train_loader, val_loader, test_loader


def denoising_score_matching_loss(score_model, h, r, t, timestep):
    """
    Denoising score-matching loss with noise-conditional score networks.
    """
    h_emb, r_emb, t_emb = (
        score_model.embedding_model.node_emb(h),
        score_model.embedding_model.rel_emb(r),
        score_model.embedding_model.node_emb(t),
    )
    true_score = score_model(h_emb, r_emb, t_emb)  # simply do not pass in the timestep
    noisy_score = score_model(h_emb, r_emb, t_emb, timestep)
    return ((true_score - noisy_score) ** 2).mean()


def train_model(score_model, optimizer, train_loader, config):
    score_model.train()
    loss_epoch = 0
    for batch in train_loader:
        h, r, t = batch
        h, r, t = h.to(config["device"]), r.to(config["device"]), t.to(config["device"])

        optimizer.zero_grad()
        for timestep in range(config["num_steps"]):
            loss = denoising_score_matching_loss(score_model, h, r, t, timestep)
            loss.backward()
        optimizer.step()

        loss_epoch += loss.item()

    return loss_epoch / len(train_loader)


def test_model(score_model, test_loader, config):
    score_model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            h, r, t = batch
            h, r, t = h.to(config["device"]), r.to(config["device"]), t.to(config["device"])
            for timestep in range(config["num_steps"]):
                loss = denoising_score_matching_loss(score_model, h, r, t, timestep)
                test_loss += loss.item()

    return test_loss / (len(test_loader) * config["num_steps"])  # Average across steps


def reverse_sde(h_emb, r_emb, t_emb, score_model, config):
    """
    Refine the relation embedding using reverse SDE for better link prediction.
    h_emb: Embedding of the head entity.
    r_emb: Initial embedding of the relation.
    t_emb: Embedding of the tail entity.
    score_model: The trained score model.
    config: Configuration dictionary.
    """
    # Define the time steps for the reverse SDE process
    t_steps = torch.linspace(config["num_steps"] - 1, 0, config["num_steps"]).to(
        config["device"]
    )

    # Define the SDE function for the reverse process
    def sde_func(t, r_emb):
        with torch.enable_grad():
            r_emb.requires_grad_(True)
            score = score_model(h_emb, r_emb, t_emb, t)
            grad_r_emb = torch.autograd.grad(score.sum(), r_emb, create_graph=True)[0]
        return -grad_r_emb  # Reverse direction

    # Perform the reverse SDE integration
    with torch.no_grad():
        refined_r_emb = odeint(sde_func, r_emb, t_steps)[0]

    return refined_r_emb


def compute_metrics(score_model, test_loader, config, initial_noise_std=0.1):
    """
    Compute link prediction metrics using the true relation r for evaluation
    """
    score_model.eval()
    for batch in test_loader:

        h, true_r, t = batch  # Use the true relation r for evaluation
        h, true_r, t = (
            h.to(config["device"]),
            true_r.to(config["device"]),
            t.to(config["device"]),
        )

        # Get embeddings for h and t
        h_emb = score_model.embedding_model.node_emb(h)
        t_emb = score_model.embedding_model.node_emb(t)

        # Start with a noisy relation embedding
        noisy_r_emb = (
            torch.randn_like(score_model.embedding_model.rel_emb.weight[0]) * initial_noise_std
        )
        noisy_r_emb = noisy_r_emb.to(config["device"])

        # Refine the noisy relation embedding using reverse SDE
        refined_r_emb = reverse_sde(h_emb, noisy_r_emb, t_emb, score_model, config)

        # TODO: need to add metric calculation now - perhaps use cosine similarity?


def main(config):
    score_model, optimizer, train_loader, val_loader, test_loader = build_model(config)

    for epoch in range(config["num_epochs"]):
        train_loss = train_model(score_model, optimizer, train_loader, config)
        test_loss = test_model(score_model, test_loader, config)
        print(
            f"\n\nEpoch {epoch + 1}\nTrain Loss: {train_loss:.20f}\nTest Loss: {test_loss:.20f}"
        )

    compute_metrics(score_model, test_loader, config)


config = {
    "batch_size": 64,
    "num_epochs": 1,
    "num_steps": 20,  # Adjust based on your desired SDE iterations
    "lr": 1e-4,
    "device": device,
    "k": 10,
}

main(config)