Write implementation on PyTorch for TransE model (you can use TorchGeometric or DGL library for working with graphs) and train your model on WordNet18RR dataset (you can use loaded dataset from any graph library).

As a result, you must provide a link to github (or gitlab) with all the source code.
The readability of the code, the presence of comments, type annotations, and the quality of the code as a whole will be taken into account when checking the test case.

### Imports and helper Functions

In [1]:
import os
import csv
import numpy as np
import matplotlib.pyplot as plt
from typing import Union, Callable, Optional, Tuple

import torch
from torch.utils.data import Dataset
from torch import nn
import torch.nn.functional as F
from torch_geometric.datasets import WordNet18RR
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

torch.cuda.is_available()

True

In [2]:
# download wordnet dataset, we'll be using the processed file data.pt
dataset = WordNet18RR('./WordNet18RR/')

### Custom Dataset and DataLoader

In [3]:
class Edge():
    def __init__(self, head, tail, rel) -> None:
        self.head = head
        self.tail = tail
        self.rel = rel

    def __str__(self) -> str:
        return f"{self.head} {self.rel} {self.tail}"

def process_lines(lines):
    return list(map(lambda s: s.strip('\n').split('\t'), lines))

def load_edges_from_file(path: str, header: bool=False):
    edge_list = list()

    lines = open(path).readlines()
    lines = process_lines(lines)

    edge_list = [Edge(head=head, tail=tail, rel=rel) for head, rel, tail in lines]
    
    return edge_list

def load_ids_dict(path):
    entity2id = process_lines(open(os.path.join(WORDNET_PATH, "entity2id.txt")))
    relation2id = process_lines(open(os.path.join(WORDNET_PATH, "relation2id.txt")))

    entity2id = dict([(x[0], int(x[1])) for x in entity2id])
    relation2id = dict([(x[0], int(x[1])) for x in relation2id])

    return entity2id, relation2id

# create entity2id and relation2id files
WORDNET_PATH = "WordNet18RR/raw/"

train_edge_list = load_edges_from_file(os.path.join(WORDNET_PATH, "train.txt"))
val_edge_list = load_edges_from_file(os.path.join(WORDNET_PATH, "valid.txt"))
test_edge_list = load_edges_from_file(os.path.join(WORDNET_PATH, "test.txt"))

entity_list = list()
relation_list = list()

for edge_list in [train_edge_list, val_edge_list, test_edge_list]:
    entity_list += [x.head for x in edge_list] + [x.tail for x in edge_list]
    relation_list += [x.rel for x in edge_list]

entity_list = sorted(list(set(entity_list)))
entity2id = dict(zip(entity_list, range(len(entity_list))))

relation_list = sorted(list(set(relation_list)))
relation2id = dict(zip(relation_list, range(len(relation_list))))

with open("WordNet18RR/raw/entity2id.txt", "w") as f:
    f.writelines([f"{x}\t{y}\n" for x,y in entity2id.items()])

with open("WordNet18RR/raw/relation2id.txt", "w") as f:
    f.writelines([f"{x}\t{y}\n" for x,y in relation2id.items()])

In [4]:
class WordNetDataset(Dataset):
    def __init__(self, path: str="WordNet18RR/raw/", split="train") -> None:
        super().__init__()
        self.path = path

        if split == 'val':
            split = 'valid'
        self.split = split

        edge_list = load_edges_from_file(os.path.join(self.path, f"{self.split}.txt"))
        entity2id, relation2id = load_ids_dict(path)

        self.edge_list = torch.tensor([(entity2id[e.head], entity2id[e.tail]) for e in edge_list])
        self.relation_list = torch.tensor([relation2id[e.rel] for e in edge_list])
    
    def __len__(self):
        return self.edge_list.shape[0]

    def __getitem__(self, index) -> Tuple[int,int]:
        return self.edge_list[index], self.relation_list[index]

class WordNetDataModule(pl.LightningDataModule):
    def __init__(self, path: str="WordNet18RR/raw/", batch_size=32) -> None:
        super().__init__()
        self.path = path
        self.batch_size = batch_size
        self.num_entities = 40943
        self.num_relations = 11
        self.params = {"pin_memory": True, "batch_size": batch_size}

    def setup(self, stage: str):
        if stage == "fit":
            self.train_dataset = WordNetDataset(path=self.path, split="train")
            self.val_dataset = WordNetDataset(path=self.path, split="valid")
        
        if stage == "predict":
            self.test_dataset = WordNetDataset(path=self.path, split="test")
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, shuffle=True, **self.params)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, shuffle=False, **self.params)

    def predict_dataloader(self):
        return DataLoader(self.test_dataset, shuffle=False, **self.params)


### Lightning Model

In [69]:
class TransE(pl.LightningModule):
    def __init__(self, margin: int=1, emb_dim: int=50, learning_rate=0.01, p_norm=1) -> None:
        """ Instatiate the entity and relation matrix of the TransE model
            https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data

        Args:
            n_entities (int): _description_
            n_relations (int): _description_
            margin (int, optional): _description_. Defaults to 1.
            emb_dim (int, optional): _description_. Defaults to 50.
        """
        super().__init__()
        self.margin = margin
        self.emb_dim = emb_dim
        self.learning_rate = learning_rate
        self.p_norm = p_norm

        # dataset specific values
        self.num_entities = 40943
        self.num_relations = 11

        # initialize embeddings
        self.entity_mat = nn.Embedding(self.num_entities, emb_dim).to(self.device)
        self.relation_mat = nn.Embedding(self.num_relations, emb_dim).to(self.device)

        with torch.no_grad():
            # initialize with random uniform
            val = 6/np.sqrt(emb_dim)
            self.entity_mat.weight.uniform_(-val, val)
            self.relation_mat.weight.uniform_(-val, val)

            # normalize entity and relation embeddings
            self.entity_mat.weight.copy_(F.normalize(self.entity_mat.weight, p=self.p_norm, dim=-1))
            self.relation_mat.weight.copy_(F.normalize(self.relation_mat.weight, p=self.p_norm, dim=-1))

    def corrupt_edge_list(self, edge_list: torch.Tensor):
        """ sample either the head or tail of x from range(n) """
        n = edge_list.shape[0]
        entity_list = range(self.num_entities)

        # sample random entity replacements
        r1 = np.random.choice(entity_list, size=n)
        r2 = np.random.choice(entity_list, size=n)

        corrupted_heads = edge_list.detach().clone()
        corrupted_tails = edge_list.detach().clone()

        corrupted_heads[:,0] = torch.from_numpy(r1)
        corrupted_tails[:,1] = torch.from_numpy(r2)
                
        return corrupted_heads, corrupted_tails
    
    def embedding_loss(self, batch):
        edge_list, labels = batch
        
        loss = torch.zeros(1).to(self.device)

        #edge_list_cor = self.corrupt_edge_list(edge_list)
        corrupted_heads, corrupted_tails = self.corrupt_edge_list(edge_list)
        
        # take embedding values for entities and relations
        t1 = self.entity_mat.weight[edge_list.repeat(2,1)]
        t2 = torch.vstack([self.entity_mat.weight[corrupted_heads],
                           self.entity_mat.weight[corrupted_tails]])
        rel = self.relation_mat.weight[labels].repeat(2,1)

        # normalize entity (maybe unnecessary here)
        #t1 = F.normalize(t1, p=self.p_norm, dim=-1)
        #t2 = F.normalize(t2, p=self.p_norm, dim=-1)

        # compute the loss value
        pos = torch.norm(t1[:,0,:] + rel - t1[:,1,:], dim=-1, p=self.p_norm)
        neg = torch.norm(t2[:,0,:] + rel - t2[:,1,:], dim=-1, p=self.p_norm)
        loss = torch.clip((self.margin + pos - neg), min=0).sum()

        return loss

    def evaluation_protocol(self, batch):
        edge_list, labels = batch
        batch_size = edge_list.shape[0]

        # combine heads, tails and labels
        triplets = torch.hstack([edge_list, labels.reshape(-1,1)])

        # repeat all triplets for n_entities times
        triplets = triplets[:,np.newaxis,:].repeat(1,self.num_entities,1)

        true_pos_total = list()
        rank_pos_list = list()

        # repeat corruption for both head and tail
        for pos in [0,1]:
            x = triplets.detach().clone()
            
            # replace all heads/tails with list of all possible entities
            x[:,:,pos] = torch.tensor(range(self.num_entities))[np.newaxis,:].repeat(batch_size,1).to(self.device)

            # triplets are arranged as (head, tail, label)
            head = self.entity_mat.weight[x[:,:,0]]
            tail = self.entity_mat.weight[x[:,:,1]]
            rel = self.relation_mat.weight[x[:,:,2]]

            # compute distance between head + label and tail
            norms = torch.norm(head + rel - tail, dim=-1, p=self.p_norm)

            # get index positions of sorted norms for each triplet
            rankings = torch.vstack([torch.argsort(x) for x in norms.unbind(dim=0)])

            # find position of heads within the rankings
            torch.save(rankings, "rankings.pt")
            torch.save(edge_list, "edge_list.pt")
            rank_pos = torch.where(rankings == edge_list[:,pos].reshape(-1,1))[1]

            rank_pos_list.append(rank_pos)
            true_pos_total.append(rank_pos < 10)

        mean_rank = torch.vstack(rank_pos_list).float().mean()
        hits_at_10 = torch.vstack(true_pos_total).float().mean()*100

        return torch.vstack(rank_pos_list).flatten()

    def training_step(self, batch, batch_idx):
        return self.embedding_loss(batch)

    def validation_step(self, batch, batch_idx):
        loss = self.embedding_loss(batch)
        batch_rankings = self.evaluation_protocol(batch)
        self.log_dict({"val_loss": loss}, prog_bar=True, on_epoch=True)
        return {"val_loss": loss, "batch_rankings": batch_rankings}
    
    def predict_step(self, batch, batch_idx):
        mean_rank, hits_at_10 = self.evaluation_protocol(batch)
        return mean_rank, hits_at_10

    def on_train_epoch_end(self):
        with torch.no_grad():
             # keep entities embeddings normalized
            self.entity_mat.weight.copy_(F.normalize(self.entity_mat.weight, p=2, dim=1))
    
    def compute_epoch_metrics(self, outputs, stage):
        epoch_rankings = torch.hstack([x['batch_rankings'] for x in outputs])
        mean_rank = epoch_rankings.float().float().mean()
        hit_at_10 = (epoch_rankings < 10).float().mean()*100
        self.log_dict({f"{stage}_mean_rank": mean_rank, f"{stage}_hits@10": hit_at_10},
                      prog_bar=True, on_epoch=True)
    
    def validation_epoch_end(self, outputs):
        self.compute_epoch_metrics(outputs, stage="val")

    def predict_epoch_end(self, outputs):
        self.compute_epoch_metrics(outputs, stage="predict")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        

In [70]:
emb_dim = 30
lr = 0.001
margin = 1
max_epochs = 1000
top_k_cp = 3
p_norm = 2   # norm either L1 or L2

# instantiated model and data module
model = TransE(emb_dim=emb_dim,
               learning_rate=lr,
               margin=margin,
               p_norm=p_norm)

dm = WordNetDataModule(batch_size=32)

dir_path = f"checkpoints/emb_dim={emb_dim}-lr={lr}-margin={margin}-p_norm={p_norm}"

# using mean predicted rank on validation set as described in section 4.2
early_stop_rank = EarlyStopping(monitor="val_mean_rank",
                                min_delta=0.5,
                                patience=10,
                                verbose=False,
                                mode="min")

# save best models based on mean rank on validation set
checkpoint_callback = ModelCheckpoint(save_top_k=top_k_cp,
                                      monitor="val_mean_rank",
                                      dirpath=dir_path,
                                      filename="transe-wordnet-{epoch}-{val_mean_rank:.0f}-{val_hits@10:.1f}")

logger = TensorBoardLogger('tb_logs', name='TransE')

trainer = pl.Trainer(max_epochs=max_epochs,
                     accelerator='gpu',
                     callbacks=[checkpoint_callback,early_stop_rank],
                     logger=logger)

try:
    # resume from best model if checkpoint is available
    ckpt_path = os.path.join(dir_path, os.listdir(dir_path)[-1])
except:
    ckpt_path = None

trainer.fit(model, datamodule=dm, ckpt_path=ckpt_path)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 13:  34%|███▍      | 965/2809 [11:26<21:51,  1.41it/s, loss=18.1, v_num=21, val_loss=19.40, val_mean_rank=4.79e+3, val_hits@10=23.70]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type      | Params
-------------------------------------------
0 | entity_mat   | Embedding | 1.2 M 
1 | relation_mat | Embedding | 330   
-------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.914     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

In [52]:
val_outputs = torch.load("val_outputs.pt")
torch.hstack([x['batch_rankings'] for x in val_outputs]).float().mean()

tensor(10923.9092, device='cuda:0')

In [58]:
(torch.hstack([x['batch_rankings'] for x in val_outputs]) < 10).float().shape

torch.Size([6068])

In [47]:
pred = trainer.predict(model, datamodule=dm, ckpt_path=ckpt_path)
test_mean_rank, test_hits_at_10 = torch.tensor(pred).mean(0)
print("\n")
print(f"test_mean_rank={test_mean_rank:.0f}, test_hits@10={test_hits_at_10:.2f}%")

Restoring states from the checkpoint path at checkpoints/emb_dim=20-lr=0.01-margin=1-p_norm=2\transe-wordnet-epoch=0-val_mean_rank=20402-val_hits@10=0.0.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at checkpoints/emb_dim=20-lr=0.01-margin=1-p_norm=2\transe-wordnet-epoch=0-val_mean_rank=20402-val_hits@10=0.0.ckpt
  rank_zero_warn(


Predicting DataLoader 0: 100%|██████████| 98/98 [00:06<00:00, 14.84it/s]
test_mean_rank=20634, test_hits@10=0.03
