Write implementation on PyTorch for TransE model (you can use TorchGeometric or DGL library for working with graphs) and train your model on WordNet18RR dataset (you can use loaded dataset from any graph library).

As a result, you must provide a link to github (or gitlab) with all the source code.
The readability of the code, the presence of comments, type annotations, and the quality of the code as a whole will be taken into account when checking the test case.

### Imports and helper Functions

In [19]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Union, Callable, Optional
import csv

import torch
from torch.utils.data import Dataset
from torch import nn
from torch_geometric.datasets import WordNet18RR
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

torch.cuda.is_available()

True

In [2]:
# download wordnet dataset, we'll be using the processed file data.pt
dataset = WordNet18RR('./WordNet18RR/')

### Custom Dataset and DataLoader

In [3]:

class Edge():
    def __init__(self, u, v, label) -> None:
        self.u = u
        self.v = v
        self.label = label

    def __str__(self) -> str:
        return f"{self.u} {self.label} {self.v}"

def load_edge_list_from_file(path: str, header: bool=False):
    edge_list = list()

    with open(path, "r") as f:
        tsv_reader = csv.reader(f, delimiter="\t")

        if header:
            next(tsv_reader)

        for row in tsv_reader:
            u, label, v = row
            edge_list.append(Edge(u=u, v=v, label=label))
    
    return edge_list

In [4]:
class WordNetEdgeDataset(Dataset):
    def __init__(self, path: str="WordNet18RR/processed/data.pt", split: str="train") -> None:
        super().__init__()
        data = torch.load(path)[0]
        mask_dict = {"train": data.train_mask, "test": data.test_mask, "val": data.val_mask}
        mask = mask_dict[split]
        self.edge_list = data.edge_index.T[mask, :]
        self.edge_labels = data.edge_type[mask]
    
    def __len__(self):
        return self.edge_list.shape[0]

    def __getitem__(self, index) -> int:
        return self.edge_list[index,:], self.edge_labels[index]
    

In [182]:
class WordNetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str="WordNet18RR/processed/data.pt", batch_size=32, num_workers=6) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_entities = 40943
        self.num_relations = 11
        self.num_workers = num_workers
        self.loader_params = {
            'batch_size': batch_size,
            'pin_memory': True,
            'shuffle': True
        }

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = WordNetEdgeDataset(split="train", path=self.data_dir)
            self.val_dataset = WordNetEdgeDataset(split="val", path=self.data_dir)
        
        if stage == "test":
            self.test_dataset = WordNetEdgeDataset(split="test", path=self.data_dir)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, **self.loader_params)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, **self.loader_params)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, **self.loader_params)


### Lightning Model

In [213]:
class TransE(pl.LightningModule):
    def __init__(self, margin: int=1, emb_dim: int=50, learning_rate=0.01) -> None:
        """ Instatiate the entity and relation matrix of the TransE model
            https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data

        Args:
            n_entities (int): _description_
            n_relations (int): _description_
            margin (int, optional): _description_. Defaults to 1.
            emb_dim (int, optional): _description_. Defaults to 50.
        """
        super().__init__()
        self.margin = margin
        self.emb_dim = emb_dim
        self.learning_rate = learning_rate

        # dataset specific values
        self.num_entities = 40943
        self.num_relations = 11

        # initialize embeddings
        self.entity_mat = nn.Embedding(self.num_entities, emb_dim).to(self.device)
        self.relation_mat = nn.Embedding(self.num_relations, emb_dim).to(self.device)

        with torch.no_grad():
            # initialize with random uniform
            val = 6/np.sqrt(emb_dim)
            self.entity_mat.weight.uniform_(-val, val)
            self.relation_mat.weight.uniform_(-val, val)

            # normalize entity embeddings
            self.entity_mat.weight.copy_(nn.functional.normalize(self.entity_mat.weight, p=2, dim=1))

    def corrupt_edge_list(self, edge_list: torch.Tensor):
        """ sample either the head or tail of x from range(n) """
        idxs = (np.random.rand(edge_list.shape[0]) < 0.5).astype(int)  # pick either head or tail
        s = edge_list.detach().clone()

        # sample random entity replacements
        vals = np.random.choice(range(self.num_entities), size=edge_list.shape[0])

        for i,idx in enumerate(idxs):
            s[i,idx] = vals[i]
        
        return s

    def embedding_loss(self, batch):
        edge_list, labels = batch
        
        loss = torch.zeros(1).to(self.device)

        edge_list_cor = self.corrupt_edge_list(edge_list)
        
        # take embedding values for entities and relations
        e1 = self.entity_mat.weight[edge_list]
        e2 = self.entity_mat.weight[edge_list_cor]
        l = self.relation_mat.weight[labels]

        # compute the loss value
        n1 = torch.norm(e1[:,0,:] + l - e1[:,1,:], dim=1)
        n2 = torch.norm(e2[:,0,:] + l - e2[:,1,:], dim=1)
        loss = (self.margin + n1 - n2)
        loss = torch.clip(loss, min=0).sum()

        return loss

    def evaluation_protocol(self, batch):
        edge_list, labels = batch
        rankings_list = list()
        hits_at_10_list = list()

        for i in range(edge_list.shape[0]):
            # take a single test triplet
            test_triplet = edge_list[i].T

            n = self.num_entities

            # replicate triplet for num_entities time for corruption
            entities = torch.tensor(list(range(self.num_entities)))
            x_cor = test_triplet.repeat(n, 1)

            # relation embeddings
            l = self.relation_mat[labels[i].repeat(n)]

            # compute ranking and hits@10 by corrupting both head and tail
            for pos in [0,1]:
                # replace triplet head with each possible entities
                x_cor[:, pos] = entities
                val = test_triplet[pos]

                # get entity matrix for all possible pairings
                e1 = self.entity_mat[x_cor]

                # compute distance between head + label and tail
                dissimilarities = torch.norm(e1[:,0,:] + l - e1[:,1,:], dim=1)

                # rank distances in ascending order
                ranking = torch.argsort(dissimilarities)

                # find position of true triplet within ranking and if is <10
                test_pos = torch.where(ranking == val)[0].item()
                is_among_10 = (ranking[:10] == val).sum().item()

                # save current rank to later compute test results
                rankings_list.append(test_pos)
                hits_at_10_list.append(is_among_10)
    
        metrics = {
            "mean_rank": np.mean(rankings_list).astype(int),
            "hits_at_10": np.sum(hits_at_10_list)
        }

        self.log_dict(metrics)
        return metrics

    def training_step(self, batch, batch_idx):
        return self.embedding_loss(batch)

    def validation_step(self, batch, batch_idx):
        loss = self.embedding_loss(batch)
        metrics = self.evaluation_protocol(batch)
        metrics["val_loss": loss]
        return metrics

    def test_step(self, batch, batch_idx):
        loss = self.embedding_loss(batch)
        metrics = self.evaluation_protocol(batch)
        metrics["test_loss": loss]
        return metrics

    def on_train_epoch_end(self):
        with torch.no_grad():
             # keep entities embeddings normalized
            self.entity_mat.weight.copy_(nn.functional.normalize(self.entity_mat.weight, p=2, dim=1))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        

In [214]:
model = TransE(emb_dim=5)
dm = WordNetDataModule(batch_size=64)
logger = TensorBoardLogger("tb_logs", name="TransE")
trainer = pl.Trainer(max_epochs=1, accelerator='gpu')
trainer.fit(model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | entity_mat   | Embedding | 204 K 
1 | relation_mat | Embedding | 55    
-------------------------------------------
204 K     Trainable params
0         Non-trainable params
204 K     Total params
0.819     Total estimated model params size (MB)


                                                                           

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Epoch 0: 100%|██████████| 1405/1405 [00:19<00:00, 70.37it/s, loss=52.1, v_num=64, val_loss=53.40]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 1405/1405 [00:19<00:00, 70.32it/s, loss=52.1, v_num=64, val_loss=53.40]


In [188]:
torch.norm(model.entity_mat.weight, dim=1)

tensor([3.9371, 0.6720, 1.9015,  ..., 4.8295, 4.3622, 3.9676],
       grad_fn=<NormBackward1>)

In [361]:
entity_mat = model.entity_mat.weight
relation_mat = model.relation_mat.weight

rankings_list = list()
hit_at_10_list = list()

for i in range(batch.shape[0]):
    # take a single test triplet
    x = batch[i].T
    head, tail = x

    n = model.num_entities

    # all possible entities list
    entities = torch.tensor(list(range(model.num_entities)))
    x_cor = x.repeat(n, 1)

    # relation embeddings
    l = relation_mat[labels[i].repeat(n)]

    # compute ranking and hits@10 by corrupting both head and tail
    for pos in [0,1]:
        # replace triplet head with each possible entities
        x_cor[:, pos] = entities
        val = x[pos]

        # get entity matrix for all possible pairings
        e1 = entity_mat[x_cor]

        # compute distance between head + label and tail
        dissimilarities = torch.norm(e1[:,0,:] + l - e1[:,1,:], dim=1)

        # rank distances in ascending order
        ranking = torch.argsort(dissimilarities)

        # find position of true triplet within ranking and if is <10
        test_pos = torch.where(ranking == val)[0].item()
        is_among_10 = (ranking[:10] == val).sum().item()

        # save current rank to later compute test results
        rankings_list.append(test_pos)
        hit_at_10_list.append(is_among_10)

print("mean_rank =", np.mean(rankings_list).astype(int))
print("hits@10 = ", np.sum(hit_at_10_list))

mean_rank = 16216
hits@10 =  0


3

In [318]:
x[0]

tensor(18045)

In [156]:
e1 = entity_mat[batch]
l = relation_mat[labels]
batch_cor = model.corrupt_batch(next(iter(dm.train_dataloader())))

In [151]:
margin = 1

# take embedding values for entities and relations
e1 = entity_mat[batch]
e2 = entity_mat[batch_cor]
l = relation_mat[labels]

# compute the loss value
n1 = torch.norm(e1[:,0,:] + l - e1[:,1,:], dim=1)
n2 = torch.norm(e2[:,0,:] + l - e2[:,1,:], dim=1)

loss = (margin + n1 - n2).sum()
loss = torch.clip(loss, min=0)
loss

tensor(59.9413, grad_fn=<ClampBackward1>)

In [127]:
torch.norm(entity_mat[batch[:,0]] + relation_mat[labels] - entity_mat[batch[:,1]], dim=1).sum()

tensor(113.5347, grad_fn=<SumBackward0>)