In [None]:
#@title OpenBioLink Module Install EXTERNAL{ form-width: "15%" }
! pip install openbiolink

In [2]:
#@title PyG Installation EXTERNAL{ form-width: "15%" }
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

In [3]:
#@title Module Imports HEADER { form-width: "15%" }
from openbiolink.obl2021 import OBL2021Dataset
import torch
from torch.nn import Module,\
                     ModuleList,\
                     Embedding,\
                     BatchNorm1d,\
                     LogSoftmax,\
                     Softmax,\
                     Linear,\
                     NLLLoss,\
                     CrossEntropyLoss
from torch.optim import Adam
import torch.nn.functional as F
import torch_geometric as PyG
from torch_geometric.data import Data
from torch_geometric.nn.conv import RGCNConv
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from typing import NoReturn
import typing
from enum import Enum
from collections import defaultdict

In [4]:
#@title Global Variables CLASS { form-width: "15%" }
class Global(Enum):
  HEAD_INDEX = 0
  RELATION_INDEX = 1
  TAIL_INDEX = 2
  FEATURE_ENG = 'one-hot'
  NUM_RELATIONS = 28
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
  MINI_BATCH_SIZE = 32
  VAL_HITS_AT = 50
  TEST_HITS_AT = 50

In [None]:
#@title add_tail_to_head(data_split) FUNCTION { form-width: "15%" }
def add_tail_to_head(data_split: torch.Tensor, plus: int=28) -> torch.Tensor:
  heads = data_split[:, (0)]
  tails = data_split[:, (-1)]
  relations = data_split[:, (1)]
  
  tail_to_head = torch.vstack(
      (tails, relations + plus, heads)
  ).t()

  return torch.cat(
      (data_split, tail_to_head),
      dim=0
  )



In [None]:
#@title checksum(data_subset, graph_data) FUNCTION { form-width: "15%" }
def checksum(data_subset: torch.Tensor, graph_data: Data, step: int=500) -> NoReturn:
  for i in range(step, data_subset.shape[0], step):
    graph_head_indcs = graph_data.edge_index[0, i-step:i]
    data_head_entities = data_subset[i-step:i, 0]
    graph_head_entities = graph_data.x[graph_head_indcs].reshape(-1)

    checksum_head = (graph_head_entities == data_head_entities).sum().item()
    if not checksum_head == step:
      print('head')
      print(i)
      break

    graph_relation_indcs = graph_data.edge_type[i-step:i]
    data_relation_types = data_subset[i-step:i, 1]

    checksum_relation = (graph_relation_indcs == data_relation_types).sum().item()
    if not checksum_relation == step:
      print('relation')
      print(i)
      break

    graph_tail_indcs = graph_data.edge_index[1, i-step:i]
    data_tail_entities = data_subset[i-step:i, 2]
    graph_tail_entities = graph_data.x[graph_tail_indcs].reshape(-1)

    checksum_tail = (graph_tail_entities == data_tail_entities).sum().item()
    if not checksum_head == step:
      print('tail')
      print(i)
      break

  else:
    print('All clear :)')

In [None]:
#@title partial_graph_data_maker(dataset, x_feature) FUNCTION{ form-width: "15%" }
def partial_graph_data_maker(messaging: torch.Tensor, supervision: torch.Tensor, negative_samples: torch.Tensor, x_feature: str=Global.FEATURE_ENG.value, check_for_correctness: bool=False) -> Data:
  relation_idx = Global.RELATION_INDEX.value
  head_idx = Global.HEAD_INDEX.value
  tail_idx = Global.TAIL_INDEX.value


  if x_feature == 'one-hot':
    one_hot_index = 0
    edge_index_list_flat = []
    seen_dict = defaultdict(lambda: -1)
    edge_index_flat = messaging[:, (head_idx, tail_idx)].reshape(-1)

    for i in range(edge_index_flat.shape[0]):
      entity = edge_index_flat[i].item()
      if seen_dict[entity] == -1:
        edge_index_list_flat.append(one_hot_index)
        seen_dict[entity] = one_hot_index
        one_hot_index += 1
      else:
        edge_index_list_flat.append(seen_dict[entity])
      if i + 1 % 50000 == 0:
        print(f'Entity {i} encoded. {i / edge_index_flat.shape[0] * 100:.2f} %')
    edge_index_messaging = torch.tensor(edge_index_list_flat).reshape(-1, 2)
    ##############################################################
    edge_index_flat = supervision[:, (head_idx, tail_idx)].reshape(-1)
    edge_index_list_flat = []

    for i in range(edge_index_flat.shape[0]):
      entity = edge_index_flat[i].item()
      if seen_dict[entity] == -1:
        edge_index_list_flat.append(one_hot_index)
        seen_dict[entity] = one_hot_index
        one_hot_index += 1
      else:
        edge_index_list_flat.append(seen_dict[entity])
      if i + 1 % 50000 == 0:
        print(f'Entity {i} encoded. {i / edge_index_flat.shape[0] * 100:.2f} %')
    edge_index_supervision = torch.tensor(edge_index_list_flat).reshape(-1, 2)
    ##############################################################
    edge_index_flat = negative_samples[:, (head_idx, tail_idx)].reshape(-1)
    edge_index_list_flat = []

    for i in range(edge_index_flat.shape[0]):
      entity = edge_index_flat[i].item()
      if seen_dict[entity] == -1:
        edge_index_list_flat.append(one_hot_index)
        seen_dict[entity] = one_hot_index
        one_hot_index += 1
      else:
        edge_index_list_flat.append(seen_dict[entity])
      if i + 1 % 50000 == 0:
        print(f'Entity {i} encoded. {i / edge_index_flat.shape[0] * 100:.2f} %')
    edge_index_negative = torch.tensor(edge_index_list_flat).reshape(-1, 2)
    ##############################################################
    x = torch.tensor(list(seen_dict.keys())).reshape(-1, 1)

    graph_data = Data(
        x=x,
        edge_index_messaging=edge_index_messaging.t().contiguous(),
        edge_type_messaging=messaging[:, relation_idx],
        edge_index_supervision=edge_index_supervision.t().contiguous(),
        edge_type_supervision=supervision[:, relation_idx],
        edge_index_negative=edge_index_negative.t().contiguous(),
        edge_type_negative=negative_samples[:, (relation_idx)]
    )
  else:
    raise NotImplementedError('This functionality has not been implemented yet.')

  if check_for_correctness:
    checksum(dataset, graph_data)
  return graph_data

In [None]:
#@title visualize_graph(graph_data, height, width) FUNCTION{ form-width: "15%" }
def visualize_graph(graph_data: Data, height: int=10, width:int=10) -> NoReturn:
  nx_graph = to_networkx(graph_data)

  pos = nx.spring_layout(nx_graph)

  fig = plt.gcf()
  fig.set_size_inches(width, height)

  edge_labels = dict()
  ei_np = graph_data.edge_index.t().numpy()
  for edge in nx_graph.edges():
    e = np.array(edge)
    idx = np.where(ei_np == e)[0][0]
    label = graph_data.edge_type[idx].item()
    edge_labels.update({edge: label})
    
  nx.draw_networkx_nodes(nx_graph, pos)
  nx.draw_networkx_edges(nx_graph, pos, connectionstyle='arc3,rad=0.2')
  nx.draw_networkx_labels(nx_graph, pos, labels={n:graph_data.x[n].item() for n in nx_graph})
  nx.draw_networkx_edge_labels(nx_graph, pos, edge_labels=edge_labels)

  fig.show()


In [5]:
#@title graph_data_maker(dataset, x_feature) FUNCTION{ form-width: "15%" }
def graph_data_maker(messaging: torch.Tensor, supervision: torch.Tensor, negative_samples: torch.Tensor, x:torch.Tensor, x_feature: str=Global.FEATURE_ENG.value, check_for_correctness: bool=False) -> Data:
  relation_idx = Global.RELATION_INDEX.value
  head_idx = Global.HEAD_INDEX.value
  tail_idx = Global.TAIL_INDEX.value
  graph_data = Data(
        x=x.reshape(-1, 1),
        edge_index_messaging=messaging[:, (head_idx, tail_idx)].t().contiguous(),
        edge_type_messaging=messaging[:, relation_idx],
        edge_index_supervision=supervision[:, (head_idx, tail_idx)].t().contiguous(),
        edge_type_supervision=supervision[:, relation_idx],
        edge_index_negative=negative_samples[:, (head_idx, tail_idx)].t().contiguous(),
        edge_type_negative=negative_samples[:, (relation_idx)]
    )
  return graph_data


In [6]:
#@title mini_batch_maker() FUNCTION { form-width: "15%" }
def mini_batch_maker(messaging, supervision, candidates, x_feature='one-hot'):
  heads = supervision[:, 0]
  relations = supervision[:, 1]
  tails = supervision[:, -1]

  ct_size = supervision.shape[0] // 2
  ch_size = supervision.shape[0] - ct_size

  negative_samples_corrupted_tails = torch.vstack(
      (
          heads[: ct_size], 
          relations[: ct_size], 
          torch.multinomial(
              candidates.type(torch.float).to(Global.DEVICE.value), 
              ct_size
          )
      )
  ).t().contiguous()

  negative_samples_corrupted_heads = torch.vstack(
      ( 
          torch.multinomial(
              candidates.type(torch.float).to(Global.DEVICE.value), 
              ch_size
          ), 
          relations[ch_size:], 
          tails[ch_size:]
      )
  ).t().contiguous()

  negative_samples = torch.cat(
      (negative_samples_corrupted_heads, negative_samples_corrupted_tails),
      dim=0
  )

  # for __index, candidate in enumerate(negative_samples):
  #   is_corrupt = not any(np.equal(knowledge_graph.numpy(), candidate.numpy().tolist()).all(1))
  #   if not is_corrupt:
  #     print(f'{candidate} is not a true corrupted edge')

  graph = graph_data_maker(
      messaging=messaging,
      supervision=supervision,
      negative_samples=negative_samples,
      x=candidates.to(Global.DEVICE.value),
      x_feature='one-hot'
  )

  return graph

In [7]:
#@title GNN Model CLASS{ form-width: "10%" }
class GNN(Module):
  def __init__(self, conv_dims: list, fully_connected_dims: list, x_feature:str, dropout: dict, embedding_dims: tuple=None)-> NoReturn:
    super(GNN, self).__init__()
    self.mode = None # 'train' or 'test' or 'dev' later 
    self.num_relations = Global.NUM_RELATIONS.value
    self.dropout = dropout
    self.x_feature = x_feature
    relation_weights_list = [Linear(conv_dims[-1], conv_dims[-1], bias=False) for _ in range(28)]
    self.relation_weights = ModuleList(relation_weights_list)
    if x_feature == 'one-hot':
      #one-hot to latent
      self.embed = Embedding(embedding_dims[0], embedding_dims[1])
      first_conv_layer = [RGCNConv(embedding_dims[1], conv_dims[0], self.num_relations)]
    elif x_feature == 'identity':
      first_conv_layer = [RGCNConv(1, conv_dims[0], self.num_relations)]
    conv_list = first_conv_layer + \
                                [
                                  RGCNConv(conv_dims[i], conv_dims[i+1], self.num_relations)
                                  for i in range(len(conv_dims[:-1]))
                                ]
  

    # fully_connected_list =   [
    #                             Linear(conv_dims[-1], fully_connected_dims[0])
    #                          ] + \
    #                          [
    #                             Linear(fully_connected_dims[i], fully_connected_dims[i+1])
    #                             for i in range(len(fully_connected_dims[:-1]))
    #                          ] 

    #graph conv layers
    self.conv_layers = ModuleList(conv_list)

    #fully connected dense layers
    # self.fully_connected_layers = ModuleList(fully_connected_list)

    # self.classifier = LogSoftmax(dim=1)

    
  def reset_parameters(self):
    self.embed.reset_parameters
    for conv in self.conv_layers:
        conv.reset_parameters()
    # for fc in self.fully_connected_layers:
    #     fc.reset_parameters()
      

  def forward(self, data: Data) -> torch.Tensor:
    edge_index = data.edge_index_messaging
    x = data.x
    edge_type = data.edge_type_messaging

    # print(x.shape)
    # return

    # latent = list()
      # last_index = 0
      # for i in range(0, x.shape[0] - 1024, 1024):
      #   latent.append(self.embed(x[i: i+1024, :]))
      #   last_index = i
      # last_latent = self.embed(x[last_index:, : ]).reshape(-1, self.embed.weight.shape[1])
      # partial_latent_space = torch.stack(latent).reshape(-1, self.embed.weight.shape[1])
      # x = torch.cat(
      #     (partial_latent_space, last_latent),
      #     dim=0 
      # )

    ####################################### ONE - HOT #######################################
    if self.x_feature == 'one-hot':
      x = self.embed(x).reshape(self.embed.weight.shape[0], -1)

      
    ####################################### Encoder: RGCN #######################################
      
    ############################################## IDENTITY ################################################
    elif self.x_feature == 'identity':
      x = torch.ones(x.shape[0], 1) ############
      if self.training:
        x = F.dropout(x, p=self.dropout["emb"])
      print(f'Embedding {x.shape}')
    ####################################### Encoder: RGCN #######################################
    for conv in self.conv_layers[:-1]:
      x = conv(x, edge_index=edge_index, edge_type=edge_type)
      x = F.relu(x)
      if self.training:
        x = F.dropout(x, p=self.dropout["conv"])
    x = self.conv_layers[-1](x, edge_index, edge_type)
    if self.training:
      x = F.dropout(x, p=self.dropout["conv"])

    # x[head] + v[relation] -> x[tail]
    # < x[head] . v[relation] . x[tail] >

    ####################################### Decoder #######################################
    # return
    supervision_edge_index = torch.cat((data.edge_index_supervision, data.edge_index_negative), dim=1)
    supervision_edge_type = torch.cat((data.edge_type_supervision, data.edge_type_negative))
    which_one_is_which = torch.cat(
        (torch.zeros_like(data.edge_type_supervision), torch.ones_like(data.edge_type_negative))
    ).to(Global.DEVICE.value)

    random_permutation = torch.randperm(supervision_edge_type.shape[0])
    heads = supervision_edge_index[0][random_permutation]
    tails = supervision_edge_index[1][random_permutation]
    relations = supervision_edge_type[random_permutation]
    which_one_is_which = which_one_is_which[random_permutation]
    # # print(random_permutation.shape)
    # return 
    n_scores = list()
    p_scores = list()
    for index in range(heads.shape[0]):
      weights = self.relation_weights[relations[index]]
      head = x[heads[index]]
      tail = x[tails[index]] 
      score = weights(head) @ tail 
      if which_one_is_which[index].item() == 1:
        p_scores.append(score)
      elif which_one_is_which[index].item() == 0:
        n_scores.append(score)

    # for index, p_edge_relation in enumerate(data.edge_type_supervision):
    #   weights = self.relation_weights[p_edge_relation]
    #   head = x[data.edge_index_supervision[0, index]]
    #   tail = x[data.edge_index_supervision[1, index]]
    #   score = weights(head) @ tail 
    #   p_scores.append(score)
    return torch.stack(p_scores), torch.stack(n_scores)
    


In [None]:
#@title Model and Hyperparameters MAIN { form-width: "15%" }
dataset = OBL2021Dataset()
model = GNN(
    x_feature='one-hot',
    conv_dims=[32, 32, 32, 32],
    embedding_dims=(dataset.candidates.max() + 1, 32),
    fully_connected_dims=[1], 
    dropout={
        "emb": 0.1,
        "conv": 0.2,
        "fc": 0.01
    }
).to(Global.DEVICE.value)
print(model)
loss_fn = torch.nn.BCEWithLogitsLoss()
opt = Adam(model.parameters())

In [9]:
#@title train(model, graph, optimizer, loss_fn) FUNCTION{ form-width: "15%" }
def train(model: GNN, graph: Data, optimizer: torch.optim, loss_fn:torch.nn.modules.loss) -> tuple((torch.float, torch.Tensor, torch.Tensor)):
  optimizer.zero_grad()
  model.train()
  positive, negative = model(graph)
  loss = loss_fn(
            negative, 
            torch.zeros(graph.edge_index_negative.shape[1]).to(Global.DEVICE.value)
         ) + loss_fn(
                positive,
                torch.ones(graph.edge_index_supervision.shape[1]).to(Global.DEVICE.value)
            )
  loss.backward()
  optimizer.step()
  return loss.item(), torch.sigmoid(positive), torch.sigmoid(negative)

In [10]:
#@title evaluate(model, graph, hits_at) FUNCTION{ form-width: "15%" }
@torch.no_grad()
def evaluate(model: GNN, graph: Data, hits_at: int) -> torch.float:
  model.eval()
  positive, negative = model(graph)
  model.train()
  p_score = torch.stack((torch.sigmoid(positive), torch.ones_like(torch.sigmoid(positive), dtype=torch.bool)))
  n_score = torch.stack((torch.sigmoid(negative), torch.zeros_like(torch.sigmoid(negative), dtype=torch.bool)))
  combined_scores = torch.cat((p_score, n_score), dim=1)
  result = combined_scores[1][torch.sort(combined_scores[0], descending=True)[1]]
  return result[:hits_at].mean()

In [None]:
# @title Parameter and Hyperparameter Tuning MAIN{ form-width: "10%" }
model.reset_parameters()
train_set = dataset.training.to(Global.DEVICE.value) # torch.tensor of shape(num_train,3)
val_set = dataset.validation.to(Global.DEVICE.value) # torch.tensor of shape(num_val,3)
test_set = dataset.testing.to(Global.DEVICE.value)   # torch.tensor of shape(num_train,3)
# knowledge_graph = torch.cat(
#     (train_set, validation_set, test_set),
#     dim=0
# )
try:
  for epoch in range(5):
    iteration = 0
  ############################## TRAIN Graph Maker #####################################
    for i in tqdm(range(0, train_set.shape[0] - Global.MINI_BATCH_SIZE.value, Global.MINI_BATCH_SIZE.value)):
      iteration += 1
      train_supervision = train_set[i: i + Global.MINI_BATCH_SIZE.value, :]
      train_messaging = torch.cat(
          (train_set[: i, :], train_set[i + Global.MINI_BATCH_SIZE.value: , :]),
          dim=0
      )
      train_graph = mini_batch_maker(
          train_messaging,
          train_supervision,
          dataset.candidates,
          'one-hot'
      )
      loss, positive_score, negative_score = train(model, train_graph, opt, loss_fn)
      if iteration % 100 == 0:
        print(' ')
        agg_p_score = positive_score.sum().item()
        agg_n_score = negative_score.sum().item()
        print(f'Train Batch {iteration}:')
        print(f'Loss:        {loss: .4f}')
        print(f'Agg P-Score: {agg_p_score: .4f}')
        print(f'Agg N-Score: {agg_n_score: .4f}')
        print(f'===' * 25)
      if iteration % 5000 == 0:
        results_list = list()
        for i in tqdm(range(0, val_set.shape[0] - Global.VAL_HITS_AT.value, Global.VAL_HITS_AT.value)):
          val_supervision = val_set[i: i + Global.VAL_HITS_AT.value]
          val_messaging = train_set
          val_graph = mini_batch_maker(
              val_messaging,
              val_supervision,
              dataset.candidates,
              'one-hot'
          )
          results_list.append(evaluate(model, val_graph, Global.VAL_HITS_AT.value))
        print('')
        result = torch.stack(results_list).mean().item()
        print(f'Dev: \nHits @ {Global.VAL_HITS_AT.value}: {result: .3f}')
        print(f'---' * 25)
except KeyboardInterrupt:
  pass


In [None]:
#@title Test MAIN { form-width: "15%" }
results_list = list()
try:
  test_batches = dataset.get_test_batches(Global.TEST_HITS_AT.value)
  for supervision in tqdm(test_batches[1], total=test_batches[0]):
    test_supervision = supervision.to(Global.DEVICE.value)
    test_messaging = torch.cat((train_set, val_set))
    test_graph = mini_batch_maker(
        test_messaging,
        test_supervision,
        dataset.candidates,
        'one-hot'
    )
    results_list.append(evaluate(model, test_graph, Global.TEST_HITS_AT.value))
except KeyboardInterrupt:
  pass
finally:
  results = torch.stack(results_list)
  print('')
  print(f'Test: \nHits @ {Global.TEST_HITS_AT.value}: {results[:].mean().item(): .4f}')

In [None]:
#@title Visualization OPTIONAL-MAIN { form-width: "15%" }
data_subset = train_set[:2500, :]
graph_data = graph_data_maker(data_subset, 'one-hot')
visualize_graph(graph_data, 100, 100)

In [None]:
#@title Mostly Trash { form-width: "15%" }
# ! pip install deepsnap
# import deepsnap


# NX_graph = PyG.utils.to_networkx(train_graph)
# SX_graph = deepsnap.graph.Graph(NX_graph)

# SX_graph
# SX_graph.negative_sampling(SX_graph.edge_index, SX_graph.num_nodes, 100)
# #@title
# print(min([i[1].shape[0] for i in train_list]))
# #@title
# d = defaultdict(lambda: 0)
# for i in dataset.training:
#   h = i[0]
#   t = i[-1]
#   d[h.item()] += 1
#   d[t.item()] += 1
# new_list = list()
# other_list = list()
# ################################################################################################################
# new_candidates = torch.tensor(list(filter(lambda z: z[1] > 1000, sorted(d.items(), key= lambda z: z[1]))))[:, 0]
# ################################################################################################################
# for i in tqdm(dataset.training):
#   h = i[0]
#   t = i[-1]
#   if h in new_candidates or t in new_candidates:
#     new_list.append(i)
#   else:
#     other_list.append(i)
# #@title
# for i in train_list:
#   print(torch.unique(i[1]).shape)
# #@title

# # #@title Dataset Load and Train/Val/Test Split MAIN 2{ form-width: "15%" }
# # dataset = OBL2021Dataset()
# # train_set = dataset.training # torch.tensor of shape(num_train,3)
# # val_set = dataset.validation # torch.tensor of shape(num_val,3)
# # test_set = dataset.testing   # torch.tensor of shape(num_train,3)

# # train_supervision, train_messaging = torch.split(
# #     torch.stack(new_list)[torch.randperm(torch.stack(new_list).shape[0])],
# #     split_size_or_sections=(
# #         Global.TRAIN_SUPERVISION_SIZE.value,
# #         torch.stack(new_list).shape[0] - Global.TRAIN_SUPERVISION_SIZE.value
# #     )
# # )
# # train_supervision_data_loader = torch.split(train_supervision, split_size_or_sections=32)
# # train_supervision_negative_samples = [None for _ in range(Global.TRAIN_SUPERVISION_SIZE.value // 32)]
# # for i in range(Global.TRAIN_SUPERVISION_SIZE.value // 32):
# #   heads = train_supervision_data_loader[i][:, 0]
# #   relations = train_supervision_data_loader[i][:, 1]
# #   tails = train_supervision_data_loader[i][:, -1]
# #   negative_samples = torch.vstack(
# #       (heads, relations, torch.multinomial(dataset.candidates.type(torch.float), 32))
# #   ).t().contiguous()
# #   train_supervision_negative_samples[i] = negative_samples
# # knowledge_graph = torch.cat(
# #     (dataset.training, dataset.validation, dataset.testing),
# #     dim=0
# # )

# # # for _index, mini_batch in enumerate(tqdm(train_supervision_negative_samples)):
# # #   for __index, candidate in enumerate(mini_batch):
# # #     is_corrupt = not any(np.equal(knowledge_graph.numpy(), candidate.numpy().tolist()).all(1))
# # #     if not is_corrupt:
# # #       print(f'{candidate} at [{_index}][{__index}] is not a true corrupted edge')

# #@title

# torch.unique(torch.cat(
#     (train_messaging, torch.stack(other_list)),
#     dim=0
# )).shape
# #@title
# torch.split(
#     torch.tensor([_ for _ in range(1, 15)])[torch.randperm(14)],
#     split_size_or_sections = (7, 7)
# )
# #@title
# train_supervision, train_messaging = torch.split(
#     torch.stack(new_list),
#     split_size_or_sections=(
#         3,
#         torch.stack(new_list).shape[0] - 3
#     )
# )
# #@title
# torch.unique(train_messaging).shape
# #@title
# torch.unique(torch.stack(other_list + new_list)).shape
# # torch.cat(
# #     (train_messaging, torch.stack(other_list)),
# #     dim=0
# # ).shape
# # train_supervision.shape
# #@title
# # d = defaultdict(lambda: 0)
# # for i in dataset.training:
# #   h = i[0]
# #   t = i[-1]
# #   d[h.item()] += 1
# #   d[t.item()] += 1
# new_list = list()
# other_list = list()
# ################################################################################################################
# new_candidates = torch.tensor(list(filter(lambda z: z[1] > 1000, sorted(d.items(), key= lambda z: z[1]))))[:, 0]
# ################################################################################################################
# for i in tqdm(dataset.training):
#   h = i[0]
#   t = i[-1]
#   if h in new_candidates or t in new_candidates:
#     new_list.append(i)
#   else:
#     other_list.append(i)

# #@title
# torch.stack(new_list).shape
# dataset.training.shape