# Knowledge Graph Embedding (KGE) Training Notebook

In [1]:
### Imports
# Local
from model.kge_trainer import KGETrainer
from utils import IBKHDataset

# External
import time
import random
import numpy as np
import pandas as pd
import scanpy as sc

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import torch_geometric as pyg
from torch_geometric.nn import GCNConv, HeteroConv
from torch_geometric.nn.kge import TransE
import torch_geometric.transforms as T

from tqdm.notebook import tqdm, trange
from sklearn.metrics import adjusted_rand_score

import wandb

### Parameters
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mloicduch[0m ([33mloicduch-mcgill-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
# Seeds
pyg.seed_everything(10) # random, np, torch, torch.cuda

---
# Training

In [3]:
### Data
try:
    data = torch.load('inputs/KGE/triplets_data.pt')
except IOError:
    iKBH_dataset = IBKHDataset(data_dir='data/iBKH')
    data = iKBH_dataset.build_data()
    torch.save(data, 'inputs/KGE/triplets_data.pt')

  data = torch.load('inputs/KGE/triplets_data.pt')


In [4]:
### Device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
# device = torch.device('cpu')
print(f'Using device: {device}')

# DATA TO DEVICE
data.to(device)

Using device: mps


Data(edge_index=[2, 2778778], edge_type=[2778778], num_nodes=145609, num_edge_types=6)

In [5]:
### Parameters
config = {
    'model': dict(
        num_nodes = data.num_nodes,
        num_relations = data.num_edge_types,
        hidden_channels = 128,
        margin = 1.0, p_norm = 1),

    'dataloader': dict(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=4096,
        shuffle=True
    ),

    'training': dict(
        lr = 0.01,
        epochs = 15,
    ),

    'device': device,
}

### Model
kge_trainer = KGETrainer(
    model := TransE(**config['model']).to(device), # Model
    train_dataloader = model.loader(**config['dataloader']), # Dataloader
    val_dataloader = None,
    device = device,
    wandb_run = wandb.init(
        project ='GraphETM',
        group = 'iBKH-Embeddings',
        name = f'iBKH-TransE_{int(time.time())}',
        config=config, save_code=True) # Start Wandb
)

### Training
kge_trainer.fit(
    epochs = config['training']['epochs'],
    optimizer = optim.Adam(kge_trainer.model.parameters(), lr=config['training']['lr']) # Optimizer
)

Epoch 0: Training:   0%|          | 0/10185 [00:00<?, ?batch/s]

0,1
train/loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/loss,0.03761


In [7]:
### Extract Embeddings
rho = kge_trainer.model.node_emb.weight.detach().cpu()
torch.save(rho, 'inputs/GraphETM/embeddings.pt')

In [None]:
#DONE