Write implementation on PyTorch for TransE model (you can use TorchGeometric or DGL library for working with graphs) and train your model on WordNet18RR dataset (you can use loaded dataset from any graph library).

As a result, you must provide a link to github (or gitlab) with all the source code.
The readability of the code, the presence of comments, type annotations, and the quality of the code as a whole will be taken into account when checking the test case.

### Imports and helper Functions

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Union, Callable, Optional
import csv

import torch
from torch.utils.data import Dataset
from torch import nn
from torch_geometric.datasets import WordNet18RR
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

torch.cuda.is_available()

True

In [2]:
# download wordnet dataset, we'll be using the processed file data.pt
dataset = WordNet18RR('./WordNet18RR/')

### Custom Dataset and DataLoader

In [3]:

class Edge():
    def __init__(self, u, v, label) -> None:
        self.u = u
        self.v = v
        self.label = label

    def __str__(self) -> str:
        return f"{self.u} {self.label} {self.v}"

def load_edge_list_from_file(path: str, header: bool=False):
    edge_list = list()

    with open(path, "r") as f:
        tsv_reader = csv.reader(f, delimiter="\t")

        if header:
            next(tsv_reader)

        for row in tsv_reader:
            u, label, v = row
            edge_list.append(Edge(u=u, v=v, label=label))
    
    return edge_list

In [4]:
class WordNetEdgeDataset(Dataset):
    def __init__(self, path: str="WordNet18RR/processed/data.pt", split: str="train") -> None:
        super().__init__()
        data = torch.load(path)[0]
        mask_dict = {"train": data.train_mask, "test": data.test_mask, "val": data.val_mask}
        mask = mask_dict[split]
        self.edge_list = data.edge_index.T[mask, :]
        self.edge_labels = data.edge_type[mask]
    
    def __len__(self):
        return self.edge_list.shape[0]

    def __getitem__(self, index) -> int:
        return self.edge_list[index,:], self.edge_labels[index]
    

In [5]:
class WordNetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str="WordNet18RR/processed/data.pt", batch_size=32) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_entities = 40943
        self.num_relations = 11
        self.params = {"pin_memory": True, "batch_size": batch_size}

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = WordNetEdgeDataset(split="train", path=self.data_dir)
            self.val_dataset = WordNetEdgeDataset(split="val", path=self.data_dir)
        
        if stage == "test":
            self.test_dataset = WordNetEdgeDataset(split="test", path=self.data_dir)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, shuffle=True, **self.params)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, shuffle=False, **self.params)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, shuffle=False **self.params)

### Lightning Model

In [8]:
class TransE(pl.LightningModule):
    def __init__(self, margin: int=1, emb_dim: int=50, learning_rate=0.01) -> None:
        """ Instatiate the entity and relation matrix of the TransE model
            https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data

        Args:
            n_entities (int): _description_
            n_relations (int): _description_
            margin (int, optional): _description_. Defaults to 1.
            emb_dim (int, optional): _description_. Defaults to 50.
        """
        super().__init__()
        self.margin = margin
        self.emb_dim = emb_dim
        self.learning_rate = learning_rate

        # dataset specific values
        self.num_entities = 40943
        self.num_relations = 11

        # initialize embeddings
        self.entity_mat = nn.Embedding(self.num_entities, emb_dim).to(self.device)
        self.relation_mat = nn.Embedding(self.num_relations, emb_dim).to(self.device)

        with torch.no_grad():
            # initialize with random uniform
            val = 6/np.sqrt(emb_dim)
            self.entity_mat.weight.uniform_(-val, val)
            self.relation_mat.weight.uniform_(-val, val)

            # normalize entity embeddings
            self.entity_mat.weight.copy_(nn.functional.normalize(self.entity_mat.weight, p=2, dim=1))

    def corrupt_edge_list(self, edge_list: torch.Tensor):
        """ sample either the head or tail of x from range(n) """
        n = edge_list.shape[0]
        entity_list = range(self.num_entities)

        idxs = (np.random.rand(edge_list.shape[0]) < 0.5).astype(int)  # pick either head or tail
        # sample random entity replacements
        rand_corrupted = edge_list.detach().clone()
        vals = np.random.choice(entity_list, size=n)

        for i,idx in enumerate(idxs):
            rand_corrupted[i,idx] = vals[i]
                
        return rand_corrupted
            

    def embedding_loss(self, batch):
        edge_list, labels = batch
        
        loss = torch.zeros(1).to(self.device)

        edge_list_cor = self.corrupt_edge_list(edge_list)
        
        # take embedding values for entities and relations
        e1 = self.entity_mat.weight[edge_list]
        e2 = self.entity_mat.weight[edge_list_cor]
        l = self.relation_mat.weight[labels]

        # compute the loss value
        n1 = torch.norm(e1[:,0,:] + l - e1[:,1,:], dim=1)
        n2 = torch.norm(e2[:,0,:] + l - e2[:,1,:], dim=1)
        loss = (self.margin + n1 - n2)
        loss = torch.clip(loss, min=0).sum()

        return loss

    def evaluation_protocol(self, batch):
        edge_list, labels = batch
        rankings_list = list()
        hits_at_10_list = list()

        with torch.no_grad():
            for i in range(edge_list.shape[0]):
                # take a single test triplet
                test_triplet = edge_list[i]

                n = self.num_entities

                # replicate triplet for num_entities time for corruption
                entities = torch.tensor(list(range(self.num_entities)))
                x_cor = test_triplet.repeat(n, 1)

                # relation embeddings
                l = self.relation_mat.weight[labels[i].repeat(n)]

                # compute ranking and hits@10 by corrupting both head and tail
                for pos in [0,1]:
                    # replace triplet head with each possible entities
                    x_cor[:, pos] = entities

                    # get entity matrix for all possible pairings
                    e = self.entity_mat.weight[x_cor]

                    # compute distance between head + label and tail
                    dissimilarities = torch.norm(e[:,0,:] + l - e[:,1,:], dim=1)

                    # rank distances in ascending order
                    ranking = torch.argsort(dissimilarities)

                    # find position of true triplet within ranking and if is <10
                    val = test_triplet[pos] # id of replaced entity
                    test_pos = torch.where(ranking == val)[0].item()
                    is_among_10 = int(test_pos < 10)

                    # save current rank to later compute test results
                    rankings_list.append(test_pos)
                    hits_at_10_list.append(is_among_10)

            mean_rank = np.mean(rankings_list).astype(int)
            hits_at_10 = np.mean(hits_at_10_list)*100

        return mean_rank, hits_at_10

    def training_step(self, batch, batch_idx):
        return self.embedding_loss(batch)

    def validation_step(self, batch, batch_idx):
        loss = self.embedding_loss(batch)
        mean_rank, hits_at_10 = self.evaluation_protocol(batch)
        metrics = {"val_loss": loss, "val_mean_rank": mean_rank, "val_hits@10": hits_at_10}
        self.log_dict(metrics, prog_bar=True, on_epoch=True)
        return metrics

    def test_step(self, batch, batch_idx):
        loss = self.embedding_loss(batch)
        mean_rank, hits_at_10 = self.evaluation_protocol(batch)
        metrics = {"val_loss": loss, "val_mean_rank": mean_rank, "val_hits@10": hits_at_10}
        self.log_dict(metrics, prog_bar=True, on_epoch=True)
        return metrics

    def on_train_epoch_end(self):
        with torch.no_grad():
             # keep entities embeddings normalized
            self.entity_mat.weight.copy_(nn.functional.normalize(self.entity_mat.weight, p=2, dim=1))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        

In [9]:
import os

emb_dim = 20
lr = 0.01
margin = 2
max_epochs = 1000
top_k_cp = 3

# instantiated model and data module
model = TransE(emb_dim=emb_dim,
               learning_rate=lr,
               margin=margin)

dm = WordNetDataModule(batch_size=32)

dir_path = f"checkpoints/emb_dim={emb_dim}-lr={lr}-margin={margin}"

# using mean predicted rank on validation set as described in section 4.2
early_stop_rank = EarlyStopping(monitor="val_mean_rank",
                                min_delta=1,
                                patience=10,
                                verbose=False,
                                mode="min")

# save best models based on mean rank on validation set
checkpoint_callback = ModelCheckpoint(save_top_k=top_k_cp,
                                      monitor="val_mean_rank",
                                      dirpath=dir_path,
                                      filename="transe-wordnet-{epoch}-{val_mean_rank:.0f}")

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='gpu',
                     callbacks=[checkpoint_callback,early_stop_rank])

try:
    # resume from best model if checkpoint is available
    ckpt_path = os.path.join(dir_path, os.listdir(dir_path)[-1])
except:
    ckpt_path = None

trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | entity_mat   | Embedding | 818 K 
1 | relation_mat | Embedding | 220   
-------------------------------------------
819 K     Trainable params
0         Non-trainable params
819 K     Total params
3.276     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  3.86it/s]



                                                                           

  rank_zero_warn(


Epoch 0:  17%|█▋        | 482/2809 [00:05<00:27, 83.70it/s, loss=60.1, v_num=14]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [10]:
edge_list, labels = next(iter(dm.train_dataloader()))

In [13]:
def new_corrupt(edge_list: torch.Tensor, num_entities):
    """ sample either the head or tail of x from range(n) """
    n = edge_list.shape[0]
    entity_list = range(num_entities)

    # sample random entity replacements
    r1 = np.random.choice(entity_list, size=n)
    r2 = np.random.choice(entity_list, size=n)

    corrupted_heads = edge_list.detach().clone()
    corrupted_tails = edge_list.detach().clone()

    corrupted_heads[:,0] = torch.from_numpy(r1)
    corrupted_tails[:,1] = torch.from_numpy(r2)
            
    return corrupted_heads, corrupted_tails

loss = torch.zeros(1)

corrupted_heads, corrupted_tails = new_corrupt(edge_list, num_entities=model.num_entities)

edge_list[:5], corrupted_heads[:5], corrupted_tails[:5]

(tensor([[12300, 30442],
         [10109,  5578],
         [11289,  6353],
         [ 3596, 14372],
         [19970, 36486]]),
 tensor([[31949, 30442],
         [39700,  5578],
         [24556,  6353],
         [ 8000, 14372],
         [ 9051, 36486]]),
 tensor([[12300, 38851],
         [10109, 10472],
         [11289, 11403],
         [ 3596, 33886],
         [19970, 34368]]))

In [14]:
entity_mat = model.entity_mat
relation_mat = model.relation_mat

# take embedding values for entities and relations
e1 = entity_mat.weight[edge_list]
e2 = entity_mat.weight[corrupted_heads]
e3 = entity_mat.weight[corrupted_tails]
l = relation_mat.weight[labels]

In [None]:


# compute the loss value
n1 = torch.norm(e1[:,0,:] + l - e1[:,1,:], dim=1)
n2 = torch.norm(e2[:,0,:] + l - e2[:,1,:], dim=1)
loss = (self.margin + n1 - n2)
loss = torch.clip(loss, min=0).sum()