## Author: Blaine Hill

In this notebook, we program out how to embed a KG such as FB15k_237 using RotatE. The weights are saved under embedding_model_weights.pth

In [1]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "embedding_model.ipynb"

import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mbthill1[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import torch
from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import RotatE
from torch_geometric.loader import DataLoader
import torch.optim as optim

import os
import os.path as osp

In [3]:

from datetime import datetime
from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx

dataset_name='FB15k_237'
embedding_model_name='RotatE'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = osp.join(os.getcwd(), '..', 'data', dataset_name)

train_data = FB15k_237(path, split='train')[0].to(device)
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
wandb.init(
    project=f"ScoreMatchingDiffKG_Embedding",
    name=f"{dataset_name}_{embedding_model_name}_embedding_model {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
    config={
        "epochs": 3,
        "batch_size": 1000,
        "lr": 0.001,
        "weight_decay": 1e-6,
        "k": 10 #used for top-k evaluation
    }
)

config = wandb.config

In [5]:
model = RotatE(num_nodes=train_data.num_nodes, num_relations=train_data.num_edge_types, hidden_channels=50).to(device) 

wandb.watch(model) #tracks gradients

[]

In [6]:
# Train model on FB15k dataset
train_loader = model.loader(
    head_index=train_data.edge_index[0],
    rel_type=train_data.edge_type,
    tail_index=train_data.edge_index[1],
    batch_size=config.batch_size,
    shuffle=True,
)
# val_loader = model.loader(
#     head_index=val_data.edge_index[0],
#     rel_type=val_data.edge_type,
#     tail_index=val_data.edge_index[1],
#     batch_size=config.batch_size,
#     shuffle=True,
# )
# test_loader = model.loader(
#     head_index=test_data.edge_index[0],
#     rel_type=test_data.edge_type,
#     tail_index=test_data.edge_index[1],
#     batch_size=config.batch_size,
#     shuffle=True,
# )
optimizer = optim.Adagrad(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

In [7]:
def train(dataloader):
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in dataloader:
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        #scale on specific number of elements in batch
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples

In [8]:
@torch.no_grad()
def test(data):
    model.eval()
    # Calculate loss
    loss = model.loss(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
    )
    # Evaluate the model using model.test()
    mean_rank, mrr, hits_at_k = model.test(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=config.batch_size,
        k=config.k,
    )
    return loss, mean_rank, mrr, hits_at_k

In [9]:
for epoch in range(config.epochs):
    epoch += 1 #for string printing
    loss = train(train_loader)
    print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}')
    train_metrics ={
        "train_epoch": epoch-1, #fixes the index
        "train_loss": loss
    }
    if epoch % 2 == 0:
        loss, mean_rank, mrr, hits_at_10 = test(val_data)
        print(f'Val Mean Rank: {mean_rank:.2f}, Val Mean Reciprocal Rank: {mrr:.2f}, Val Hits@10: {hits_at_10:.4f}')
        val_metrics ={
            "val_loss": loss,
            "val_mean_rank": mean_rank,
            "val_mrr": mrr,
            "val_hits_at_10": hits_at_10
        }
    #log to wandb
    wandb.log({**train_metrics, **val_metrics} if 'val_metrics' in locals() else {**train_metrics})

#once everything is finished, test model
loss, mean_rank, mrr, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {mean_rank:.2f}, Test Mean Reciprocal Rank: {mrr:.2f}, Test Hits@10: {hits_at_10:.4f}')
test_metrics ={
            "test_loss": loss,
            "test_mean_rank": mean_rank,
            "test_mrr": mrr,
            "test_hits_at_10": hits_at_10
        }
wandb.log({**test_metrics})

Epoch: 001, Train Loss: 0.7774
Epoch: 002, Train Loss: 0.7709


100%|██████████| 17535/17535 [04:04<00:00, 71.73it/s]


Val Mean Rank: 9269.36, Val Mean Reciprocal Rank: 0.00, Val Hits@10: 0.0001
Epoch: 003, Train Loss: 0.7667


100%|██████████| 20466/20466 [04:43<00:00, 72.08it/s]

Test Mean Rank: 8993.22, Test Mean Reciprocal Rank: 0.00, Test Hits@10: 0.0002





In [10]:
path = osp.join(os.getcwd(), f'{dataset_name}_embedding_model_weights.pth')
torch.save(model.state_dict(), path)

model.to_onnx()
wandb.save(f'{dataset_name}_embedding_model_weights.onnx')

AttributeError: 'RotatE' object has no attribute 'to_onnx'