First we will train a DDPM model using the embedding model so we can get a good approximation of the fisher information matrix. Then, we use the FIM as metric tensor which will represent the distance function

distance = sqrt((h - t)^T @ FIM @ (h - t))

In [None]:
import os.path as osp
import torch
import onnx
from torch_geometric.datasets import FB15k_237
from denoising_diffusion_pytorch import GaussianDiffusion, UNet

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
def build_ddpm_model(dataset_name='FB15k_237', embedding_model_name='RotatE', config=None):
    # Load the RotatE embedding model
    parent_dir = osp.dirname(osp.abspath(__file__))
    embedding_model_path = osp.join(parent_dir, f'{dataset_name}_{embedding_model_name}_embedding_model_weights.onnx')
    embedding_model = onnx.load(embedding_model_path)
    embedding_model = torch.jit.trace(embedding_model.cpu(), torch.randn(1, 3))

    # Set up the noise scheduler
    noise_scheduler = GaussianDiffusion(
        timesteps=1000,
        loss_type='l2'
    )

    # Set up the U-Net model
    unet = UNet(
        dim=64,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        with_time_emb=True,
    ).to(device)

    return embedding_model, noise_scheduler, unet

def embed(embedding_model, data):
    x = data.x
    edge_index = data.edge_index
    edge_type = data.edge_attr
    return embedding_model(x, edge_index, edge_type)

def stein_score(model, embedding_model, data, t):
    with torch.no_grad():
        x = embed(embedding_model, data).to(device)
        noise = torch.randn_like(x)
        noisy_x = noise_scheduler.q_sample(x, t, noise)
        pred = model(noisy_x, t)
        score = torch.autograd.grad(torch.sum(pred**2), noisy_x)[0]
    return score

def compute_stein_scores(model, embedding_model, data_loader, noise_scheduler):
    stein_scores = []
    for data in data_loader:
        data = data.to(device)
        for t in noise_scheduler.sample_timesteps(data.num_nodes):
            stein_scores.append(stein_score(model, embedding_model, data, t))
    return stein_scores

def estimate_fim(stein_scores):
    fim = torch.zeros((stein_scores[0].shape[-1], stein_scores[0].shape[-1]), device=device)
    for score in stein_scores:
        fim += score.T @ score
    fim /= len(stein_scores)
    return fim

def train_ddpm_model(unet, embedding_model, noise_scheduler, train_loader, config):
    optimizer = torch.optim.Adam(unet.parameters(), lr=config.lr)

    for epoch in range(config.epochs):
        unet.train()
        for data in train_loader:
            data = data.to(device)
            x = embed(embedding_model, data)
            loss = noise_scheduler.training_losses(unet, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return unet


In [None]:
# Configuration
config = {
    'data_dir': 'data',
    'model_dir': 'models',
    'batch_size': 32,
    'lr': 1e-4,
    'epochs': 100
}

# Load dataset and create data loaders
dataset_name = 'FB15k_237'
data_path = osp.join(config['data_dir'], dataset_name)

train_data = FB15k_237(data_path, split='train')[0].to(device)
val_data = FB15k_237(data_path, split='val')[0].to(device)
test_data = FB15k_237(data_path, split='test')[0].to(device)

train_loader = train_data.loader(batch_size=config['batch_size'], shuffle=True)
val_loader = val_data.loader(batch_size=config['batch_size'], shuffle=False)
test_loader = test_data.loader(batch_size=config['batch_size'], shuffle=False)

# Build the DDPM model
embedding_model, noise_scheduler, unet = build_ddpm_model(dataset_name, 'RotatE', config)

# Train the DDPM model
unet = train_ddpm_model(unet, embedding_model, noise_scheduler, train_loader, val_loader, config)

In [None]:
# Save the trained DDPM model
model_path = osp.join(config['model_dir'], 'ddpm_model.onnx')
torch.onnx.export(unet, embed(embedding_model, train_data.to(device)), model_path)