In [2]:
#@title OpenBioLink Module Install EXTERNAL{ form-width: "15%" }
! pip install openbiolink



In [3]:
#@title PyG Installation EXTERNAL{ form-width: "15%" }
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

In [4]:
#@title Module Imports HEADER { form-width: "15%" }
from openbiolink.obl2021 import OBL2021Dataset
import torch
from torch.nn import Module,\
                     ModuleList,\
                     Embedding,\
                     BatchNorm1d,\
                     LogSoftmax,\
                     Softmax,\
                     Linear,\
                     NLLLoss,\
                     CrossEntropyLoss
from torch.optim import Adam
import torch.nn.functional as F
import torch_geometric as PyG
from torch_geometric.data import Data
from torch_geometric.nn.conv import RGCNConv
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from typing import NoReturn
from enum import Enum
from collections import defaultdict

In [5]:
#@title Global Variables CLASS { form-width: "15%" }
class Global(Enum):
  HEAD_INDEX = 0
  RELATION_INDEX = 1
  TAIL_INDEX = 2
  FEATURE_ENG = 'one-hot'
  NUM_RELATIONS = 28
  DEVICE = 'cpu'
  MINI_BATCH_SIZE = 32

In [6]:
#@title add_tail_to_head(data_split) FUNCTION { form-width: "15%" }
def add_tail_to_head(data_split: torch.Tensor, plus: int=28) -> torch.Tensor:
  heads = data_split[:, (0)]
  tails = data_split[:, (-1)]
  relations = data_split[:, (1)]
  
  tail_to_head = torch.vstack(
      (tails, relations + plus, heads)
  ).t()

  return torch.cat(
      (data_split, tail_to_head),
      dim=0
  )



In [7]:
#@title checksum(data_subset, graph_data) FUNCTION { form-width: "15%" }
def checksum(data_subset: torch.Tensor, graph_data: Data, step: int=500) -> NoReturn:
  for i in range(step, data_subset.shape[0], step):
    graph_head_indcs = graph_data.edge_index[0, i-step:i]
    data_head_entities = data_subset[i-step:i, 0]
    graph_head_entities = graph_data.x[graph_head_indcs].reshape(-1)

    checksum_head = (graph_head_entities == data_head_entities).sum().item()
    if not checksum_head == step:
      print('head')
      print(i)
      break

    graph_relation_indcs = graph_data.edge_type[i-step:i]
    data_relation_types = data_subset[i-step:i, 1]

    checksum_relation = (graph_relation_indcs == data_relation_types).sum().item()
    if not checksum_relation == step:
      print('relation')
      print(i)
      break

    graph_tail_indcs = graph_data.edge_index[1, i-step:i]
    data_tail_entities = data_subset[i-step:i, 2]
    graph_tail_entities = graph_data.x[graph_tail_indcs].reshape(-1)

    checksum_tail = (graph_tail_entities == data_tail_entities).sum().item()
    if not checksum_head == step:
      print('tail')
      print(i)
      break

  else:
    print('All clear :)')

In [8]:
#@title graph_data_maker(dataset, x_feature) FUNCTION{ form-width: "15%" }
def graph_data_maker(messaging: torch.Tensor, supervision: torch.Tensor, negative_samples: torch.Tensor, x_feature: str=Global.FEATURE_ENG.value, check_for_correctness: bool=False) -> Data:
  relation_idx = Global.RELATION_INDEX.value
  head_idx = Global.HEAD_INDEX.value
  tail_idx = Global.TAIL_INDEX.value

  if x_feature == 'one-hot':
    one_hot_index = 0
    edge_index_list_flat = []
    seen_dict = defaultdict(lambda: -1)
    edge_index_flat = messaging[:, (head_idx, tail_idx)].reshape(-1)

    for i in range(edge_index_flat.shape[0]):
      entity = edge_index_flat[i].item()
      if seen_dict[entity] == -1:
        edge_index_list_flat.append(one_hot_index)
        seen_dict[entity] = one_hot_index
        one_hot_index += 1
      else:
        edge_index_list_flat.append(seen_dict[entity])
      if i + 1 % 50000 == 0:
        print(f'Entity {i} encoded. {i / edge_index_flat.shape[0] * 100:.2f} %')
    edge_index_messaging = torch.tensor(edge_index_list_flat).reshape(-1, 2)
    ##############################################################
    edge_index_flat = supervision[:, (head_idx, tail_idx)].reshape(-1)
    edge_index_list_flat = []

    for i in range(edge_index_flat.shape[0]):
      entity = edge_index_flat[i].item()
      if seen_dict[entity] == -1:
        edge_index_list_flat.append(one_hot_index)
        seen_dict[entity] = one_hot_index
        one_hot_index += 1
      else:
        edge_index_list_flat.append(seen_dict[entity])
      if i + 1 % 50000 == 0:
        print(f'Entity {i} encoded. {i / edge_index_flat.shape[0] * 100:.2f} %')
    edge_index_supervision = torch.tensor(edge_index_list_flat).reshape(-1, 2)
    ##############################################################
    edge_index_flat = negative_samples[:, (head_idx, tail_idx)].reshape(-1)
    edge_index_list_flat = []

    for i in range(edge_index_flat.shape[0]):
      entity = edge_index_flat[i].item()
      if seen_dict[entity] == -1:
        edge_index_list_flat.append(one_hot_index)
        seen_dict[entity] = one_hot_index
        one_hot_index += 1
      else:
        edge_index_list_flat.append(seen_dict[entity])
      if i + 1 % 50000 == 0:
        print(f'Entity {i} encoded. {i / edge_index_flat.shape[0] * 100:.2f} %')
    edge_index_negative = torch.tensor(edge_index_list_flat).reshape(-1, 2)
    ##############################################################
    x = torch.tensor(list(seen_dict.keys())).reshape(-1, 1)

    graph_data = Data(
        x=x,
        edge_index_messaging=edge_index_messaging.t().contiguous(),
        edge_type_messaging=messaging[:, relation_idx],
        edge_index_supervision=edge_index_supervision.t().contiguous(),
        edge_type_supervision=supervision[:, relation_idx],
        edge_index_negative=edge_index_negative.t().contiguous(),
        edge_type_negative=negative_samples[:, (relation_idx)]
    )
  else:
    raise NotImplementedError('This functionality has not been implemented yet.')

  if check_for_correctness:
    checksum(dataset, graph_data)
  return graph_data

In [9]:
#@title visualize_graph(graph_data, height, width) FUNCTION{ form-width: "15%" }
def visualize_graph(graph_data: Data, height: int=10, width:int=10) -> NoReturn:
  nx_graph = to_networkx(graph_data)

  pos = nx.spring_layout(nx_graph)

  fig = plt.gcf()
  fig.set_size_inches(width, height)

  edge_labels = dict()
  ei_np = graph_data.edge_index.t().numpy()
  for edge in nx_graph.edges():
    e = np.array(edge)
    idx = np.where(ei_np == e)[0][0]
    label = graph_data.edge_type[idx].item()
    edge_labels.update({edge: label})
    
  nx.draw_networkx_nodes(nx_graph, pos)
  nx.draw_networkx_edges(nx_graph, pos, connectionstyle='arc3,rad=0.2')
  nx.draw_networkx_labels(nx_graph, pos, labels={n:graph_data.x[n].item() for n in nx_graph})
  nx.draw_networkx_edge_labels(nx_graph, pos, edge_labels=edge_labels)

  fig.show()


In [38]:
#@title GNN Model CLASS{ form-width: "10%" }
class GNN(Module):
  def __init__(self, conv_dims: list, fully_connected_dims: list, x_feature:str, dropout: dict, embedding_dims: tuple=None)-> NoReturn:
    super(GNN, self).__init__()
    self.mode = None # 'train' or 'test' or 'dev' later 
    self.num_relations = Global.NUM_RELATIONS.value
    self.dropout = dropout
    self.x_feature = x_feature
    relation_weights_list = [Linear(16, 16, bias=False) for _ in range(28)]
    self.relation_weights = ModuleList(relation_weights_list)
    if x_feature == 'one-hot':
      #one-hot to latent
      self.embed = Embedding(embedding_dims[0], embedding_dims[1])
      first_conv_layer = [RGCNConv(embedding_dims[1], conv_dims[0], self.num_relations)]
    elif x_feature == 'identity':
      first_conv_layer = [RGCNConv(1, conv_dims[0], self.num_relations)]
    conv_list = first_conv_layer + \
                                [
                                  RGCNConv(conv_dims[i], conv_dims[i+1], self.num_relations)
                                  for i in range(len(conv_dims[:-1]))
                                ]
  

    # fully_connected_list =   [
    #                             Linear(2*conv_dims[-1], fully_connected_dims[0])
    #                          ] + \
    #                          [
    #                             Linear(fully_connected_dims[i], fully_connected_dims[i+1])
    #                             for i in range(len(fully_connected_dims[:-1]))
    #                          ] + \
    #                          [
    #                             Linear(fully_connected_dims[-1], self.output_dim)
    #                          ]

    #graph conv layers
    self.conv_layers = ModuleList(conv_list)

    #fully connected dense layers
    # self.fully_connected_layers = ModuleList(fully_connected_list)

    self.classifier = LogSoftmax(dim=1)

    
  def reset_parameters(self):
    self.embed.reset_parameters
    for conv in self.conv_layers:
        conv.reset_parameters()
    # for fc in self.fully_connected_layers:
    #     fc.reset_parameters()
      

  def forward(self, data: Data) -> torch.Tensor:
    edge_index = data.edge_index_messaging
    x = data.x
    edge_type = data.edge_type_messaging

    # print(x.shape)
    # return
    ####################################### ONE - HOT #######################################
    if self.x_feature == 'one-hot':
      latent = list()
      last_index = 0
      for i in range(0, x.shape[0] - 1024, 1024):
        latent.append(self.embed(x[i: i+1024, :]))
        last_index = i
      last_latent = self.embed(x[last_index:, : ]).reshape(-1, self.embed.weight.shape[1])
      partial_latent_space = torch.stack(latent).reshape(-1, self.embed.weight.shape[1])
      x = torch.cat(
          (partial_latent_space, last_latent),
          dim=0 
      )
      
    ####################################### Encoder: RGCN #######################################
      
    ############################################## IDENTITY ################################################
    elif self.x_feature == 'identity':
      x = torch.ones(x.shape[0], 1) ############KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKk
      if self.training:
        x = F.dropout(x, p=self.dropout["emb"])
      print(f'Embedding {x.shape}')
    ####################################### Encoder: RGCN #######################################
    for conv in self.conv_layers[:-1]:
      x = conv(x, edge_index=edge_index, edge_type=edge_type)
      x = F.relu(x)
      if self.training:
        x = F.dropout(x, p=self.dropout["conv"])
    x = self.conv_layers[-1](x, edge_index, edge_type)
    if self.training:
      x = F.dropout(x, p=self.dropout["conv"])

    ####################################### Decoder #######################################
    n_scores = list()
    for index, n_edge_relation in enumerate(data.edge_type_negative):
      weights = self.relation_weights[n_edge_relation]
      head = x[data.edge_index_negative[0, index]]
      tail = x[data.edge_index_negative[1, index]]
      score = weights(head) @ tail 
      n_scores.append(score)
    p_scores = list()
    for index, p_edge_relation in enumerate(data.edge_type_supervision):
      weights = self.relation_weights[p_edge_relation]
      head = x[data.edge_index_supervision[0, index]]
      tail = x[data.edge_index_supervision[1, index]]
      score = weights(head) @ tail 
      p_scores.append(score)
    return torch.stack(n_scores), torch.stack(p_scores)
    


In [None]:
#@title Dataset Load and Train/Val/Test Split MAIN 1{ form-width: "15%" }
dataset = OBL2021Dataset()
train_set = dataset.training # torch.tensor of shape(num_train,3)
val_set = dataset.validation # torch.tensor of shape(num_val,3)
test_set = dataset.testing   # torch.tensor of shape(num_train,3)
knowledge_graph = torch.cat(
    (dataset.training, dataset.validation, dataset.testing),
    dim=0
)

# train_list = list()
# gg = list()
for i in tqdm(range(0, train_set.shape[0], Global.MINI_BATCH_SIZE.value)):
  train_supervision = train_set[i: i + Global.MINI_BATCH_SIZE.value, :]

  heads = train_supervision[:, 0]
  relations = train_supervision[:, 1]
  tails = train_supervision[:, -1]

  negative_samples = torch.vstack(
      (heads, relations, torch.multinomial(dataset.candidates.type(torch.float), Global.MINI_BATCH_SIZE.value))
  ).t().contiguous()

  train_messaging = torch.cat(
      (train_set[: i, :], train_set[i + Global.MINI_BATCH_SIZE.value: , :]),
      dim=0
  )

  for __index, candidate in enumerate(negative_samples):
    is_corrupt = not any(np.equal(knowledge_graph.numpy(), candidate.numpy().tolist()).all(1))
    if not is_corrupt:
      print(f'{candidate} at [{_index}][{__index}] is not a true corrupted edge')

  g = graph_data_maker(train_messaging, train_supervision, negative_samples, 'one-hot', False)
  # gg.append(g)

  # you can run model(g) on it


In [41]:
#@title Model and Hyperparameters MAIN { form-width: "15%" }
model = GNN(
    x_feature='one-hot',
    conv_dims=[16, 16, 16],
    embedding_dims=(dataset.candidates.max() + 1, 32),
    fully_connected_dims=[1], 
    dropout={
        "emb": 0.01,
        "conv": 0.01,
        "fc": 0.01
    }
).to(Global.DEVICE.value)
print(model)

GNN(
  (relation_weights): ModuleList(
    (0): Linear(in_features=16, out_features=16, bias=False)
    (1): Linear(in_features=16, out_features=16, bias=False)
    (2): Linear(in_features=16, out_features=16, bias=False)
    (3): Linear(in_features=16, out_features=16, bias=False)
    (4): Linear(in_features=16, out_features=16, bias=False)
    (5): Linear(in_features=16, out_features=16, bias=False)
    (6): Linear(in_features=16, out_features=16, bias=False)
    (7): Linear(in_features=16, out_features=16, bias=False)
    (8): Linear(in_features=16, out_features=16, bias=False)
    (9): Linear(in_features=16, out_features=16, bias=False)
    (10): Linear(in_features=16, out_features=16, bias=False)
    (11): Linear(in_features=16, out_features=16, bias=False)
    (12): Linear(in_features=16, out_features=16, bias=False)
    (13): Linear(in_features=16, out_features=16, bias=False)
    (14): Linear(in_features=16, out_features=16, bias=False)
    (15): Linear(in_features=16, out_feat

In [None]:
loss_fn = torch.nn.BCEWithLogitsLoss()
opt = Adam(model.parameters())


def train(model, graph, optimizer, loss_fn):
  optimizer.zero_grad()
  negative, positive = model(graph)
  loss = loss_fn(negative, torch.zeros(32)) + loss_fn(positive, torch.ones(32))
  loss.backward()
  optimizer.step()
  return loss.item(), torch.sigmoid(positive), torch.sigmoid(negative)


@torch.no_grad()
def evaluate(model, graph):
  negative, positive = model(graph)
  return positive, negative


for i in range(320):
  l, p, n = train(model, g, opt, loss_fn)
  # pp, nn = evaluate(model, gg[1])
  print(f'Iteration {i+1}, Loss: {l: .4f}, P-Score: {p.sum().item(): .4f}, N-Score: {n.sum().item(): .4f}')
  # print(f'Dev: {pp.sum().item()}, {nn.sum().item()}')
  print('================================================================================')

In [None]:
#@title Visualization OPTIONAL-MAIN { form-width: "15%" }
data_subset = train_set[:2500, :]
graph_data = graph_data_maker(data_subset, 'one-hot')
visualize_graph(graph_data, 100, 100)

In [None]:
#@title Mostly Trash { form-width: "15%" }
# ! pip install deepsnap
# import deepsnap


# NX_graph = PyG.utils.to_networkx(train_graph)
# SX_graph = deepsnap.graph.Graph(NX_graph)

# SX_graph
# SX_graph.negative_sampling(SX_graph.edge_index, SX_graph.num_nodes, 100)
# #@title
# print(min([i[1].shape[0] for i in train_list]))
# #@title
# d = defaultdict(lambda: 0)
# for i in dataset.training:
#   h = i[0]
#   t = i[-1]
#   d[h.item()] += 1
#   d[t.item()] += 1
# new_list = list()
# other_list = list()
# ################################################################################################################
# new_candidates = torch.tensor(list(filter(lambda z: z[1] > 1000, sorted(d.items(), key= lambda z: z[1]))))[:, 0]
# ################################################################################################################
# for i in tqdm(dataset.training):
#   h = i[0]
#   t = i[-1]
#   if h in new_candidates or t in new_candidates:
#     new_list.append(i)
#   else:
#     other_list.append(i)
# #@title
# for i in train_list:
#   print(torch.unique(i[1]).shape)
# #@title

# # #@title Dataset Load and Train/Val/Test Split MAIN 2{ form-width: "15%" }
# # dataset = OBL2021Dataset()
# # train_set = dataset.training # torch.tensor of shape(num_train,3)
# # val_set = dataset.validation # torch.tensor of shape(num_val,3)
# # test_set = dataset.testing   # torch.tensor of shape(num_train,3)

# # train_supervision, train_messaging = torch.split(
# #     torch.stack(new_list)[torch.randperm(torch.stack(new_list).shape[0])],
# #     split_size_or_sections=(
# #         Global.TRAIN_SUPERVISION_SIZE.value,
# #         torch.stack(new_list).shape[0] - Global.TRAIN_SUPERVISION_SIZE.value
# #     )
# # )
# # train_supervision_data_loader = torch.split(train_supervision, split_size_or_sections=32)
# # train_supervision_negative_samples = [None for _ in range(Global.TRAIN_SUPERVISION_SIZE.value // 32)]
# # for i in range(Global.TRAIN_SUPERVISION_SIZE.value // 32):
# #   heads = train_supervision_data_loader[i][:, 0]
# #   relations = train_supervision_data_loader[i][:, 1]
# #   tails = train_supervision_data_loader[i][:, -1]
# #   negative_samples = torch.vstack(
# #       (heads, relations, torch.multinomial(dataset.candidates.type(torch.float), 32))
# #   ).t().contiguous()
# #   train_supervision_negative_samples[i] = negative_samples
# # knowledge_graph = torch.cat(
# #     (dataset.training, dataset.validation, dataset.testing),
# #     dim=0
# # )

# # # for _index, mini_batch in enumerate(tqdm(train_supervision_negative_samples)):
# # #   for __index, candidate in enumerate(mini_batch):
# # #     is_corrupt = not any(np.equal(knowledge_graph.numpy(), candidate.numpy().tolist()).all(1))
# # #     if not is_corrupt:
# # #       print(f'{candidate} at [{_index}][{__index}] is not a true corrupted edge')

# #@title

# torch.unique(torch.cat(
#     (train_messaging, torch.stack(other_list)),
#     dim=0
# )).shape
# #@title
# torch.split(
#     torch.tensor([_ for _ in range(1, 15)])[torch.randperm(14)],
#     split_size_or_sections = (7, 7)
# )
# #@title
# train_supervision, train_messaging = torch.split(
#     torch.stack(new_list),
#     split_size_or_sections=(
#         3,
#         torch.stack(new_list).shape[0] - 3
#     )
# )
# #@title
# torch.unique(train_messaging).shape
# #@title
# torch.unique(torch.stack(other_list + new_list)).shape
# # torch.cat(
# #     (train_messaging, torch.stack(other_list)),
# #     dim=0
# # ).shape
# # train_supervision.shape
# #@title
# # d = defaultdict(lambda: 0)
# # for i in dataset.training:
# #   h = i[0]
# #   t = i[-1]
# #   d[h.item()] += 1
# #   d[t.item()] += 1
# new_list = list()
# other_list = list()
# ################################################################################################################
# new_candidates = torch.tensor(list(filter(lambda z: z[1] > 1000, sorted(d.items(), key= lambda z: z[1]))))[:, 0]
# ################################################################################################################
# for i in tqdm(dataset.training):
#   h = i[0]
#   t = i[-1]
#   if h in new_candidates or t in new_candidates:
#     new_list.append(i)
#   else:
#     other_list.append(i)

# #@title
# torch.stack(new_list).shape
# dataset.training.shape