Write implementation on PyTorch for TransE model (you can use TorchGeometric or DGL library for working with graphs) and train your model on WordNet18RR dataset (you can use loaded dataset from any graph library).

As a result, you must provide a link to github (or gitlab) with all the source code.
The readability of the code, the presence of comments, type annotations, and the quality of the code as a whole will be taken into account when checking the test case.

### Imports and helper Functions

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Union, Callable, Optional

import torch
from torch.utils.data import Dataset
from torch import nn
from torch_geometric.datasets import WordNet18RR
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl



c:\Users\Marco\AppData\Local\Programs\Python\Python39\lib\site-packages\numpy\.libs\libopenblas.EL2C6PLE4ZYW3ECEVIV3OXXGRN2NRFM2.gfortran-win_amd64.dll
c:\Users\Marco\AppData\Local\Programs\Python\Python39\lib\site-packages\numpy\.libs\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll


In [2]:
# download wordnet dataset, we'll be using the processed file data.pt
dataset = WordNet18RR('./WordNet18RR/')

In [3]:
l1_norm = lambda x, y: torch.linalg.norm((x-y), ord=1)
l2_norm = lambda x, y: torch.linalg.norm((x-y), ord=2)

def normalize(x: Union[torch.Tensor, np.ndarray], axis: int=1):
    """ Normalize the matrix x along the specified axis

    Args:
        x (Union[torch.Tensor, np.ndarray]): _description_
        axis (int, optional): 0 = columns, 1 = rows. Defaults to 1.

    Returns:
        (Union[torch.Tensor, np.ndarray]): returns a matrix with the same dtype as the input
    """
    return_tensor = False

    if x.dtype == torch.Tensor:
        x = x.numpy()
        return_tensor = True
    
    x = np.apply_along_axis(func1d=lambda x: x / np.linalg.norm(x), arr=x, axis=axis)

    if return_tensor:
        return torch.from_numpy(x)
    else:
        return x

### Custom Dataset and DataLoader

In [4]:
import csv

class Edge():
    def __init__(self, u, v, label) -> None:
        self.u = u
        self.v = v
        self.label = label

    def __str__(self) -> str:
        return f"{self.u} {self.label} {self.v}"

def load_edge_list_from_file(path: str, header: bool=False):
    edge_list = list()

    with open(path, "r") as f:
        tsv_reader = csv.reader(f, delimiter="\t")

        if header:
            next(tsv_reader)

        for row in tsv_reader:
            u, label, v = row
            edge_list.append(Edge(u=u, v=v, label=label))
    
    return edge_list

In [5]:
class WordNetEdgeDataset(Dataset):
    def __init__(self, path: str="WordNet18RR/processed/data.pt", split: str="train") -> None:
        super().__init__()
        data = torch.load(path)[0]
        mask_dict = {"train": data.train_mask, "test": data.test_mask, "val": data.val_mask}
        mask = mask_dict[split]
        self.edge_list = data.edge_index.T[mask, :]
        self.edge_labels = data.edge_type[mask]
    
    def __len__(self):
        return self.edge_list.shape[0]

    def __getitem__(self, index) -> int:
        return self.edge_list[index,:], self.edge_labels[index]
    

In [6]:
class WordNetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str="WordNet18RR/processed/data.pt", batch_size=32) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_entities = 40943
        self.num_relations = 11

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = WordNetEdgeDataset(split="train", path=self.data_dir)
            self.val_dataset = WordNetEdgeDataset(split="val", path=self.data_dir)
        
        if stage == "test":
            self.test_dataset = WordNetEdgeDataset(split="test", path=self.data_dir)
            self.test_loader = DataLoader(self.test_dataset, batch_size=32, shuffle=True)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True)


### Lightning Model

In [7]:
def sample_corrupted_triplet(x: torch.Tensor, n: int):
    """ sample either the head or tail of x from range(n)

    Args:
        x (torch.Tensor): a pair of ints
        n (int): number of entities to sample from
    """
    # sample corrupted triplet
    idx = int(np.random.rand(1) < 0.5)  # pick either head or tail
    s = x.detach().clone()
    while True:
        s[idx] = np.random.choice(range(dm.num_entities))   # resample either head or tail

        # make sure the triples are different
        if s[idx] != x[idx]:
            break
    
    return s

In [10]:
class TransE(pl.LightningModule):
    def __init__(self, margin: int=1, emb_dim: int=50, learning_rate=0.01,
                 distance_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]=l1_norm) -> None:
        """ Instatiate the entity and relation matrix of the TransE model
            https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data

        Args:
            n_entities (int): _description_
            n_relations (int): _description_
            margin (int, optional): _description_. Defaults to 1.
            emb_dim (int, optional): _description_. Defaults to 50.
            distance_func (Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional): _description_. Defaults to l1_norm.
        """
        super().__init__()
        self.margin = margin
        self.emb_dim = emb_dim
        self.distance_func = distance_func
        self.learning_rate = learning_rate

        # dataset specific values
        self.num_entities = 40943
        self.num_relations = 11

        # initialize embeddings
        self.entity_mat = nn.Embedding(self.num_entities, emb_dim)
        self.relation_mat = nn.Embedding(self.num_relations, emb_dim)
        

        with torch.no_grad():
            # initialize with random uniform
            val = 6/np.sqrt(emb_dim)
            self.entity_mat.weight.uniform_(-val, val)
            self.relation_mat.weight.uniform_(-val, val)

            # normalize each embedding vector
            self.entity_mat.weight.copy_(nn.functional.normalize(self.entity_mat.weight, p=2, dim=1))
            self.relation_mat.weight.copy_(nn.functional.normalize(self.relation_mat.weight, p=2, dim=1))        

    def training_step(self, batch, batch_idx):
        """ batch is the list of ids of edge within the batch """
        edge_list, labels = batch
        
        loss = torch.zeros(1)

        for i in range(edge_list.shape[0]):
            x = edge_list[i]
            x_corrupted = sample_corrupted_triplet(x, n=self.num_entities)

            # take embedding values for entities and relation
            h1, t1, = self.entity_mat.weight[x,:]
            h2, t2 = self.entity_mat.weight[x_corrupted,:]
            l = self.relation_mat.weight[labels[i]]

            # compute the loss value
            val = self.margin + torch.norm(h1 + l - t1) - torch.norm(h2 + l - t2)

            nn.functional.relu(val, inplace=True) # take positive part
            
            loss += val
        
        return loss

    def on_training_epoch_end(self, epoch_idx: int):
        with torch.nograd():
             # keep embeddings normalized
            self.entity_mat.weight.copy_(nn.functional.normalize(self.entity_mat.weight, p=2, dim=1))
            self.relation_mat.weight.copy_(nn.functional.normalize(self.relation_mat.weight, p=2, dim=1))       
    
    def validation_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        

In [11]:
model = TransE(emb_dim=5)
dm = WordNetDataModule()
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

  | Name         | Type      | Params
-------------------------------------------
0 | entity_mat   | Embedding | 204 K 
1 | relation_mat | Embedding | 55    
-------------------------------------------
204 K     Trainable params
0         Non-trainable params
204 K     Total params
0.819     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:  64%|██████▍   | 1750/2714 [10:20<05:41,  2.82it/s, loss=26.8, v_num=3]
Epoch 0:  43%|████▎     | 1160/2714 [02:41<03:35,  7.20it/s, loss=29.3, v_num=4]