# Knowledge Graph Embedding (KGE) Training Notebook

In [None]:
### Imports
# Local
from model.kge_trainer import KGETrainer
from utils import IBKHDataset

# External
import time
import random
import numpy as np

import torch
import torch.optim as optim

import torch_geometric as pyg
from torch_geometric.nn.kge import TransE

import wandb

### Parameters
wandb.login()

In [2]:
# Seeds
pyg.seed_everything(10) # random, np, torch, torch.cuda

---
# Training

In [3]:
### Data
try:
    data = torch.load('inputs/KGE/input_triplet.pt', weights_only=False)
    print('Data loaded from cache...')
except IOError:
    iKBH_dataset = IBKHDataset(data_dir='data/iBKH')
    data = iKBH_dataset.build_data()
    torch.save(data, 'inputs/KGE/input_triplet.pt')

Data loaded from cache...


In [4]:
### Device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
# device = torch.device('cpu')
print(f'Using device: {device}')

# DATA TO DEVICE
data.to(device)

Using device: mps


Data(edge_index=[2, 11470691], edge_type=[11470691], num_nodes=122777, num_edge_types=6)

In [5]:
### Parameters
config = {
    'model': dict(
        num_nodes = data.num_nodes,
        num_relations = data.num_edge_types,
        hidden_channels = 64,
        margin = 1.0, p_norm = 1),

    'dataloader': dict(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=4096,
        shuffle=True
    ),

    'training': dict(
        lr = 0.01,
        epochs = 15,
    ),

    'device': device,
}

### Model
kge_trainer = KGETrainer(
    model := TransE(**config['model']).to(device), # Model
    train_dataloader = model.loader(**config['dataloader']), # Dataloader
    val_dataloader = None,
    device = device,
    wandb_run = wandb.init(
        project ='GraphETM-KGE',
        name = f'iBKH-TransE-64',
        config=config, save_code=True) # Start Wandb
)

### Training
kge_trainer.fit(
    epochs = config['training']['epochs'],
    optimizer = optim.Adam(kge_trainer.model.parameters(), lr=config['training']['lr']) # Optimizer
)

Epoch 0: Training:   0%|          | 0/42015 [00:00<?, ?batch/s]

0,1
train/loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
train/loss,0.09632


In [6]:
### Extract Embeddings
rho = kge_trainer.model.node_emb.weight.detach().cpu()
torch.save(rho, 'inputs/GraphETM/embedding_full.pt')

In [8]:
# Save EHR and SC embeddings
# sc_indices  = np.load('inputs/GraphETM/id_embed_sc.npy')
# ehr_indices = np.load('inputs/GraphETM/id_embed_ehr.npy')

# torch.save(rho[sc_indices,  :], 'inputs/GraphETM/embedding_sc.pt')
# torch.save(rho[ehr_indices, :], 'inputs/GraphETM/embedding_ehr.pt')

In [19]:
#DONE