Write implementation on PyTorch for TransE model (you can use TorchGeometric or DGL library for working with graphs) and train your model on WordNet18RR dataset (you can use loaded dataset from any graph library).

As a result, you must provide a link to github (or gitlab) with all the source code.
The readability of the code, the presence of comments, type annotations, and the quality of the code as a whole will be taken into account when checking the test case.

### Imports and helper Functions

In [43]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
from typing import Union, Callable, Optional

import torch
from torch.utils.data import Dataset
from torch import nn
from torch_geometric.datasets import WordNet18RR
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

torch.cuda.is_available()

True

In [2]:
# download wordnet dataset, we'll be using the processed file data.pt
dataset = WordNet18RR('./WordNet18RR/')

### Custom Dataset and DataLoader

In [3]:

class Edge():
    def __init__(self, u, v, label) -> None:
        self.u = u
        self.v = v
        self.label = label

    def __str__(self) -> str:
        return f"{self.u} {self.label} {self.v}"

def load_edge_list_from_file(path: str, header: bool=False):
    edge_list = list()

    with open(path, "r") as f:
        tsv_reader = csv.reader(f, delimiter="\t")

        if header:
            next(tsv_reader)

        for row in tsv_reader:
            u, label, v = row
            edge_list.append(Edge(u=u, v=v, label=label))
    
    return edge_list

In [39]:
class WordNetProcessedDataset(Dataset):
    def __init__(self, path: str="WordNet18RR/processed/data.pt", split: str="train") -> None:
        super().__init__()
        data = torch.load(path)[0]
        mask_dict = {"train": data.train_mask, "test": data.test_mask, "val": data.val_mask}
        mask = mask_dict[split]
        self.edge_list = data.edge_index.T[mask, :]
        self.edge_labels = data.edge_type[mask]
    
    def __len__(self):
        return self.edge_list.shape[0]

    def __getitem__(self, index) -> int:
        return self.edge_list[index,:], self.edge_labels[index]

class WordNetProcessedDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str="WordNet18RR/processed/data.pt", batch_size=32) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_entities = 40943
        self.num_relations = 11
        self.params = {"pin_memory": True, "batch_size": batch_size}

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = WordNetProcessedDataset(split="train", path=self.data_dir)
            self.val_dataset = WordNetProcessedDataset(split="val", path=self.data_dir)
        
        if stage == "test":
            self.test_dataset = WordNetProcessedDataset(split="test", path=self.data_dir)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, shuffle=True, **self.params)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, shuffle=False, **self.params)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, shuffle=False, **self.params)
    

In [40]:
class WordNetRawDataset(Dataset):
    def __init__(self, path: str="WordNet18RR/raw/", split="train") -> None:
        super().__init__()
        # load triplets from files
        train_edge_list = load_edge_list_from_file(os.path.join(path, "train.txt"))
        val_edge_list = load_edge_list_from_file(os.path.join(path, "valid.txt"))
        test_edge_list = load_edge_list_from_file(os.path.join(path, "test.txt"))

        # create dictionary mapping entity/label str to id
        entity_list = list()
        labels_list = list()

        for edge_list in [train_edge_list, val_edge_list, test_edge_list]:
            entity_list += [x.u for x in edge_list] + [x.v for x in edge_list]
            labels_list += [x.label for x in edge_list]

        entity_list = list(set(entity_list))
        entity_ids = dict(zip(entity_list, range(len(entity_list))))

        labels_list = list(set(labels_list))
        labels_ids = dict(zip(labels_list, range(len(labels_list))))

        edge_list_dict = {"train": train_edge_list, "test": test_edge_list, "val": val_edge_list}
        self.edge_list_id = torch.tensor([(entity_ids[edge.u], entity_ids[edge.v]) for edge in edge_list_dict[split]])
        self.labels_list_id = torch.tensor([labels_ids[edge.label] for edge in edge_list_dict[split]])
    
    def __len__(self):
        return self.edge_list_id.shape[0]

    def __getitem__(self, index) -> int:
        return self.edge_list_id[index,:], self.labels_list_id[index]

class WordNetRawDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str="WordNet18RR/processed/data.pt", batch_size=32) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_entities = 40943
        self.num_relations = 11
        self.params = {"pin_memory": True, "batch_size": batch_size}

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = WordNetRawDataset(split="train")
            self.val_dataset = WordNetRawDataset(split="val")
        
        if stage == "test":
            self.test_dataset = WordNetRawDataset(split="test")
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, shuffle=True, **self.params)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, shuffle=False, **self.params)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, shuffle=False **self.params)


### Lightning Model

In [8]:
class TransE(pl.LightningModule):
    def __init__(self, margin: int=1, emb_dim: int=50, learning_rate=0.01, p_norm=1) -> None:
        """ Instatiate the entity and relation matrix of the TransE model
            https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data

        Args:
            n_entities (int): _description_
            n_relations (int): _description_
            margin (int, optional): _description_. Defaults to 1.
            emb_dim (int, optional): _description_. Defaults to 50.
        """
        super().__init__()
        self.margin = margin
        self.emb_dim = emb_dim
        self.learning_rate = learning_rate
        self.p_norm = p_norm

        # dataset specific values
        self.num_entities = 40943
        self.num_relations = 11

        # initialize embeddings
        self.entity_mat = nn.Embedding(self.num_entities, emb_dim).to(self.device)
        self.relation_mat = nn.Embedding(self.num_relations, emb_dim).to(self.device)

        with torch.no_grad():
            # initialize with random uniform
            val = 6/np.sqrt(emb_dim)
            self.entity_mat.weight.uniform_(-val, val)
            self.relation_mat.weight.uniform_(-val, val)

            # normalize entity and relation embeddings
            self.entity_mat.weight.copy_(nn.functional.normalize(self.entity_mat.weight, p=2, dim=1))
            self.relation_mat.weight.copy_(nn.functional.normalize(self.relation_mat.weight, p=2, dim=1))

    def corrupt_edge_list(self, edge_list: torch.Tensor):
        """ sample either the head or tail of x from range(n) """
        n = edge_list.shape[0]
        entity_list = range(self.num_entities)

        # sample random entity replacements
        r1 = np.random.choice(entity_list, size=n)
        r2 = np.random.choice(entity_list, size=n)

        corrupted_heads = edge_list.detach().clone()
        corrupted_tails = edge_list.detach().clone()

        corrupted_heads[:,0] = torch.from_numpy(r1)
        corrupted_tails[:,1] = torch.from_numpy(r2)
                
        return corrupted_heads, corrupted_tails
    
    def embedding_loss(self, batch):
        edge_list, labels = batch
        
        loss = torch.zeros(1).to(self.device)

        #edge_list_cor = self.corrupt_edge_list(edge_list)
        corrupted_heads, corrupted_tails = self.corrupt_edge_list(edge_list)
        
        # take embedding values for entities and relations
        e1 = self.entity_mat.weight[edge_list.repeat(2,1)]
        e2 = torch.vstack([self.entity_mat.weight[corrupted_heads],
                           self.entity_mat.weight[corrupted_tails]])
        l = self.relation_mat.weight[labels].repeat(2,1)

        # compute the loss value
        n1 = torch.norm(e1[:,0,:] + l - e1[:,1,:], dim=1, p=self.p_norm)
        n2 = torch.norm(e2[:,0,:] + l - e2[:,1,:], dim=1, p=self.p_norm)
        loss = (self.margin + n1 - n2)
        loss = torch.clip(loss, min=0).sum()

        return loss

    def evaluation_protocol(self, batch):
        edge_list, labels = batch
        batch_size = edge_list.shape[0]

        # combine heads, tails and labels
        triplets = torch.hstack([edge_list, labels.reshape(-1,1)])

        # repeat all triplets for n_entities times
        triplets = triplets[:,np.newaxis,:].repeat(1,self.num_entities,1)

        true_pos_total = list()
        rank_pos_list = list()

        # repeat corruption for both head and tail
        for pos in [0,1]:
            x = triplets.detach().clone()
            
            # replace all heads with list of all possible entities
            x[:,:,pos] = torch.tensor(range(self.num_entities))[np.newaxis,:].repeat(batch_size,1).to(self.device)

            # triplets are arranged as (head, tail, label)
            e1 = self.entity_mat.weight[x[:,:,0]]
            e2 = self.entity_mat.weight[x[:,:,1]]
            l = self.relation_mat.weight[x[:,:,2]]

            # compute distance between head + label and tail
            norms = torch.norm(e1+l-e2, dim=2, p=self.p_norm)

            # get index positions of sorted norms for each triplet
            rankings = torch.vstack([torch.argsort(x) for x in norms.unbind(dim=0)])

            # find position of heads within the rankings
            rank_pos = torch.where(rankings == edge_list[:,pos].reshape(-1,1))[1]

            rank_pos_list.append(rank_pos)
            true_pos_total.append(rank_pos < 10)

        mean_rank = torch.vstack(rank_pos_list).float().mean()
        hits_at_10 = torch.vstack(true_pos_total).float().mean()*100

        return mean_rank, hits_at_10

    def training_step(self, batch, batch_idx):
        return self.embedding_loss(batch)

    def validation_step(self, batch, batch_idx):
        loss = self.embedding_loss(batch)
        mean_rank, hits_at_10 = self.evaluation_protocol(batch)
        metrics = {"val_loss": loss, "val_mean_rank": mean_rank, "val_hits@10": hits_at_10}
        self.log_dict(metrics, prog_bar=True, on_epoch=True)
        return metrics

    def test_step(self, batch, batch_idx):
        loss = self.embedding_loss(batch)
        mean_rank, hits_at_10 = self.evaluation_protocol(batch)
        metrics = {"test_loss": loss, "test_mean_rank": mean_rank, "test_hits@10": hits_at_10}
        self.log_dict(metrics, prog_bar=True, on_epoch=True)
        return metrics

    def on_train_epoch_end(self):
        with torch.no_grad():
             # keep entities embeddings normalized
            self.entity_mat.weight.copy_(nn.functional.normalize(self.entity_mat.weight, p=2, dim=1))

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        

In [44]:
emb_dim = 20
lr = 0.01
margin = 2
max_epochs = 1000
top_k_cp = 3
p_norm = 2   # norm either L1 or L2

# instantiated model and data module
model = TransE(emb_dim=emb_dim,
               learning_rate=lr,
               margin=margin,
               p_norm=p_norm)

dm = WordNetRawDataModule(batch_size=32)

dir_path = f"checkpoints/emb_dim={emb_dim}-lr={lr}-margin={margin}-p_norm={p_norm}"

# using mean predicted rank on validation set as described in section 4.2
early_stop_rank = EarlyStopping(monitor="val_mean_rank",
                                min_delta=0.5,
                                patience=10,
                                verbose=False,
                                mode="min")

# save best models based on mean rank on validation set
checkpoint_callback = ModelCheckpoint(save_top_k=top_k_cp,
                                      monitor="val_mean_rank",
                                      dirpath=dir_path,
                                      filename="transe-wordnet-{epoch}-{val_mean_rank:.0f}")

logger = TensorBoardLogger('tb_logs', name='TransE')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='gpu',
                     callbacks=[checkpoint_callback,early_stop_rank],
                     logger=logger)

try:
    # resume from best model if checkpoint is available
    ckpt_path = os.path.join(dir_path, os.listdir(dir_path)[-1])
except:
    ckpt_path = None

trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | entity_mat   | Embedding | 818 K 
1 | relation_mat | Embedding | 220   
-------------------------------------------
819 K     Trainable params
0         Non-trainable params
819 K     Total params
3.276     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 16:  11%|█         | 300/2809 [00:03<00:32, 76.22it/s, loss=68, v_num=49, val_loss=40.70, val_mean_rank=3.86e+3, val_hits@10=26.20]   

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
