# Demo for training TransE model on the FB15k_237 dataset

Code is based on [PyTorch Geometric Example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)

In [1]:
from torch_geometric.datasets import FB15k_237
from torch_geometric.data import Dataset

In [2]:
data_train = FB15k_237('./data/fb15k', split='train')[0]
data_val = FB15k_237('./data/fb15k', split='val')[0]
data_test = FB15k_237('./data/fb15k', split='test')[0]

In [3]:
print(f'# of graph:    {len(data_train)}')
print(f'# of nodes:    {data_train.num_nodes}')
print(f'# of edges:    {data_train.num_edges}')
print(f'# of node features: {data_train.num_node_features}')
print(f'# of edge features:    {data_train.num_edge_features}')
print(f'# of edge types: {data_train.num_edge_types}')

# of graph:    3
# of nodes:    14541
# of edges:    272115
# of node features: 0
# of edge features:    0
# of edge types: 237


In [4]:
num_nodes = data_train.num_nodes
num_relations = data_train.num_edge_types
hidden_channels = 50

In [5]:
from torch_geometric.nn import TransE
import torch
import torch.optim as optim

In [6]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [7]:
model = TransE(
    num_nodes=num_nodes,
    num_relations=num_relations,
    hidden_channels=hidden_channels,
)

In [8]:
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [9]:
from torch_geometric.nn.kge.loader import KGTripletLoader
from torch_geometric.nn.kge import KGEModel

In [10]:
loader_train = KGTripletLoader(
    head_index=data_train.edge_index[0],
    rel_type=data_train.edge_type,
    tail_index=data_train.edge_index[1],
    batch_size=2000,
    shuffle=True,
)


In [11]:
def train(model: KGEModel, loader: KGTripletLoader):
    total_loss = 0
    total_examples = 0

    model.train()
    for triple in loader:
        head_index, rel_type, tail_index = triple
        head_index = head_index.to(device)
        rel_type = rel_type.to(device)
        tail_index = tail_index.to(device)

        optimizer.zero_grad()
        loss = model.loss(head_index, rel_type, tail_index)
        loss.backward()
        optimizer.step()

        total_loss += float(loss) * head_index.numel()
        total_examples += head_index.numel()
        
    return total_loss / total_examples

In [12]:
@torch.no_grad()
def test(model: KGEModel, dataset: Dataset):
    model.eval()
    dataset = dataset.to(device)
    mean_rank, hits_at_k = model.test(
        head_index=dataset.edge_index[0],
        rel_type=dataset.edge_type,
        tail_index=dataset.edge_index[1],
        batch_size=10000,
        k=10,
    )
    return mean_rank, hits_at_k


In [13]:
model = model.to(device)
for epoch in range(1, 501):
    loss = train(model, loader_train)
    if epoch % 10 == 1:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    if epoch % 25 == 0:
        rank, hits = test(model, data_val)
        print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
              f'Val Hits@10: {hits:.4f}')

rank, hits_at_10 = test(model, data_test)
print(f'Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}')

Epoch: 001, Loss: 0.8016
Epoch: 011, Loss: 0.1819
Epoch: 021, Loss: 0.1415


100%|██████████| 17535/17535 [00:07<00:00, 2333.67it/s]


Epoch: 025, Val Mean Rank: 384.65, Val Hits@10: 0.3689
Epoch: 031, Loss: 0.1258
Epoch: 041, Loss: 0.1164


100%|██████████| 17535/17535 [00:06<00:00, 2519.45it/s]


Epoch: 050, Val Mean Rank: 335.48, Val Hits@10: 0.3617
Epoch: 051, Loss: 0.1101
Epoch: 061, Loss: 0.1058
Epoch: 071, Loss: 0.1010


100%|██████████| 17535/17535 [00:07<00:00, 2418.79it/s]


Epoch: 075, Val Mean Rank: 309.65, Val Hits@10: 0.3652
Epoch: 081, Loss: 0.0979
Epoch: 091, Loss: 0.0954


100%|██████████| 17535/17535 [00:07<00:00, 2399.59it/s]


Epoch: 100, Val Mean Rank: 299.04, Val Hits@10: 0.3667
Epoch: 101, Loss: 0.0942
Epoch: 111, Loss: 0.0926
Epoch: 121, Loss: 0.0897


100%|██████████| 17535/17535 [00:07<00:00, 2485.48it/s]


Epoch: 125, Val Mean Rank: 286.63, Val Hits@10: 0.3710
Epoch: 131, Loss: 0.0897
Epoch: 141, Loss: 0.0886


100%|██████████| 17535/17535 [00:07<00:00, 2396.28it/s]


Epoch: 150, Val Mean Rank: 280.67, Val Hits@10: 0.3629
Epoch: 151, Loss: 0.0873
Epoch: 161, Loss: 0.0857
Epoch: 171, Loss: 0.0855


100%|██████████| 17535/17535 [00:07<00:00, 2406.64it/s]


Epoch: 175, Val Mean Rank: 275.80, Val Hits@10: 0.3586
Epoch: 181, Loss: 0.0846
Epoch: 191, Loss: 0.0846


100%|██████████| 17535/17535 [00:07<00:00, 2370.66it/s]


Epoch: 200, Val Mean Rank: 274.89, Val Hits@10: 0.3697
Epoch: 201, Loss: 0.0826
Epoch: 211, Loss: 0.0831
Epoch: 221, Loss: 0.0828


100%|██████████| 17535/17535 [00:07<00:00, 2391.87it/s]


Epoch: 225, Val Mean Rank: 268.58, Val Hits@10: 0.3587
Epoch: 231, Loss: 0.0815
Epoch: 241, Loss: 0.0812


100%|██████████| 17535/17535 [00:07<00:00, 2390.04it/s]


Epoch: 250, Val Mean Rank: 265.55, Val Hits@10: 0.3698
Epoch: 251, Loss: 0.0805
Epoch: 261, Loss: 0.0799
Epoch: 271, Loss: 0.0803


100%|██████████| 17535/17535 [00:07<00:00, 2407.27it/s]


Epoch: 275, Val Mean Rank: 266.60, Val Hits@10: 0.3626
Epoch: 281, Loss: 0.0800
Epoch: 291, Loss: 0.0802


100%|██████████| 17535/17535 [00:07<00:00, 2429.14it/s]


Epoch: 300, Val Mean Rank: 261.46, Val Hits@10: 0.3575
Epoch: 301, Loss: 0.0795
Epoch: 311, Loss: 0.0797
Epoch: 321, Loss: 0.0794


100%|██████████| 17535/17535 [00:07<00:00, 2410.94it/s]


Epoch: 325, Val Mean Rank: 259.91, Val Hits@10: 0.3725
Epoch: 331, Loss: 0.0792
Epoch: 341, Loss: 0.0776


100%|██████████| 17535/17535 [00:07<00:00, 2339.84it/s]


Epoch: 350, Val Mean Rank: 262.63, Val Hits@10: 0.3706
Epoch: 351, Loss: 0.0775
Epoch: 361, Loss: 0.0776
Epoch: 371, Loss: 0.0777


100%|██████████| 17535/17535 [00:07<00:00, 2390.79it/s]


Epoch: 375, Val Mean Rank: 261.10, Val Hits@10: 0.3689
Epoch: 381, Loss: 0.0775
Epoch: 391, Loss: 0.0768


100%|██████████| 17535/17535 [00:06<00:00, 2525.08it/s]


Epoch: 400, Val Mean Rank: 256.16, Val Hits@10: 0.3726
Epoch: 401, Loss: 0.0775
Epoch: 411, Loss: 0.0763
Epoch: 421, Loss: 0.0766


100%|██████████| 17535/17535 [00:07<00:00, 2414.93it/s]


Epoch: 425, Val Mean Rank: 257.94, Val Hits@10: 0.3624
Epoch: 431, Loss: 0.0764
Epoch: 441, Loss: 0.0762


100%|██████████| 17535/17535 [00:07<00:00, 2376.71it/s]


Epoch: 450, Val Mean Rank: 260.97, Val Hits@10: 0.3658
Epoch: 451, Loss: 0.0768
Epoch: 461, Loss: 0.0752
Epoch: 471, Loss: 0.0765


100%|██████████| 17535/17535 [00:07<00:00, 2321.14it/s]


Epoch: 475, Val Mean Rank: 257.63, Val Hits@10: 0.3595
Epoch: 481, Loss: 0.0764
Epoch: 491, Loss: 0.0752


100%|██████████| 17535/17535 [00:07<00:00, 2365.09it/s]


Epoch: 500, Val Mean Rank: 255.19, Val Hits@10: 0.3694


100%|██████████| 20466/20466 [00:08<00:00, 2388.39it/s]

Test Mean Rank: 264.45, Test Hits@10: 0.3618



