# Translational Embeddings for Modeling Multi-relational Data
In this notebook we implement the [TransE](https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data) model and test its performance for the task of link prediction on the WordNet18RR dataset.

### Imports and helper Functions

In [1]:
import os
import numpy as np
import itertools
from contextlib import suppress
from typing import List, Union, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch_geometric.datasets import WordNet18RR, WordNet18
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# needed 
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# check if CUDA device is available
torch.cuda.is_available()

True

Both versions of the WordNet18 dataset have been imported through the PyTorch Geometric library. We highlight that training/testing on the original WN18 was done solely to compare the implemented model best performance against the results from **Borders et al.** (table 3 page 6).

In [2]:
# download wordnet dataset
WordNet18('./WordNet18/')
WordNet18RR('./WordNet18RR/')

WordNet18RR()

### Custom Dataset and DataLoader

Here we define a custom dataset that reads from the raw directory of the WordNet18 downloaded in the previous section.

In [3]:
class Edge():
    def __init__(self, head, tail, rel) -> None:
        self.head = head
        self.tail = tail
        self.rel = rel

    def __str__(self) -> str:
        return f"{self.head} {self.rel} {self.tail}"

def process_lines(lines: List[str], delim: str='\t'):
    """ cleans up the input set of strings """
    return list(map(lambda s: s.strip('\n').split(delim), lines))

def load_edges_from_file(path: str, is_wn18: bool=True):
    """ read edges from the text file in raw, considering the different
        formats of RR (head, rel, tail) and original version (head, tail, rel) """
    edge_list = list()

    lines = open(path).readlines()

    # WN18 contains a header line and has a format (head, tail, rel)
    if is_wn18:
        lines = lines[1:]
        delim = ' '
    else:
        delim = '\t'

    lines = process_lines(lines, delim=delim)

    # the two WN version have a different format to represent edges/relation
    if is_wn18:
        edge_list = [Edge(head=head, tail=tail, rel=rel) for head, tail, rel in lines]
    else:
        edge_list = [Edge(head=head, tail=tail, rel=rel) for head, rel, tail in lines]
    
    return edge_list

def load_ids_dict(path: str) -> Union[dict, dict]:
    """ reads and return the dictionaries entity->id and relation->id 
        from the specified location """
    
    assert(os.path.exists(path))

    entity2id = process_lines(open(os.path.join(path, "entity2id.txt")))
    relation2id = process_lines(open(os.path.join(path, "relation2id.txt")))

    entity2id = dict([(x[0], int(x[1])) for x in entity2id])
    relation2id = dict([(x[0], int(x[1])) for x in relation2id])

    return entity2id, relation2id

def create_id_mappings(dataset_str: str="WordNet18RR") -> None:
    """ creates the mapping ids inside the raw directory of the 
        specified version of WordNet18 """

    assert(dataset_str in ["WordNet18", "WordNet18RR"])

    is_wn18 = dataset_str == "WordNet18"
    path = f"./{dataset_str}/raw/"

    if not os.path.exists(path):
        print(f"Directory {path} does not exist")
        return
    
    # read edge_list from the raw text files
    train_edge_list = load_edges_from_file(os.path.join(path, "train.txt"), is_wn18=is_wn18)
    val_edge_list = load_edges_from_file(os.path.join(path, "valid.txt"), is_wn18=is_wn18)
    test_edge_list = load_edges_from_file(os.path.join(path, "test.txt"), is_wn18=is_wn18)

    entity_list = list()
    relation_list = list()

    # assign unique id to each entity/relation
    for edge_list in [train_edge_list, val_edge_list, test_edge_list]:
        entity_list += [x.head for x in edge_list] + [x.tail for x in edge_list]
        relation_list += [x.rel for x in edge_list]

    entity_list = sorted(list(set(entity_list)))
    entity2id = dict(zip(entity_list, range(len(entity_list))))

    relation_list = sorted(list(set(relation_list)))
    relation2id = dict(zip(relation_list, range(len(relation_list))))

    # save the generated mappings into the raw directory
    with open(os.path.join(path, "entity2id.txt"), "w") as f:
        f.writelines([f"{x}\t{y}\n" for x,y in entity2id.items()])

    with open(os.path.join(path, "relation2id.txt"), "w") as f:
        f.writelines([f"{x}\t{y}\n" for x,y in relation2id.items()])

To generate the mappings we simply call the *create_id_mappings()* function specifying the WN version we want to work with.

In [4]:
create_id_mappings("WordNet18RR")

Next we define a PyTorch dataset and a Data Module that can be handled by PyTorch Lightning.

In [5]:
class WordNetDataset(Dataset):
    def __init__(self, dataset: str="WordNet18RR", split="train") -> None:
        super().__init__()
        self.path = f"./{dataset}/raw"

        if split == 'val':
            split = 'valid'
        self.split = split

        is_wn18 = dataset == "WordNet18"

        edge_list = load_edges_from_file(os.path.join(self.path, f"{self.split}.txt"), is_wn18=is_wn18)
        entity2id, relation2id = load_ids_dict(path=self.path)

        self.edge_list = torch.tensor([(entity2id[e.head], entity2id[e.tail]) for e in edge_list])
        self.relation_list = torch.tensor([relation2id[e.rel] for e in edge_list])
    
    def __len__(self):
        return self.edge_list.shape[0]

    def __getitem__(self, index) -> Tuple[int,int]:
        return self.edge_list[index], self.relation_list[index]

class WordNetDataModule(pl.LightningDataModule):
    def __init__(self, dataset: str="WordNet18RR", batch_size=32) -> None:
        super().__init__()
        self.path = f"./{dataset}/raw"
        self.dataset = dataset
        self.batch_size = batch_size
        self.num_entities = 40943

        if dataset == 'WordNet18RR':
            self.num_relations = 11
        else:
            self.num_relations = 18
        
        self.params = {"pin_memory": True, "batch_size": batch_size}

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = WordNetDataset(dataset=self.dataset, split="train")
            self.val_dataset = WordNetDataset(dataset=self.dataset, split="valid")
        
        if stage == "predict":
            self.test_dataset = WordNetDataset(dataset=self.dataset, split="test")
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, shuffle=True, **self.params)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, shuffle=False, **self.params)

    def predict_dataloader(self):
        return DataLoader(self.test_dataset, shuffle=False, **self.params)


### Lightning Model

We start by defining a custom TransE model that implements the margin ranking loss as well as the corruption procedure from the paper.

For clarity the *double_corrupt_edge_list()* was tested but ultimately dropped in the model analysis as it implements a corruption procedure that works by creating 2 copies of a batch where in each first the heads and then the tails are replaced by a random entity. This is due to some confusion raised from the formulat (2) in section 2 page 3, where the set $$S'$$ seems to include both version of the triplet (head or tail replaced).

But then a few lines later the authors claim *"Then, a smallset of triplets is sampled from the training set, and will serve as the training triplets of the minibatch. For each such triplet, we then sample a single corrupted triplet."* which we interpreted as building a single corrupted list of triplets for a batch where for each triplet we replace the head or the tail.

We highlight how training using *double_corrupt_edge_list()* doesn't result in any noticeble change in performance, which we kept just for completeness.

In [6]:
class TransE(pl.LightningModule):
    def __init__(self, margin: int=1, emb_dim: int=20, learning_rate=0.01, p_norm=1, dataset="WordNet18RR") -> None:
        """ Instatiate the entity and relation matrix of the TransE model
            https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data

        Args:
            n_entities (int): _description_
            n_relations (int): _description_
            margin (int, optional): _description_. Defaults to 1.
            emb_dim (int, optional): _description_. Defaults to 50.
        """
        super().__init__()
        self.margin = margin
        self.emb_dim = emb_dim
        self.learning_rate = learning_rate
        self.p_norm = p_norm

        # dataset specific values
        self.num_entities = 40943

        if dataset == "WordNet18":
            self.num_relations = 18
        else:
            self.num_relations = 11

        # initialize embeddings
        self.entity_mat = nn.Embedding(self.num_entities, emb_dim).to(self.device)
        self.relation_mat = nn.Embedding(self.num_relations, emb_dim).to(self.device)

        with torch.no_grad():
            # initialize with random uniform
            val = 6/np.sqrt(emb_dim)
            self.entity_mat.weight.uniform_(-val, val)
            self.relation_mat.weight.uniform_(-val, val)

            # normalize entity and relation embeddings
            self.entity_mat.weight.copy_(F.normalize(self.entity_mat.weight, p=self.p_norm, dim=-1))
            self.relation_mat.weight.copy_(F.normalize(self.relation_mat.weight, p=self.p_norm, dim=-1))

    def double_corrupt_edge_list(self, edge_list: torch.Tensor) -> Union[torch.Tensor, torch.Tensor]:
        """ given a list of triplets return two corrupted lists, the first randomly replacing
            head entities, the second randomly replacing tail entities
            NOTE: this function has not been used but was part of the experiments """
        n = edge_list.shape[0]
        entity_list = range(self.num_entities)

        # sample random entity replacements
        r1 = np.random.choice(entity_list, size=n)
        r2 = np.random.choice(entity_list, size=n)

        corrupted_heads = edge_list.detach().clone()
        corrupted_tails = edge_list.detach().clone()

        corrupted_heads[:,0] = torch.from_numpy(r1)
        corrupted_tails[:,1] = torch.from_numpy(r2)
                
        return corrupted_heads, corrupted_tails

    def corrupt_edge_list(self, edge_list: torch.Tensor) -> torch.Tensor:
        """ given a list of triplets return a single list of triplet where either
            the head or the tail has been randomly replaced, but not both """
        n = edge_list.shape[0]
        entity_list = range(self.num_entities)

        heads = edge_list[:,0]
        tails = edge_list[:, 1]

        # sample random entity replacements
        sample = np.random.choice(entity_list, size=n)
        sample = torch.from_numpy(sample).type(torch.int64).to(self.device)

        # random selection of either head or tail
        pos = np.random.choice([0, 1], size=n)
        pos = torch.from_numpy(pos).type(torch.int64).to(self.device)
        pos = pos.reshape(-1,1)

        # create a tensor of two columns where the first is head/tail
        # and the second represents the random sample of entities
        corrupted_heads = torch.vstack([heads, sample]).T
        corrupted_tails = torch.vstack([tails, sample]).T

        # keep either the head/tail or the random entity from each tensor
        corrupted_heads = corrupted_heads.gather(1, pos.reshape(-1,1))
        corrupted_tails = corrupted_tails.gather(1, (1-pos).reshape(-1,1))

        # combine the resulting triplet with the gurantee that either head or
        # tail has been randomly replaced but not both at the same time
        corrupted_triplet = torch.hstack([corrupted_heads, corrupted_tails])
        
        return corrupted_triplet

    
    def embedding_loss(self, batch: torch.Tensor) -> torch.Tensor:
        """ returns the margin ranking loss for a batch of triplets
            according to the corruption procedure defined at page 3 """
        edge_list, labels = batch
        
        loss = torch.zeros(1).to(self.device)

        # to use double corruption uncomment the following line
        #corrupted_heads, corrupted_tails = self.corrupt_edge_list(edge_list)
        corrupted_triplet = self.corrupt_edge_list(edge_list)
        
        # to use the double list corruption uncommented the following
        # lines and comment the alternative definitions
        """ t1 = self.entity_mat.weight[edge_list.repeat(2,1)]          
        t2 = torch.vstack([self.entity_mat.weight[corrupted_heads],    
                           self.entity_mat.weight[corrupted_tails]])
        rel = self.relation_mat.weight[labels].repeat(2,1) """

        # (single) corruption procedure
        t1 = self.entity_mat.weight[edge_list]
        t2 = self.entity_mat.weight[corrupted_triplet]
        rel = self.relation_mat.weight[labels]

        # normalize entity (maybe unnecessary here)
        t1 = F.normalize(t1, p=self.p_norm, dim=-1)
        t2 = F.normalize(t2, p=self.p_norm, dim=-1)

        # margin ranking loss, dim1 represents a triplet, dim2 represent either head=0
        # or tail=1 and dim3 is the embedding representation of the entity
        pos = torch.norm(t1[:,0,:] + rel - t1[:,1,:], dim=-1, p=self.p_norm)
        neg = torch.norm(t2[:,0,:] + rel - t2[:,1,:], dim=-1, p=self.p_norm)
        loss = torch.clip((self.margin + pos - neg), min=0).sum()

        return loss

    def evaluation_protocol(self, batch: torch.Tensor):
        edge_list, labels = batch
        batch_size = edge_list.shape[0]

        # combine heads, tails and labels
        triplets = torch.hstack([edge_list, labels.reshape(-1,1)])

        # repeat all triplets for n_entities times
        triplets = triplets[:,np.newaxis,:].repeat(1,self.num_entities,1)

        true_pos_total = list()
        rank_pos_list = list()

        # repeat corruption for both head and tail
        for pos in [0,1]:
            x = triplets.detach().clone()
            
            # replace all heads/tails with list of all possible entities
            x[:,:,pos] = torch.tensor(range(self.num_entities))[np.newaxis,:].repeat(batch_size,1).to(self.device)

            # triplets are arranged as (head, tail, label)
            head = self.entity_mat.weight[x[:,:,0]]
            tail = self.entity_mat.weight[x[:,:,1]]
            rel = self.relation_mat.weight[x[:,:,2]]

            # compute distance between head + label and tail
            norms = torch.norm(head + rel - tail, dim=-1, p=self.p_norm)

            # get index positions of sorted norms for each triplet
            rankings = torch.vstack([torch.argsort(x) for x in norms.unbind(dim=0)])

            # find position of heads within the rankings
            torch.save(rankings, "rankings.pt")
            torch.save(edge_list, "edge_list.pt")
            rank_pos = torch.where(rankings == edge_list[:,pos].reshape(-1,1))[1]

            rank_pos_list.append(rank_pos)
            true_pos_total.append(rank_pos < 10)

        return torch.vstack(rank_pos_list).flatten()

    def training_step(self, batch: torch.Tensor, batch_idx: int):
        loss_dict = {'train_loss': self.embedding_loss(batch)}
        self.log_dict(loss_dict, logger=True)

    def validation_step(self, batch: torch.Tensor, batch_idx: int):
        loss = self.embedding_loss(batch)
        batch_rankings = self.evaluation_protocol(batch)
        self.log_dict({"val_loss": loss}, prog_bar=True, on_epoch=True, logger=True)
        return {"val_loss": loss, "batch_rankings": batch_rankings}
    
    def predict_step(self, batch: torch.Tensor, batch_idx: int):
        batch_rankings = self.evaluation_protocol(batch)
        return {"batch_rankings": batch_rankings}

    def on_train_epoch_end(self):
        with torch.no_grad():
             # keep entities embeddings normalized
            self.entity_mat.weight.copy_(F.normalize(self.entity_mat.weight, p=2, dim=1))
    
    def compute_epoch_metrics(self, outputs: List[torch.Tensor], stage: str, log_value: bool=True):
        """ compute loss and metrics for the current epoch outputs """
        epoch_rankings = torch.hstack([x['batch_rankings'] for x in outputs])
        mean_rank = epoch_rankings.float().float().mean()
        hit_at_10 = (epoch_rankings < 10).float().mean()*100
        if log_value:
            self.log_dict({f"{stage}_mean_rank": mean_rank,
                           f"{stage}_hits@10": hit_at_10},
                           prog_bar=True, on_epoch=True, logger=True)
        else:
            return mean_rank, hit_at_10
    
    def validation_epoch_end(self, outputs):
        self.compute_epoch_metrics(outputs, stage="val")

    def prediction_epoch_end(self, outputs):
        self.compute_epoch_metrics(outputs, stage="predict")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        

In [7]:
def train_transe(config: dict,
                 max_epochs: int=100,
                 accelerator: str='gpu',
                 num_best_ckpt: int=3,
                 patience: int=10,
                 min_delta: float=0.5,
                 main_path: str='./',
                 dataset: str='WordNet18RR') -> None:
    """ train a TransE model on the input parameters

    Args:
        config (dict): dictionary containing keys [emb_dim, lr, margin, p_norm]
        max_epochs (int, optional): maximum number of epochs. Defaults to 100.
        accelerator (str, optional): pytorch accelerator. Defaults to 'gpu'.
        num_best_ckpt (int, optional): number of best models to save through training. Defaults to 3.
        patience (int, optional): number of epochs to wait for early stopping. Defaults to 10.
        min_delta (float, optional): minimum change of meank rank for early stopping. Defaults to 0.5.
        main_path (str, optional): main path to store models. Defaults to './'.
        dataset (str, optional): dataset to train on.  Defaults to 'WordNet18RR'.
    """
    
    model = TransE(emb_dim=config['emb_dim'],
                learning_rate=config['lr'],
                margin=config['margin'],
                p_norm=config['p_norm'],
                dataset=dataset)

    dm = WordNetDataModule(batch_size=config['batch_size'], dataset=dataset)

    dir_path = f"ckpt_{dataset}/emb_dim={config['emb_dim']}-lr={config['lr']}-margin={config['margin']}-p_norm={config['p_norm']}"
    dir_path = os.path.join(main_path, dir_path)

    # using mean predicted rank on validation set as described in section 4.2
    early_stop_rank = EarlyStopping(monitor="val_mean_rank",
                                    min_delta=min_delta,
                                    patience=patience,
                                    verbose=False,
                                    mode="min")

    # save best models based on mean rank on validation set
    checkpoint_callback = ModelCheckpoint(save_top_k=num_best_ckpt,
                                        monitor="val_mean_rank",
                                        dirpath=dir_path,
                                        filename="transe-{dataset}-{epoch}-{val_mean_rank:.0f}-{val_hits@10:.1f}")

    # loggin using TensorBoard
    logger = TensorBoardLogger(f'tb_logs_{dataset}', name='TransE')

    trainer = pl.Trainer(max_epochs=max_epochs,
                        accelerator=accelerator,
                        callbacks=[checkpoint_callback, early_stop_rank],
                        logger=logger)

    try:
        # resume from best model if checkpoint is available
        ckpt_path = os.path.join(dir_path, os.listdir(dir_path)[-1])
    except:
        ckpt_path = None

    trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)


### Parameters search

The following cells show how to train the model on a range of parameters or on some specific target parameters. A few notes here:

- we trained the model using the same parameter configuration from Bordes et al. but on the number of dimension since 50 didn't fit in memory with a batch size of 128. For efficiency purposes we opted to keep a higher batch size and lower emb dim, though a few test have been done with batch_size=16 and emb_dim=50 which didn't beat the best model we found on the suggested configuration.
- given underwhelming performance on the WordNet18RR (~40% hits@10) compared to the numbers achieved from the authors on the original WordNet18 (~75% hits@10) we decided to test the model also on the latter but we were not able to replicate the same results only reaching ~60% hits@10, see later for details.
- different attempts have been made in order to figure out why the difference in performance, such as using a different normalization procedure, different parameters for early stopping, using a from scratch and an out of library maring ranking loss, etc. none of those resulted in a substantial improvement

All models are available under the ckpt_{dataset} directory where the filename highlight number of epochs, mean rank and hits@10.

In [None]:
# best model: emb_dim=40, lr=0.001, margin=1, p_norm=2

# add values to each list to enlarge parameter search
# all models get saved into 'ckpt_{dataset}'

# embd_dim=50 was used in the paper but on the current setup it requires a smaller
# batch size and hence a much slower training time, which is why 40 was used instead
config = {
    "batch_size": [128],
    "lr": [0.001, 0.01, 0.1],
    "emb_dim": [20, 40],
    "p_norm": [1],
    "margin": [1, 2]
}

num_epochs = 1000

# a from scratch method for parameter search was used due to some printing
# issues with tqdm that resulted in a messy output
keys, values = zip(*config.items())
comb_list = [dict(zip(keys,v)) for v in itertools.product(*values)]

for comb in comb_list:
    with suppress(Exception):
        train_transe(config=comb, max_epochs=num_epochs, dataset="WordNet18RR")

In [51]:
# use this cell to run a single training instance
single_train_config = {
    'batch_size': 128,
    'lr': 0.001,
    'emb_dim': 40,
    'p_norm': 2,
    'margin': 1
}

train_transe(config=single_train_config, main_path="./ckpt_WordNet18RR/", dataset="WordNet18RR")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | entity_mat   | Embedding | 1.6 M 
1 | relation_mat | Embedding | 440   
-------------------------------------------
1.6 M     Trainable params
0         Non-trainable params
1.6 M     Total params
6.553     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 4:  94%|█████████▎| 659/703 [00:11<00:00, 56.97it/s, loss=74.8, v_num=69, val_loss=100.0, val_mean_rank=1.08e+4, val_hits@10=8.600]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


### Prediction on WordNet18RR (and WN18)
The following cell defines a prediction function that takes the configuration of best model, loads the single best among the topk (=3) stored earlier during training and prints out the result. The best performance achieved on the test set of *WordNet18RR* has been *hits@10 of 41.6%* and a *mean rank of ~2975*.


In [13]:
def hit10_from_filename(filename):
    filename = filename.replace('.ckpt','')
    return float(filename.split('@10=')[1])

def predict_transe(config: dict):
    """ outputs the mean rank and hist@10 for the input configuration """
    trainer = pl.Trainer()

    dm = WordNetDataModule(batch_size=config['batch_size'], dataset=config['dataset'])

    # ckpt_path from model config
    path = f"./ckpt_{config['dataset']}/emb_dim={config['emb_dim']}-lr={config['learning_rate']}-margin={config['margin']}-p_norm={config['p_norm']}/"

    # take filename of model with highest hit@10
    filename = max(os.listdir(path), key=hit10_from_filename)
    
    model = TransE(emb_dim=config['emb_dim'],
                learning_rate=config['learning_rate'],
                margin=config['margin'],
                p_norm=config['p_norm'],
                dataset=config['dataset'])
    
    pred = trainer.predict(model, datamodule=dm, ckpt_path=os.path.join(path, filename))
    test_mean_rank, test_hits_at_10 = model.compute_epoch_metrics(pred, stage="predict", log_value=False)
    print("\n")
    print(f"test_mean_rank={test_mean_rank:.0f}, test_hits@10={test_hits_at_10:.2f}%")

# PREDICTION ON TEST SET

# set the following config to match the best model parameters
best_model_config = {
    'batch_size': 128,
    'learning_rate': 0.001,
    'emb_dim': 40,
    'p_norm': 2,
    'margin': 1,
    'dataset': 'WordNet18RR'
}

predict_transe(config=best_model_config)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at ./ckpt_WordNet18RR/emb_dim=40-lr=0.001-margin=1-p_norm=2/transe-wordnet-epoch=46-val_mean_rank=2854-val_hits@10=41.6.ckpt
Loaded model weights from checkpoint at ./ckpt_WordNet18RR/emb_dim=40-lr=0.001-margin=1-p_norm=2/transe-wordnet-epoch=46-val_mean_rank=2854-val_hits@10=41.6.ckpt


Predicting DataLoader 0: 100%|██████████| 25/25 [01:33<00:00,  3.75s/it]


test_mean_rank=2975, test_hits@10=41.59%


For completeness we report the results on WordNet18 with the best model configuration (the same achieved on WN18RR):
- hits@10 = 61.7%
- mean rank = 170

In [15]:
# for sake of completeness we report the test set performance on WordNet18
best_model_config = {
    'batch_size': 128,
    'learning_rate': 0.001,
    'emb_dim': 40,
    'p_norm': 2,
    'margin': 1,
    'dataset': 'WordNet18'
}

predict_transe(config=best_model_config)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Predicting DataLoader 0:  12%|█▎        | 5/40 [11:24<1:19:53, 136.95s/it]

Restoring states from the checkpoint path at ./ckpt_WordNet18/emb_dim=40-lr=0.001-margin=1-p_norm=2/transe-dataset=0-epoch=88-val_mean_rank=202-val_hits@10=61.2.ckpt
Loaded model weights from checkpoint at ./ckpt_WordNet18/emb_dim=40-lr=0.001-margin=1-p_norm=2/transe-dataset=0-epoch=88-val_mean_rank=202-val_hits@10=61.2.ckpt
  rank_zero_warn(



Predicting DataLoader 0: 100%|██████████| 40/40 [02:55<00:00,  4.38s/it]


test_mean_rank=170, test_hits@10=61.71%
