Write implementation on PyTorch for TransE model (you can use TorchGeometric or DGL library for working with graphs) and train your model on WordNet18RR dataset (you can use loaded dataset from any graph library).

As a result, you must provide a link to github (or gitlab) with all the source code.
The readability of the code, the presence of comments, type annotations, and the quality of the code as a whole will be taken into account when checking the test case.

### Imports and helper Functions

In [149]:
import numpy as np
import matplotlib.pyplot as plt
from typing import Union, Callable, Optional

import torch
from torch.utils.data import Dataset
from torch import nn
from torch_geometric.datasets import WordNet18RR
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl



In [None]:
l1_norm = lambda x, y: torch.linalg.norm((x-y), ord=1)
l2_norm = lambda x, y: torch.linalg.norm((x-y), ord=2)

def normalize(x: Union[torch.Tensor, np.ndarray], axis: int=1):
    """ Normalize the matrix x along the specified axis

    Args:
        x (Union[torch.Tensor, np.ndarray]): _description_
        axis (int, optional): 0 = columns, 1 = rows. Defaults to 1.

    Returns:
        (Union[torch.Tensor, np.ndarray]): returns a matrix with the same dtype as the input
    """
    return_tensor = False

    if x.dtype == torch.Tensor:
        x = x.numpy()
        return_tensor = True
    
    x = np.apply_along_axis(func1d=lambda x: x / np.linalg.norm(x), arr=x, axis=axis)

    if return_tensor:
        return torch.from_numpy(x)
    else:
        return x

### Custom Dataset and DataLoader

In [None]:
class WordNetDataset(Dataset):
    def __init__(self, path="WordNet18RR/raw/") -> None:
        super().__init__()
        
    
    def __len__(self):
        pass

    def __getitem__(self, index) -> int:
        return super().__getitem__(index)
    

In [155]:
import csv

class Edge():
    def __init__(self, u, v, label) -> None:
        self.u = u
        self.v = v
        self.label = label

with open("WordNet18RR/raw/train.txt") as f:
    tsv_reader = csv.reader(f, delimiter="\t")

next(tsv_reader)

edge_list = list()

for row in tsv_reader:
    print(row)
    break

ValueError: I/O operation on closed file.

### Lightning Model

In [84]:
class TransE(pl.LightningModule):
    def __init__(self, n_entities: int, n_relations: int, margin: int=1, emb_dim: int=50,
                 distance_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]=l1_norm) -> None:
        """ Instatiate the entity and relation matrix of the TransE model
            https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data

        Args:
            n_entities (int): _description_
            n_relations (int): _description_
            margin (int, optional): _description_. Defaults to 1.
            emb_dim (int, optional): _description_. Defaults to 50.
            distance_func (Callable[[torch.Tensor, torch.Tensor], torch.Tensor], optional): _description_. Defaults to l1_norm.
        """
        super().__init__()
        self.margin = margin
        self.emb_dim = emb_dim
        self.distance_func = distance_func

        val = 6/np.sqrt(emb_dim)
        entity_size = (n_entities, emb_dim)
        relation_size = (n_relations, emb_dim)

        # instantiate matrices
        entity_mat = np.random.uniform(low=-val, high=val, size=entity_size)
        relation_mat = np.random.uniform(low=val, high=val, size=relation_size)

        # normalize matrices
        entity_mat = normalize(entity_mat)
        relation_mat = normalize(relation_mat)

        # cast to torch
        self.entity_mat = torch.from_numpy(entity_mat)
        self.relation_mat = torch.from_numpy(relation_mat)
    
    def training_step(self, batch, batch_idx):
        """ batch is the list of ids of edge within the batch """
        return self.relation_mat 
        

In [105]:
dataset = WordNet18RR('./WordNet18RR/')


In [125]:
data = torch.load("./WordNet18RR/processed/data.pt")[0]
loader = DataLoader(dataset, batch_size=32, shuffle=True)

TypeError: __init__() missing 1 required positional argument: 'link_sampler'

In [88]:
model = TransE(n_entities=data.num_nodes,
               n_relations=data.num_edges)

In [131]:
from torch_geometric.transforms.random_link_split import RandomLinkSplit

rls = RandomLinkSplit()
train_data, val_data, test_data = rls(data)

In [140]:
edge_loader = DataLoader(range(train_data.edge_index.shape[1]), batch_size=32, shuffle=True)

In [146]:
data

Data(edge_index=[2, 93003], edge_type=[93003], train_mask=[93003], val_mask=[93003], test_mask=[93003], num_nodes=40943)

In [None]:
class 

In [None]:
class WordNetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str="WordNet18RR/raw", batch_size: int=32) -> None:
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
    
    def setup(self, stage: Optional[str]=None):
        