In [None]:
#@title OpenBioLink Module Install EXTERNAL{ form-width: "15%" }
! pip install openbiolink

In [None]:
#@title PyG Installation EXTERNAL{ form-width: "15%" }
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu111.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

In [3]:
#@title Module Imports HEADER { form-width: "15%" }
from openbiolink.obl2021 import OBL2021Dataset
import torch
from torch.nn import Module,\
                     ModuleList,\
                     Embedding,\
                     BatchNorm1d,\
                     LogSoftmax,\
                     Softmax,\
                     Linear,\
                     NLLLoss,\
                     CrossEntropyLoss
from torch.optim import Adam
import torch.nn.functional as F
import torch_geometric as PyG
from torch_geometric.data import Data
from torch_geometric.nn.conv import RGCNConv
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

from typing import NoReturn
from enum import Enum
from collections import defaultdict

In [4]:
#@title Global Variables CLASS { form-width: "15%" }
class Global(Enum):
  HEAD_INDEX = 0
  RELATION_INDEX = 1
  TAIL_INDEX = 2
  FEATURE_ENG = 'one-hot'
  NUM_RELATIONS = 28
  DEVICE = 'cpu'
  TRAIN_SUPERVISION_SIZE = 180000

In [44]:
#@title add_tail_to_head(data_split) FUNCTION { form-width: "15%" }
def add_tail_to_head(data_split: torch.Tensor, plus: int=dataset.num_relations) -> torch.Tensor:
  heads = data_split[:, (0)]
  tails = data_split[:, (-1)]
  relations = data_split[:, (1)]
  
  tail_to_head = torch.vstack(
      (tails, relations + plus, heads)
  ).t()

  return torch.cat(
      (data_split, tail_to_head),
      dim=0
  )



In [33]:
#@title checksum(data_subset, graph_data) FUNCTION { form-width: "15%" }
def checksum(data_subset: torch.Tensor, graph_data: Data, step: int=500) -> NoReturn:
  for i in range(step, data_subset.shape[0], step):
    graph_head_indcs = graph_data.edge_index[0, i-step:i]
    data_head_entities = data_subset[i-step:i, 0]
    graph_head_entities = graph_data.x[graph_head_indcs].reshape(-1)

    checksum_head = (graph_head_entities == data_head_entities).sum().item()
    if not checksum_head == step:
      print('head')
      print(i)
      break

    graph_relation_indcs = graph_data.edge_type[i-step:i]
    data_relation_types = data_subset[i-step:i, 1]

    checksum_relation = (graph_relation_indcs == data_relation_types).sum().item()
    if not checksum_relation == step:
      print('relation')
      print(i)
      break

    graph_tail_indcs = graph_data.edge_index[1, i-step:i]
    data_tail_entities = data_subset[i-step:i, 2]
    graph_tail_entities = graph_data.x[graph_tail_indcs].reshape(-1)

    checksum_tail = (graph_tail_entities == data_tail_entities).sum().item()
    if not checksum_head == step:
      print('tail')
      print(i)
      break

  else:
    print('All clear :)')

In [45]:
#@title graph_data_maker(dataset, x_feature) FUNCTION{ form-width: "15%" }
def graph_data_maker(dataset: torch.Tensor, x_feature: str=Global.FEATURE_ENG.value, check_for_correctness: bool=False) -> Data:
  relation_idx = Global.RELATION_INDEX.value
  head_idx = Global.HEAD_INDEX.value
  tail_idx = Global.TAIL_INDEX.value

  if x_feature == 'one-hot':
    one_hot_index = 0
    edge_index_list_flat = []
    seen_dict = defaultdict(lambda: -1)
    edge_index_flat = dataset[:, (head_idx, tail_idx)].reshape(-1)

    for i in range(edge_index_flat.shape[0]):
      entity = edge_index_flat[i].item()
      if seen_dict[entity] == -1:
        edge_index_list_flat.append(one_hot_index)
        seen_dict[entity] = one_hot_index
        one_hot_index += 1
      else:
        edge_index_list_flat.append(seen_dict[entity])
      if i + 1 % 50000 == 0:
        print(f'Entity {i} encoded. {i / edge_index_flat.shape[0] * 100:.2f} %')
    edge_index = torch.tensor(edge_index_list_flat).reshape(-1, 2)
    x = torch.tensor(list(seen_dict.keys())).reshape(-1, 1)

    graph_data = Data(
        x=x,
        edge_index=edge_index.t().contiguous(),
        edge_type=dataset[:, relation_idx]
    )
  else:
    raise NotImplementedError('This functionality has not been implemented yet.')

  if check_for_correctness:
    checksum(dataset, graph_data)
  return graph_data

In [46]:
#@title visualize_graph(graph_data, height, width) FUNCTION{ form-width: "15%" }
def visualize_graph(graph_data: Data, height: int=10, width:int=10) -> NoReturn:
  nx_graph = to_networkx(graph_data)

  pos = nx.spring_layout(nx_graph)

  fig = plt.gcf()
  fig.set_size_inches(width, height)

  edge_labels = dict()
  ei_np = graph_data.edge_index.t().numpy()
  for edge in nx_graph.edges():
    e = np.array(edge)
    idx = np.where(ei_np == e)[0][0]
    label = graph_data.edge_type[idx].item()
    edge_labels.update({edge: label})
    
  nx.draw_networkx_nodes(nx_graph, pos)
  nx.draw_networkx_edges(nx_graph, pos, connectionstyle='arc3,rad=0.2')
  nx.draw_networkx_labels(nx_graph, pos, labels={n:graph_data.x[n].item() for n in nx_graph})
  nx.draw_networkx_edge_labels(nx_graph, pos, edge_labels=edge_labels)

  fig.show()


In [47]:
#@title GNN Model CLASS{ form-width: "10%" }
class GNN(Module):
  def __init__(self, conv_dims: list, fully_connected_dims: list, x_feature:str, dropout: dict, embedding_dims: tuple=None)-> NoReturn:
    super(GNN, self).__init__()
    self.mode = None # 'train' or 'test' or 'dev' later 
    self.num_relations = Global.NUM_RELATIONS.value
    self.dropout = dropout
    self.x_feature = x_feature
    if x_feature == 'one_hot':
      #one-hot to latent
      self.embed = Embedding(embedding_dims[0], embedding_dims[1])
      first_conv_layer = [RGCNConv(embedding_dims[1], conv_dims[0], self.num_relations)]
    elif x_feature == 'identity':
      first_conv_layer = [RGCNConv(1, conv_dims[0], self.num_relations)]
    conv_list = first_conv_layer + \
                                [
                                  RGCNConv(conv_dims[i], conv_dims[i+1], self.num_relations)
                                  for i in range(len(conv_dims[:-1]))
                                ]
  

    # fully_connected_list =   [
    #                             Linear(2*conv_dims[-1], fully_connected_dims[0])
    #                          ] + \
    #                          [
    #                             Linear(fully_connected_dims[i], fully_connected_dims[i+1])
    #                             for i in range(len(fully_connected_dims[:-1]))
    #                          ] + \
    #                          [
    #                             Linear(fully_connected_dims[-1], self.output_dim)
    #                          ]

    #graph conv layers
    self.conv_layers = ModuleList(conv_list)

    #fully connected dense layers
    # self.fully_connected_layers = ModuleList(fully_connected_list)

    self.classifier = LogSoftmax(dim=1)

    
  def reset_parameters(self):
    self.embed.reset_parameters
    for conv in self.conv_layers:
        conv.reset_parameters()
    # for fc in self.fully_connected_layers:
    #     fc.reset_parameters()
      

  def forward(self, data: Data) -> torch.Tensor:
    edge_index = data.edge_index
    x = data.x
    edge_type = data.edge_type
    ####################################### Encoder: Embedding #######################################
    if self.x_feature == 'one_hot':
      x = self.embed(x)
    elif self.x_feature == 'identity':
      x = torch.ones(x.shape)
    if self.training:
      x = F.dropout(x, p=self.dropout["emb"])
    ####################################### Encoder: RGCN #######################################
    for conv in self.conv_layers[:-1]:
      x = conv(x, edge_index=edge_index, edge_type=edge_type)
      x = F.relu(x)
      if self.training:
        x = F.dropout(x, p=self.dropout["conv"])

    x = self.conv_layers[-1](x, edge_index, edge_type)
    if self.training:
      x = F.dropout(x, p=self.dropout["conv"])

    ####################################### Decoder #######################################
    return x


In [None]:
#@title Dataset Load and Train/Val/Test Split MAIN { form-width: "15%" }
dataset = OBL2021Dataset()
train_set = dataset.training # torch.tensor of shape(num_train,3)
val_set = dataset.validation # torch.tensor of shape(num_val,3)
test_set = dataset.testing   # torch.tensor of shape(num_train,3)

train_supervision, train_messaging = torch.split(
    train_set,
    split_size_or_sections=(
        Global.TRAIN_SUPERVISION_SIZE.value,
        train_set.shape[0] - Global.TRAIN_SUPERVISION_SIZE.value
    )
)
train_supervision_data_loader = torch.split(train_supervision, split_size_or_sections=32)
train_supervision_negative_samples = [None for _ in range(32)]
for i in range(32):
  heads = train_supervision_data_loader[i][:, 0]
  relations = train_supervision_data_loader[i][:, 1]
  tails = train_supervision_data_loader[i][:, -1]
  negative_samples = torch.vstack(
      (heads, relations, torch.multinomial(dataset.candidates.type(torch.float), 32))
  ).t().contiguous()
  train_supervision_negative_samples[i] = negative_samples
knowledge_graph = torch.cat(
    (dataset.training, dataset.validation, dataset.testing),
    dim=0
)

for _index, mini_batch in enumerate(train_supervision_negative_samples):
  for __index, candidate in enumerate(mini_batch):
    is_corrupt = not any(np.equal(knowledge_graph.numpy(), candidate.numpy().tolist()).all(1))
    if not is_corrupt:
      print(f'{candidate} is not a true corrupted edge')


In [48]:
#@title Graph Maker MAIN { form-width: "15%" }
feature_eng_method = Global.FEATURE_ENG.value
train_graph = graph_data_maker(train_set, x_feature=feature_eng_method)
# test_graph = graph_data_maker(test_set, x_feature=feature_eng_method)
# val_graph = graph_data_maker(val_set, x_feature=feature_eng_method)

In [None]:
#@title Visualization OPTIONAL-MAIN { form-width: "15%" }
data_subset = train_set[:2500, :]
graph_data = graph_data_maker(data_subset, 'one-hot')
visualize_graph(graph_data, 100, 100)

In [None]:
#@title Model and Hyperparameters MAIN { form-width: "15%" }
model = GNN(
    x_feature='identity',
    conv_dims=[16, 16, 16], 
    fully_connected_dims=[1], 
    dropout={
        "emb": 0.01,
        "conv": 0.01,
        "fc": 0.01
    }
).to(Global.DEVICE.value)

In [None]:
model(train_graph).shape

torch.Size([180992, 16])

In [None]:
#@title
! pip install deepsnap
import deepsnap


NX_graph = PyG.utils.to_networkx(train_graph)
SX_graph = deepsnap.graph.Graph(NX_graph)

SX_graph
SX_graph.negative_sampling(SX_graph.edge_index, SX_graph.num_nodes, 100)