In [12]:
from torch import cuda
from torch.optim import Adam

from torchkge.models import TransEModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss, DataLoader
from torchkge.utils.datasets import load_fb15k
from torchkge.inference import RelationInference

from tqdm.autonotebook import tqdm

In [13]:
# Load dataset
kg_train, _, _ = load_fb15k()

In [14]:
# Define some hyper-parameters for training
emb_dim = 100
lr = 0.0004
n_epochs = 1000
b_size = 32768
margin = 0.5

In [15]:
# Define the model and criterion
model = TransEModel(emb_dim, kg_train.n_ent, kg_train.n_rel, dissimilarity_type='L2')
criterion = MarginLoss(margin)

# Move everything to CUDA if available
if cuda.is_available():
    cuda.empty_cache()
    model.cuda()
    criterion.cuda()

# Define the torch optimizer to be used
optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)

sampler = BernoulliNegativeSampler(kg_train)
dataloader = DataLoader(kg_train, batch_size=b_size, use_cuda='all')

iterator = tqdm(range(n_epochs), unit='epoch')
for epoch in iterator:
    running_loss = 0.0
    for i, batch in enumerate(dataloader):
        h, t, r = batch[0], batch[1], batch[2]
        n_h, n_t = sampler.corrupt_batch(h, t, r)

        optimizer.zero_grad()

        # forward + backward + optimize
        pos, neg = model(h, t, r, n_h, n_t)
        loss = criterion(pos, neg)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    iterator.set_description(
        'Epoch {} | mean loss: {:.5f}'.format(epoch + 1,
                                              running_loss / len(dataloader)))

model.normalize_parameters()

  0%|          | 0/1000 [00:00<?, ?epoch/s]

In [20]:
import torch
e1 = torch.LongTensor([3920, 839])
e2 = torch.LongTensor([8170, 13723])
result = RelationInference(model, e1, e2, top_k=3)
result.predictions

In [21]:
result.predictions

tensor([[0, 0, 0],
        [0, 0, 0]])

In [22]:
kg_train.head_idx

tensor([ 3920,   839, 10094,  ...,  8170, 13723,  6647])