## Author: Blaine Hill

In this notebook, we program out how to embed a KG such as FB15k_237 using RotatE. The weights are saved under embedding_model_weights.pth

In [21]:
import os
os.environ["WANDB_NOTEBOOK_NAME"] = "embedding_model.ipynb"

import wandb
wandb.login()

True

In [22]:
import torch
from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import RotatE
from torch_geometric.loader import DataLoader
import torch.optim as optim

import os
import os.path as osp

In [23]:

from datetime import datetime
from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx

dataset_name='FB15k_237'
embedding_model_name='RotatE'

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = osp.join(os.getcwd(), '..', 'data', dataset_name)

train_data = FB15k_237(path, split='train')[0].to(device)
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)

In [24]:
wandb.init(
    project=f"ScoreMatchingDiffKG_Embedding",
    name=f"{dataset_name}_{embedding_model_name}_embedding_model {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
    config={
        "epochs": 500,
        "batch_size": 1000,
        "lr": 0.001,
        "weight_decay": 1e-6,
        "k": 10 #used for top-k evaluation
    }
)

config = wandb.config

In [25]:
model = RotatE(num_nodes=train_data.num_nodes, num_relations=train_data.num_edge_types, hidden_channels=50).to(device) 

wandb.watch(model) #tracks gradients

[]

In [26]:
# Train model on FB15k dataset
train_loader = model.loader(
    head_index=train_data.edge_index[0],
    rel_type=train_data.edge_type,
    tail_index=train_data.edge_index[1],
    batch_size=config.batch_size,
    shuffle=True,
)
# val_loader = model.loader(
#     head_index=val_data.edge_index[0],
#     rel_type=val_data.edge_type,
#     tail_index=val_data.edge_index[1],
#     batch_size=config.batch_size,
#     shuffle=True,
# )
# test_loader = model.loader(
#     head_index=test_data.edge_index[0],
#     rel_type=test_data.edge_type,
#     tail_index=test_data.edge_index[1],
#     batch_size=config.batch_size,
#     shuffle=True,
# )
optimizer = optim.Adagrad(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)

In [27]:
def train(dataloader):
    model.train()
    total_loss = total_examples = 0
    for head_index, rel_type, tail_index in dataloader:
        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()
        #scale on specific number of elements in batch
        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
    return total_loss / total_examples

In [28]:
@torch.no_grad()
def test(data):
    model.eval()
    # Calculate loss
    loss = model.loss(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
    )
    # Evaluate the model using model.test()
    mean_rank, mrr, hits_at_k = model.test(
        head_index=data.edge_index[0],
        rel_type=data.edge_type,
        tail_index=data.edge_index[1],
        batch_size=config.batch_size,
        k=config.k,
    )
    return loss, mean_rank, mrr, hits_at_k

In [29]:
for epoch in range(config.epochs):
    epoch += 1 #for string printing
    loss = train(train_loader)
    print(f'Epoch: {epoch:03d}, Train Loss: {loss:.4f}')
    train_metrics ={
        "train_epoch": epoch-1, #fixes the index
        "train_loss": loss
    }
    if epoch % 25 == 0:
        loss, mean_rank, mrr, hits_at_10 = test(val_data)
        print(f'Val Mean Rank: {mean_rank:.2f}, Val Mean Reciprocal Rank: {mrr:.2f}, Val Hits@10: {hits_at_10:.4f}')
        val_metrics ={
            "val_loss": loss,
            "val_mean_rank": mean_rank,
            "val_mrr": mrr,
            "val_hits_at_10": hits_at_10
        }
    #log to wandb
    wandb.log({**train_metrics, **val_metrics} if 'val_metrics' in locals() else {**train_metrics})

#once everything is finished, test model
loss, mean_rank, mrr, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {mean_rank:.2f}, Test Mean Reciprocal Rank: {mrr:.2f}, Test Hits@10: {hits_at_10:.4f}')
test_metrics ={
            "test_loss": loss,
            "test_mean_rank": mean_rank,
            "test_mrr": mrr,
            "test_hits_at_10": hits_at_10
        }
wandb.log({**test_metrics})

Epoch: 001, Train Loss: 0.7774
Epoch: 002, Train Loss: 0.7709
Epoch: 003, Train Loss: 0.7668
Epoch: 004, Train Loss: 0.7635
Epoch: 005, Train Loss: 0.7608
Epoch: 006, Train Loss: 0.7584
Epoch: 007, Train Loss: 0.7562
Epoch: 008, Train Loss: 0.7542
Epoch: 009, Train Loss: 0.7524
Epoch: 010, Train Loss: 0.7507
Epoch: 011, Train Loss: 0.7491
Epoch: 012, Train Loss: 0.7477
Epoch: 013, Train Loss: 0.7462
Epoch: 014, Train Loss: 0.7449
Epoch: 015, Train Loss: 0.7436
Epoch: 016, Train Loss: 0.7423
Epoch: 017, Train Loss: 0.7411
Epoch: 018, Train Loss: 0.7400
Epoch: 019, Train Loss: 0.7388
Epoch: 020, Train Loss: 0.7378
Epoch: 021, Train Loss: 0.7367
Epoch: 022, Train Loss: 0.7357
Epoch: 023, Train Loss: 0.7348
Epoch: 024, Train Loss: 0.7338
Epoch: 025, Train Loss: 0.7329


100%|██████████| 17535/17535 [04:08<00:00, 70.62it/s]


Val Mean Rank: 4517.82, Val Mean Reciprocal Rank: 0.00, Val Hits@10: 0.0018
Epoch: 026, Train Loss: 0.7320
Epoch: 027, Train Loss: 0.7310
Epoch: 028, Train Loss: 0.7302
Epoch: 029, Train Loss: 0.7294
Epoch: 030, Train Loss: 0.7286
Epoch: 031, Train Loss: 0.7278
Epoch: 032, Train Loss: 0.7269
Epoch: 033, Train Loss: 0.7262
Epoch: 034, Train Loss: 0.7254
Epoch: 035, Train Loss: 0.7247
Epoch: 036, Train Loss: 0.7240
Epoch: 037, Train Loss: 0.7233
Epoch: 038, Train Loss: 0.7225
Epoch: 039, Train Loss: 0.7219
Epoch: 040, Train Loss: 0.7212
Epoch: 041, Train Loss: 0.7205
Epoch: 042, Train Loss: 0.7198
Epoch: 043, Train Loss: 0.7192
Epoch: 044, Train Loss: 0.7186
Epoch: 045, Train Loss: 0.7179
Epoch: 046, Train Loss: 0.7173
Epoch: 047, Train Loss: 0.7166
Epoch: 048, Train Loss: 0.7161
Epoch: 049, Train Loss: 0.7155
Epoch: 050, Train Loss: 0.7149


100%|██████████| 17535/17535 [04:07<00:00, 70.91it/s]


Val Mean Rank: 3340.89, Val Mean Reciprocal Rank: 0.02, Val Hits@10: 0.0376
Epoch: 051, Train Loss: 0.7143
Epoch: 052, Train Loss: 0.7137
Epoch: 053, Train Loss: 0.7132
Epoch: 054, Train Loss: 0.7126
Epoch: 055, Train Loss: 0.7121
Epoch: 056, Train Loss: 0.7116
Epoch: 057, Train Loss: 0.7110
Epoch: 058, Train Loss: 0.7104
Epoch: 059, Train Loss: 0.7100
Epoch: 060, Train Loss: 0.7094
Epoch: 061, Train Loss: 0.7089
Epoch: 062, Train Loss: 0.7085
Epoch: 063, Train Loss: 0.7080
Epoch: 064, Train Loss: 0.7074
Epoch: 065, Train Loss: 0.7069
Epoch: 066, Train Loss: 0.7065
Epoch: 067, Train Loss: 0.7059
Epoch: 068, Train Loss: 0.7055
Epoch: 069, Train Loss: 0.7050
Epoch: 070, Train Loss: 0.7046
Epoch: 071, Train Loss: 0.7041
Epoch: 072, Train Loss: 0.7036
Epoch: 073, Train Loss: 0.7032
Epoch: 074, Train Loss: 0.7028
Epoch: 075, Train Loss: 0.7022


100%|██████████| 17535/17535 [04:00<00:00, 72.88it/s]


Val Mean Rank: 2981.51, Val Mean Reciprocal Rank: 0.03, Val Hits@10: 0.0676
Epoch: 076, Train Loss: 0.7019
Epoch: 077, Train Loss: 0.7014
Epoch: 078, Train Loss: 0.7010
Epoch: 079, Train Loss: 0.7006
Epoch: 080, Train Loss: 0.7001
Epoch: 081, Train Loss: 0.6997
Epoch: 082, Train Loss: 0.6993
Epoch: 083, Train Loss: 0.6988
Epoch: 084, Train Loss: 0.6984
Epoch: 085, Train Loss: 0.6980
Epoch: 086, Train Loss: 0.6977
Epoch: 087, Train Loss: 0.6972
Epoch: 088, Train Loss: 0.6968
Epoch: 089, Train Loss: 0.6964
Epoch: 090, Train Loss: 0.6960
Epoch: 091, Train Loss: 0.6956
Epoch: 092, Train Loss: 0.6953
Epoch: 093, Train Loss: 0.6949
Epoch: 094, Train Loss: 0.6945
Epoch: 095, Train Loss: 0.6941
Epoch: 096, Train Loss: 0.6937
Epoch: 097, Train Loss: 0.6934
Epoch: 098, Train Loss: 0.6930
Epoch: 099, Train Loss: 0.6927
Epoch: 100, Train Loss: 0.6923


100%|██████████| 17535/17535 [04:10<00:00, 69.90it/s]


Val Mean Rank: 2799.25, Val Mean Reciprocal Rank: 0.04, Val Hits@10: 0.0813
Epoch: 101, Train Loss: 0.6919
Epoch: 102, Train Loss: 0.6916
Epoch: 103, Train Loss: 0.6912
Epoch: 104, Train Loss: 0.6908
Epoch: 105, Train Loss: 0.6906
Epoch: 106, Train Loss: 0.6901
Epoch: 107, Train Loss: 0.6898
Epoch: 108, Train Loss: 0.6894
Epoch: 109, Train Loss: 0.6891
Epoch: 110, Train Loss: 0.6888
Epoch: 111, Train Loss: 0.6884
Epoch: 112, Train Loss: 0.6882
Epoch: 113, Train Loss: 0.6877
Epoch: 114, Train Loss: 0.6875
Epoch: 115, Train Loss: 0.6871
Epoch: 116, Train Loss: 0.6868
Epoch: 117, Train Loss: 0.6865
Epoch: 118, Train Loss: 0.6861
Epoch: 119, Train Loss: 0.6859
Epoch: 120, Train Loss: 0.6854
Epoch: 121, Train Loss: 0.6852
Epoch: 122, Train Loss: 0.6849
Epoch: 123, Train Loss: 0.6846
Epoch: 124, Train Loss: 0.6843
Epoch: 125, Train Loss: 0.6840


100%|██████████| 17535/17535 [04:00<00:00, 73.01it/s]


Val Mean Rank: 2682.05, Val Mean Reciprocal Rank: 0.04, Val Hits@10: 0.0902
Epoch: 126, Train Loss: 0.6836
Epoch: 127, Train Loss: 0.6833
Epoch: 128, Train Loss: 0.6829
Epoch: 129, Train Loss: 0.6827
Epoch: 130, Train Loss: 0.6826
Epoch: 131, Train Loss: 0.6821
Epoch: 132, Train Loss: 0.6818
Epoch: 133, Train Loss: 0.6815
Epoch: 134, Train Loss: 0.6813
Epoch: 135, Train Loss: 0.6809
Epoch: 136, Train Loss: 0.6806
Epoch: 137, Train Loss: 0.6803
Epoch: 138, Train Loss: 0.6801
Epoch: 139, Train Loss: 0.6798
Epoch: 140, Train Loss: 0.6794
Epoch: 141, Train Loss: 0.6793
Epoch: 142, Train Loss: 0.6789
Epoch: 143, Train Loss: 0.6785
Epoch: 144, Train Loss: 0.6783
Epoch: 145, Train Loss: 0.6780
Epoch: 146, Train Loss: 0.6778
Epoch: 147, Train Loss: 0.6774
Epoch: 148, Train Loss: 0.6773
Epoch: 149, Train Loss: 0.6770
Epoch: 150, Train Loss: 0.6767


100%|██████████| 17535/17535 [04:03<00:00, 71.95it/s]


Val Mean Rank: 2596.45, Val Mean Reciprocal Rank: 0.05, Val Hits@10: 0.1003
Epoch: 151, Train Loss: 0.6764
Epoch: 152, Train Loss: 0.6762
Epoch: 153, Train Loss: 0.6758
Epoch: 154, Train Loss: 0.6755
Epoch: 155, Train Loss: 0.6754
Epoch: 156, Train Loss: 0.6751
Epoch: 157, Train Loss: 0.6748
Epoch: 158, Train Loss: 0.6746
Epoch: 159, Train Loss: 0.6743
Epoch: 160, Train Loss: 0.6739
Epoch: 161, Train Loss: 0.6736
Epoch: 162, Train Loss: 0.6735
Epoch: 163, Train Loss: 0.6731
Epoch: 164, Train Loss: 0.6731
Epoch: 165, Train Loss: 0.6726
Epoch: 166, Train Loss: 0.6724
Epoch: 167, Train Loss: 0.6721
Epoch: 168, Train Loss: 0.6721
Epoch: 169, Train Loss: 0.6715
Epoch: 170, Train Loss: 0.6714
Epoch: 171, Train Loss: 0.6714
Epoch: 172, Train Loss: 0.6710
Epoch: 173, Train Loss: 0.6708
Epoch: 174, Train Loss: 0.6705
Epoch: 175, Train Loss: 0.6704


100%|██████████| 17535/17535 [04:14<00:00, 69.02it/s]


Val Mean Rank: 2526.99, Val Mean Reciprocal Rank: 0.05, Val Hits@10: 0.1081
Epoch: 176, Train Loss: 0.6700
Epoch: 177, Train Loss: 0.6698
Epoch: 178, Train Loss: 0.6695
Epoch: 179, Train Loss: 0.6693
Epoch: 180, Train Loss: 0.6691
Epoch: 181, Train Loss: 0.6688
Epoch: 182, Train Loss: 0.6686
Epoch: 183, Train Loss: 0.6684
Epoch: 184, Train Loss: 0.6680
Epoch: 185, Train Loss: 0.6679
Epoch: 186, Train Loss: 0.6676
Epoch: 187, Train Loss: 0.6675
Epoch: 188, Train Loss: 0.6672
Epoch: 189, Train Loss: 0.6670
Epoch: 190, Train Loss: 0.6669
Epoch: 191, Train Loss: 0.6665
Epoch: 192, Train Loss: 0.6663
Epoch: 193, Train Loss: 0.6662
Epoch: 194, Train Loss: 0.6659
Epoch: 195, Train Loss: 0.6655
Epoch: 196, Train Loss: 0.6653
Epoch: 197, Train Loss: 0.6651
Epoch: 198, Train Loss: 0.6650
Epoch: 199, Train Loss: 0.6648
Epoch: 200, Train Loss: 0.6646


100%|██████████| 17535/17535 [04:02<00:00, 72.44it/s]


Val Mean Rank: 2468.34, Val Mean Reciprocal Rank: 0.06, Val Hits@10: 0.1151
Epoch: 201, Train Loss: 0.6643
Epoch: 202, Train Loss: 0.6641
Epoch: 203, Train Loss: 0.6638
Epoch: 204, Train Loss: 0.6635
Epoch: 205, Train Loss: 0.6634
Epoch: 206, Train Loss: 0.6633
Epoch: 207, Train Loss: 0.6629
Epoch: 208, Train Loss: 0.6628
Epoch: 209, Train Loss: 0.6625
Epoch: 210, Train Loss: 0.6624
Epoch: 211, Train Loss: 0.6621
Epoch: 212, Train Loss: 0.6619
Epoch: 213, Train Loss: 0.6616
Epoch: 214, Train Loss: 0.6616
Epoch: 215, Train Loss: 0.6613
Epoch: 216, Train Loss: 0.6610
Epoch: 217, Train Loss: 0.6610
Epoch: 218, Train Loss: 0.6606
Epoch: 219, Train Loss: 0.6605
Epoch: 220, Train Loss: 0.6602
Epoch: 221, Train Loss: 0.6600
Epoch: 222, Train Loss: 0.6600
Epoch: 223, Train Loss: 0.6598
Epoch: 224, Train Loss: 0.6595
Epoch: 225, Train Loss: 0.6592


100%|██████████| 17535/17535 [04:16<00:00, 68.27it/s]


Val Mean Rank: 2416.84, Val Mean Reciprocal Rank: 0.06, Val Hits@10: 0.1217
Epoch: 226, Train Loss: 0.6590
Epoch: 227, Train Loss: 0.6588
Epoch: 228, Train Loss: 0.6587
Epoch: 229, Train Loss: 0.6584
Epoch: 230, Train Loss: 0.6584
Epoch: 231, Train Loss: 0.6580
Epoch: 232, Train Loss: 0.6580
Epoch: 233, Train Loss: 0.6577
Epoch: 234, Train Loss: 0.6574
Epoch: 235, Train Loss: 0.6573
Epoch: 236, Train Loss: 0.6571
Epoch: 237, Train Loss: 0.6569
Epoch: 238, Train Loss: 0.6567
Epoch: 239, Train Loss: 0.6564
Epoch: 240, Train Loss: 0.6563
Epoch: 241, Train Loss: 0.6560
Epoch: 242, Train Loss: 0.6559
Epoch: 243, Train Loss: 0.6558
Epoch: 244, Train Loss: 0.6555
Epoch: 245, Train Loss: 0.6555
Epoch: 246, Train Loss: 0.6552
Epoch: 247, Train Loss: 0.6549
Epoch: 248, Train Loss: 0.6548
Epoch: 249, Train Loss: 0.6547
Epoch: 250, Train Loss: 0.6545


100%|██████████| 17535/17535 [04:07<00:00, 70.87it/s]


Val Mean Rank: 2370.21, Val Mean Reciprocal Rank: 0.06, Val Hits@10: 0.1276
Epoch: 251, Train Loss: 0.6543
Epoch: 252, Train Loss: 0.6541
Epoch: 253, Train Loss: 0.6540
Epoch: 254, Train Loss: 0.6536
Epoch: 255, Train Loss: 0.6536
Epoch: 256, Train Loss: 0.6533
Epoch: 257, Train Loss: 0.6532
Epoch: 258, Train Loss: 0.6531
Epoch: 259, Train Loss: 0.6529
Epoch: 260, Train Loss: 0.6526
Epoch: 261, Train Loss: 0.6524
Epoch: 262, Train Loss: 0.6522
Epoch: 263, Train Loss: 0.6521
Epoch: 264, Train Loss: 0.6519
Epoch: 265, Train Loss: 0.6519
Epoch: 266, Train Loss: 0.6516
Epoch: 267, Train Loss: 0.6514
Epoch: 268, Train Loss: 0.6513
Epoch: 269, Train Loss: 0.6511
Epoch: 270, Train Loss: 0.6509
Epoch: 271, Train Loss: 0.6505
Epoch: 272, Train Loss: 0.6505
Epoch: 273, Train Loss: 0.6505
Epoch: 274, Train Loss: 0.6503
Epoch: 275, Train Loss: 0.6500


100%|██████████| 17535/17535 [04:11<00:00, 69.63it/s]


Val Mean Rank: 2326.99, Val Mean Reciprocal Rank: 0.07, Val Hits@10: 0.1317
Epoch: 276, Train Loss: 0.6498
Epoch: 277, Train Loss: 0.6496
Epoch: 278, Train Loss: 0.6494
Epoch: 279, Train Loss: 0.6494
Epoch: 280, Train Loss: 0.6492
Epoch: 281, Train Loss: 0.6491
Epoch: 282, Train Loss: 0.6489
Epoch: 283, Train Loss: 0.6487
Epoch: 284, Train Loss: 0.6485
Epoch: 285, Train Loss: 0.6483
Epoch: 286, Train Loss: 0.6481
Epoch: 287, Train Loss: 0.6480
Epoch: 288, Train Loss: 0.6479
Epoch: 289, Train Loss: 0.6475
Epoch: 290, Train Loss: 0.6474
Epoch: 291, Train Loss: 0.6474
Epoch: 292, Train Loss: 0.6471
Epoch: 293, Train Loss: 0.6471
Epoch: 294, Train Loss: 0.6468
Epoch: 295, Train Loss: 0.6467
Epoch: 296, Train Loss: 0.6465
Epoch: 297, Train Loss: 0.6463
Epoch: 298, Train Loss: 0.6460
Epoch: 299, Train Loss: 0.6460
Epoch: 300, Train Loss: 0.6459


100%|██████████| 17535/17535 [03:56<00:00, 74.04it/s]


Val Mean Rank: 2286.27, Val Mean Reciprocal Rank: 0.07, Val Hits@10: 0.1349
Epoch: 301, Train Loss: 0.6456
Epoch: 302, Train Loss: 0.6454
Epoch: 303, Train Loss: 0.6454
Epoch: 304, Train Loss: 0.6452
Epoch: 305, Train Loss: 0.6450
Epoch: 306, Train Loss: 0.6449
Epoch: 307, Train Loss: 0.6447
Epoch: 308, Train Loss: 0.6447
Epoch: 309, Train Loss: 0.6443
Epoch: 310, Train Loss: 0.6444
Epoch: 311, Train Loss: 0.6442
Epoch: 312, Train Loss: 0.6441
Epoch: 313, Train Loss: 0.6438
Epoch: 314, Train Loss: 0.6438
Epoch: 315, Train Loss: 0.6434
Epoch: 316, Train Loss: 0.6434
Epoch: 317, Train Loss: 0.6432
Epoch: 318, Train Loss: 0.6432
Epoch: 319, Train Loss: 0.6429
Epoch: 320, Train Loss: 0.6426
Epoch: 321, Train Loss: 0.6426
Epoch: 322, Train Loss: 0.6426
Epoch: 323, Train Loss: 0.6424
Epoch: 324, Train Loss: 0.6421
Epoch: 325, Train Loss: 0.6419


100%|██████████| 17535/17535 [04:08<00:00, 70.57it/s]


Val Mean Rank: 2248.36, Val Mean Reciprocal Rank: 0.07, Val Hits@10: 0.1418
Epoch: 326, Train Loss: 0.6418
Epoch: 327, Train Loss: 0.6417
Epoch: 328, Train Loss: 0.6416
Epoch: 329, Train Loss: 0.6414
Epoch: 330, Train Loss: 0.6412
Epoch: 331, Train Loss: 0.6411
Epoch: 332, Train Loss: 0.6409
Epoch: 333, Train Loss: 0.6408
Epoch: 334, Train Loss: 0.6407
Epoch: 335, Train Loss: 0.6407
Epoch: 336, Train Loss: 0.6405
Epoch: 337, Train Loss: 0.6403
Epoch: 338, Train Loss: 0.6403
Epoch: 339, Train Loss: 0.6400
Epoch: 340, Train Loss: 0.6400
Epoch: 341, Train Loss: 0.6397
Epoch: 342, Train Loss: 0.6394
Epoch: 343, Train Loss: 0.6394
Epoch: 344, Train Loss: 0.6392
Epoch: 345, Train Loss: 0.6391
Epoch: 346, Train Loss: 0.6390
Epoch: 347, Train Loss: 0.6388
Epoch: 348, Train Loss: 0.6386
Epoch: 349, Train Loss: 0.6385
Epoch: 350, Train Loss: 0.6384


100%|██████████| 17535/17535 [04:06<00:00, 71.08it/s]


Val Mean Rank: 2212.24, Val Mean Reciprocal Rank: 0.07, Val Hits@10: 0.1468
Epoch: 351, Train Loss: 0.6383
Epoch: 352, Train Loss: 0.6382
Epoch: 353, Train Loss: 0.6380
Epoch: 354, Train Loss: 0.6378
Epoch: 355, Train Loss: 0.6378
Epoch: 356, Train Loss: 0.6376
Epoch: 357, Train Loss: 0.6374
Epoch: 358, Train Loss: 0.6374
Epoch: 359, Train Loss: 0.6371
Epoch: 360, Train Loss: 0.6371
Epoch: 361, Train Loss: 0.6368
Epoch: 362, Train Loss: 0.6365
Epoch: 363, Train Loss: 0.6367
Epoch: 364, Train Loss: 0.6365
Epoch: 365, Train Loss: 0.6362
Epoch: 366, Train Loss: 0.6363
Epoch: 367, Train Loss: 0.6360
Epoch: 368, Train Loss: 0.6361
Epoch: 369, Train Loss: 0.6357
Epoch: 370, Train Loss: 0.6358
Epoch: 371, Train Loss: 0.6354
Epoch: 372, Train Loss: 0.6353
Epoch: 373, Train Loss: 0.6353
Epoch: 374, Train Loss: 0.6352
Epoch: 375, Train Loss: 0.6350


100%|██████████| 17535/17535 [04:06<00:00, 71.02it/s]


Val Mean Rank: 2177.77, Val Mean Reciprocal Rank: 0.08, Val Hits@10: 0.1526
Epoch: 376, Train Loss: 0.6348
Epoch: 377, Train Loss: 0.6349
Epoch: 378, Train Loss: 0.6344
Epoch: 379, Train Loss: 0.6344
Epoch: 380, Train Loss: 0.6343
Epoch: 381, Train Loss: 0.6341
Epoch: 382, Train Loss: 0.6340
Epoch: 383, Train Loss: 0.6339
Epoch: 384, Train Loss: 0.6337
Epoch: 385, Train Loss: 0.6336
Epoch: 386, Train Loss: 0.6336
Epoch: 387, Train Loss: 0.6334
Epoch: 388, Train Loss: 0.6332
Epoch: 389, Train Loss: 0.6332
Epoch: 390, Train Loss: 0.6331
Epoch: 391, Train Loss: 0.6329
Epoch: 392, Train Loss: 0.6330
Epoch: 393, Train Loss: 0.6327
Epoch: 394, Train Loss: 0.6326
Epoch: 395, Train Loss: 0.6324
Epoch: 396, Train Loss: 0.6323
Epoch: 397, Train Loss: 0.6320
Epoch: 398, Train Loss: 0.6318
Epoch: 399, Train Loss: 0.6319
Epoch: 400, Train Loss: 0.6317


100%|██████████| 17535/17535 [03:55<00:00, 74.54it/s]


Val Mean Rank: 2144.64, Val Mean Reciprocal Rank: 0.08, Val Hits@10: 0.1581
Epoch: 401, Train Loss: 0.6317
Epoch: 402, Train Loss: 0.6316
Epoch: 403, Train Loss: 0.6313
Epoch: 404, Train Loss: 0.6312
Epoch: 405, Train Loss: 0.6310
Epoch: 406, Train Loss: 0.6307
Epoch: 407, Train Loss: 0.6309
Epoch: 408, Train Loss: 0.6306
Epoch: 409, Train Loss: 0.6307
Epoch: 410, Train Loss: 0.6305
Epoch: 411, Train Loss: 0.6303
Epoch: 412, Train Loss: 0.6302
Epoch: 413, Train Loss: 0.6302
Epoch: 414, Train Loss: 0.6300
Epoch: 415, Train Loss: 0.6300
Epoch: 416, Train Loss: 0.6296
Epoch: 417, Train Loss: 0.6296
Epoch: 418, Train Loss: 0.6294
Epoch: 419, Train Loss: 0.6293
Epoch: 420, Train Loss: 0.6294
Epoch: 421, Train Loss: 0.6290
Epoch: 422, Train Loss: 0.6290
Epoch: 423, Train Loss: 0.6287
Epoch: 424, Train Loss: 0.6288
Epoch: 425, Train Loss: 0.6286


100%|██████████| 17535/17535 [04:00<00:00, 72.97it/s]


Val Mean Rank: 2112.91, Val Mean Reciprocal Rank: 0.08, Val Hits@10: 0.1627
Epoch: 426, Train Loss: 0.6285
Epoch: 427, Train Loss: 0.6285
Epoch: 428, Train Loss: 0.6282
Epoch: 429, Train Loss: 0.6281
Epoch: 430, Train Loss: 0.6279
Epoch: 431, Train Loss: 0.6279
Epoch: 432, Train Loss: 0.6279
Epoch: 433, Train Loss: 0.6276
Epoch: 434, Train Loss: 0.6276
Epoch: 435, Train Loss: 0.6275
Epoch: 436, Train Loss: 0.6272
Epoch: 437, Train Loss: 0.6270
Epoch: 438, Train Loss: 0.6271
Epoch: 439, Train Loss: 0.6270
Epoch: 440, Train Loss: 0.6267
Epoch: 441, Train Loss: 0.6267
Epoch: 442, Train Loss: 0.6268
Epoch: 443, Train Loss: 0.6266
Epoch: 444, Train Loss: 0.6266
Epoch: 445, Train Loss: 0.6265
Epoch: 446, Train Loss: 0.6261
Epoch: 447, Train Loss: 0.6262
Epoch: 448, Train Loss: 0.6259
Epoch: 449, Train Loss: 0.6259
Epoch: 450, Train Loss: 0.6259


100%|██████████| 17535/17535 [03:56<00:00, 74.25it/s]


Val Mean Rank: 2082.62, Val Mean Reciprocal Rank: 0.09, Val Hits@10: 0.1661
Epoch: 451, Train Loss: 0.6255
Epoch: 452, Train Loss: 0.6254
Epoch: 453, Train Loss: 0.6253
Epoch: 454, Train Loss: 0.6253
Epoch: 455, Train Loss: 0.6254
Epoch: 456, Train Loss: 0.6249
Epoch: 457, Train Loss: 0.6250
Epoch: 458, Train Loss: 0.6249
Epoch: 459, Train Loss: 0.6249
Epoch: 460, Train Loss: 0.6246
Epoch: 461, Train Loss: 0.6246
Epoch: 462, Train Loss: 0.6243
Epoch: 463, Train Loss: 0.6243
Epoch: 464, Train Loss: 0.6242
Epoch: 465, Train Loss: 0.6241
Epoch: 466, Train Loss: 0.6241
Epoch: 467, Train Loss: 0.6239
Epoch: 468, Train Loss: 0.6237
Epoch: 469, Train Loss: 0.6235
Epoch: 470, Train Loss: 0.6234
Epoch: 471, Train Loss: 0.6233
Epoch: 472, Train Loss: 0.6234
Epoch: 473, Train Loss: 0.6232
Epoch: 474, Train Loss: 0.6227
Epoch: 475, Train Loss: 0.6231


100%|██████████| 17535/17535 [04:03<00:00, 71.95it/s]


Val Mean Rank: 2053.02, Val Mean Reciprocal Rank: 0.09, Val Hits@10: 0.1706
Epoch: 476, Train Loss: 0.6228
Epoch: 477, Train Loss: 0.6229
Epoch: 478, Train Loss: 0.6225
Epoch: 479, Train Loss: 0.6223
Epoch: 480, Train Loss: 0.6223
Epoch: 481, Train Loss: 0.6223
Epoch: 482, Train Loss: 0.6221
Epoch: 483, Train Loss: 0.6222
Epoch: 484, Train Loss: 0.6219
Epoch: 485, Train Loss: 0.6218
Epoch: 486, Train Loss: 0.6217
Epoch: 487, Train Loss: 0.6216
Epoch: 488, Train Loss: 0.6215
Epoch: 489, Train Loss: 0.6213
Epoch: 490, Train Loss: 0.6214
Epoch: 491, Train Loss: 0.6211
Epoch: 492, Train Loss: 0.6211
Epoch: 493, Train Loss: 0.6210
Epoch: 494, Train Loss: 0.6209
Epoch: 495, Train Loss: 0.6209
Epoch: 496, Train Loss: 0.6205
Epoch: 497, Train Loss: 0.6207
Epoch: 498, Train Loss: 0.6206
Epoch: 499, Train Loss: 0.6202
Epoch: 500, Train Loss: 0.6202


100%|██████████| 17535/17535 [04:06<00:00, 71.11it/s]


Val Mean Rank: 2024.23, Val Mean Reciprocal Rank: 0.09, Val Hits@10: 0.1752


100%|██████████| 20466/20466 [04:35<00:00, 74.42it/s]

Test Mean Rank: 2069.01, Test Mean Reciprocal Rank: 0.09, Test Hits@10: 0.1738





In [30]:
path = osp.join(os.getcwd(), f'{dataset_name}_embedding_model_weights.pth')
torch.save(model.state_dict(), path)

# Fetch a batch from your train_loader
for batch in train_loader:
    # Assuming batch contains head_index, rel_type, tail_index, and possibly other data
    head_index, rel_type, tail_index = batch
    break  # Only need one batch for this purpose

# Use the fetched batch to provide dummy inputs for the export
# Ensure these variables are moved to the same device as your model if necessary
torch.onnx.export(
    model,
    (head_index, rel_type, tail_index),  # Use actual data as dummy inputs
    f'{dataset_name}_embedding_model_weights.onnx',
    opset_version=11,
    do_constant_folding=True,
    input_names=["head_index", "rel_type", "tail_index"],  # Adjust input names as needed
    dynamic_axes={
        "head_index": {0: "batch_size"}, 
        "rel_type": {0: "batch_size"}, 
        "tail_index": {0: "batch_size"}
    }
)
wandb.save(f'{dataset_name}_embedding_model_weights.onnx')
wandb.finish()

0,1
test_hits_at_10,▁
test_loss,▁
test_mean_rank,▁
test_mrr,▁
train_epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▇▆▆▆▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
val_hits_at_10,▁▁▁▁▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇██████
val_loss,██▆▆▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
val_mean_rank,██▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_mrr,▁▁▁▁▂▂▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇██████

0,1
test_hits_at_10,0.17385
test_loss,0.6314
test_mean_rank,2069.01172
test_mrr,0.09284
train_epoch,499.0
train_loss,0.62024
val_hits_at_10,0.17519
val_loss,0.63112
val_mean_rank,2024.2251
val_mrr,0.09036
