## Author: Blaine Hill

In this notebook, we train a score matching model as described by [Yang Song](https://yang-song.net/blog/2021/score/)


$$s_\theta(x) = \nabla_x \log p_\theta(x) = -\nabla_x f_\theta(x) - \nabla_x \log Z_\theta = -\nabla_x f_\theta(x)$$

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torch_geometric.nn import RotatE
from torch_geometric.datasets import FB15k_237
import sys
import os
import os.path as osp
import math
import numpy as np
from icecream import ic
from torchdiffeq import odeint 

parent_dir = os.path.abspath(os.pardir)
sys.path.append(parent_dir)
data_path = osp.join(parent_dir, "data", "FB15k_237")

device = "cuda" if torch.cuda.is_available() else "cpu"
train_data = FB15k_237(data_path, split="train")[0].to(device)
val_data = FB15k_237(data_path, split="val")[0].to(device)
test_data = FB15k_237(data_path, split="test")[0].to(device)


class ScoreModel(nn.Module):
    def __init__(self, embedding_model, embedding_dim, relation_dim, config):
        super(ScoreModel, self).__init__()
        self.embedding_model = embedding_model
        self.embedding_dim = embedding_dim
        self.relation_dim = relation_dim
        self.config = config 

        self.score_net = nn.Sequential(
            nn.Linear(embedding_dim + relation_dim, 512, dtype=torch.float),
            nn.ReLU(),
            nn.Linear(512, 512, dtype=torch.float),
            nn.ReLU(),
            nn.Linear(512, 1, dtype=torch.float),
        )

    def forward(self, h, r, t, timestep=None):
        h_emb = self.embedding_model.node_emb(h)
        r_emb = self.embedding_model.rel_emb(r)
        t_emb = self.embedding_model.node_emb(t)

        # Implement your desired distance measure here (e.g., L2 distance)
        distance = torch.linalg.norm(h_emb + r_emb - t_emb, dim=-1)

        # Gradually increase the weight of the distance term during SDE steps
        weight = 0.0  # No weight increase if timestep is None
        if timestep is not None:
            def sigmoid(x):
                return 1 / (1 + math.exp(-x))
            weight = sigmoid(timestep / (self.config["num_steps"] - 1))
        score = weight * distance
        return score



def sde_func_wrapper(h, t, score_model):
    def score_sde(time, r_emb):
        """
        Function for the reverse SDE process.
        """
        ic(r_emb)
        ic(r_emb.shape)
        with torch.no_grad():
            score_grad = torch.autograd.grad(score_model(h, r_emb, t), [r_emb])[0]
        return -score_grad  # Update in the direction that minimizes the score
    return score_sde

def reverse_sde(h, r_init, t, score_model, config):
    """
    Perform reverse SDE to refine the target entity embedding.
    """
    t_steps = torch.linspace(0, 1.0, config["num_steps"]).to(config["device"])
    score_sde_func = sde_func_wrapper(h, t, score_model)
    r_emb = odeint(score_sde_func, r_init, t_steps)[-1, :]
    return r_emb


def build_model(config):
    embedding_model = RotatE(num_nodes=train_data.num_nodes, num_relations=train_data.num_edge_types, hidden_channels=50).to(config["device"])
    embedding_model.load_state_dict(torch.load("../embedding_model/FB15k_237_RotatE_embedding_model_weights.pth"))
    embedding_model.eval()

    # Get entity and relation embeddings
    with torch.no_grad():
        entity_embeddings_real = embedding_model.node_emb.weight.detach()
        entity_embeddings_im = embedding_model.node_emb_im.weight.detach()
        entity_embeddings = torch.cat([entity_embeddings_real, entity_embeddings_im], dim=-1)
        relation_embeddings = embedding_model.rel_emb.weight.detach()

    score_model = ScoreModel(embedding_model, entity_embeddings.shape[-1], relation_embeddings.shape[-1], config).to(config["device"])


    # Create data loaders
    train_dataset = TensorDataset(train_data.edge_index[0], train_data.edge_type, train_data.edge_index[1])
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)

    val_dataset = TensorDataset(val_data.edge_index[0], val_data.edge_type, val_data.edge_index[1])
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)

    test_dataset = TensorDataset(test_data.edge_index[0], test_data.edge_type, test_data.edge_index[1])
    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False)

    optimizer = Adam(score_model.parameters(), lr=config["lr"])


    return score_model, optimizer, train_loader, val_loader, test_loader


def denoising_score_matching_loss(score_model, h, r, t, timestep):
    """
    Denoising score-matching loss with noise-conditional score networks.
    """
    true_score = score_model(h, r, t)  # simply do not pass in the timestep
    noisy_score = score_model(h, r, t, timestep)
    return ((true_score - noisy_score) ** 2).mean()


def train_model(score_model, optimizer, train_loader, config):
    score_model.train()
    loss_epoch = 0
    for batch in train_loader:
        h, r, t = batch
        h, r, t = h.to(config["device"]), r.to(config["device"]), t.to(config["device"])

        optimizer.zero_grad()
        for timestep in range(config["num_steps"]):
            loss = denoising_score_matching_loss(score_model, h, r, t, timestep)
            loss.backward()
        optimizer.step()

        loss_epoch += loss.item()

    return loss_epoch / len(train_loader)



def test_model(score_model, test_loader, config):
    score_model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch in test_loader:
            h, r, t = batch
            h, r, t = h.to(config["device"]), r.to(config["device"]), t.to(config["device"])
            for timestep in range(config["num_steps"]):
                loss = denoising_score_matching_loss(score_model, h, r, t, timestep)
                test_loss += loss.item()

    return test_loss / (len(test_loader) * config["num_steps"])  # Average across steps


def compute_metrics(score_model, test_loader, config):
    """
    Compute link prediction metrics: mean rank, mean reciprocal rank, and Hits@k.
    """
    score_model.eval()
    mean_rank = 0.0
    mean_reciprocal_rank = 0.0
    hits_at_10 = 0
    hits_at_3 = 0

    with torch.no_grad():
        for batch in test_loader:
            h, r, t = batch
            h, t = h.to(config["device"]), t.to(config["device"])

            # Sample a set of candidate target entities (r')
            num_candidates = 50  # You can adjust this value
            candidate_r = torch.randint(0, score_model.embedding_model.num_relations, size=(r.shape[0], num_candidates)).to(config["device"])

            # Refine candidate embeddings using reverse SDE
            refined_candidates = []
            for i in range(r.shape[0]):
                refined_candidate = reverse_sde(h[i], candidate_r, t[i], score_model, config)
                refined_candidates.append(refined_candidate)
            refined_candidates = torch.stack(refined_candidates, dim=0)

            # Compute scores for all candidates (including ground truth)
            all_r = torch.cat([r.unsqueeze(1), refined_candidates], dim=1)
            scores = score_model(r.repeat(1, num_candidates + 1), r.repeat(1, num_candidates + 1), all_r, torch.ones(r.shape[0], num_candidates + 1) * config["num_steps"]).squeeze(2)

            # Evaluate link prediction metrics
            for i in range(h.shape[0]):
                ground_truth_rank = (scores[i] == scores[i].max()).nonzero(as_tuple=True)[0].item()
                mean_rank += 1 + ground_truth_rank
                mean_reciprocal_rank += 1.0 / (1 + ground_truth_rank)
                hits_at_10 += (scores[i][:10] == scores[i].max()).any().item()
                hits_at_3 += (scores[i][:3] == scores[i].max()).any().item()

    mean_rank /= len(test_loader.dataset)
    mean_reciprocal_rank /= len(test_loader.dataset)
    hits_at_10 /= len(test_loader.dataset)
    hits_at_3 /= len(test_loader.dataset)

    print(f"\nMean Rank: {mean_rank:.4f}")
    print(f"Mean Reciprocal Rank: {mean_reciprocal_rank:.4f}")
    print(f"Hits@10: {hits_at_10:.4f}")
    print(f"Hits@3: {hits_at_3:.4f}")


def main(config):
    score_model, optimizer, train_loader, val_loader, test_loader = build_model(config)

    for epoch in range(config["num_epochs"]):
        train_loss = train_model(score_model, optimizer, train_loader, config)
        test_loss = test_model(score_model, test_loader, config)
        print(f"\n\nEpoch {epoch + 1}\nTrain Loss: {train_loss:.20f}\nTest Loss: {test_loss:.20f}")

    compute_metrics(score_model, test_loader, config)

config = {
    "batch_size": 64,
    "num_epochs": 1,
    "num_steps": 20,  # Adjust based on your desired SDE iterations
    "lr": 1e-4,
    "device": device,
}

main(config)

