In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from dgl import heterograph
import dgl.function as fn
import dgl.utils as dgl_utils
from functools import partial
from dgl.nn.pytorch import RelGraphConv
from dgl.contrib.data import load_data

import numpy as np
import pygraphviz as pgv

import utils
from base import BaseRGCN

load graph from dictionary

In [69]:
graph_dict = utils.read_dict_file('../data/clean/graph_dict.txt')
g = heterograph(graph_dict)

In [70]:
g

Graph(num_nodes={'disease': 18527, 'drug': 8094, 'gene': 37912},
      num_edges={('drug', 'treat', 'disease'): 7908, ('drug', 'carrier', 'gene'): 834, ('drug', 'enzyme', 'gene'): 5271, ('drug', 'target', 'gene'): 20553, ('drug', 'transport', 'gene'): 3023, ('gene', 'binding', 'gene'): 566365, ('gene', 'activation', 'gene'): 144427, ('gene', 'catalysis', 'gene'): 211527, ('gene', 'reaction', 'gene'): 74799, ('gene', 'expression', 'gene'): 14026, ('gene', 'inhibition', 'gene'): 21048, ('gene', 'ptmod', 'gene'): 8304, ('gene', 'associate', 'disease'): 11413, ('drug', 'side_effect', 'drug'): 52292},
      metagraph=[('drug', 'disease'), ('drug', 'gene'), ('drug', 'gene'), ('drug', 'gene'), ('drug', 'gene'), ('drug', 'drug'), ('gene', 'gene'), ('gene', 'gene'), ('gene', 'gene'), ('gene', 'gene'), ('gene', 'gene'), ('gene', 'gene'), ('gene', 'gene'), ('gene', 'disease')])

In [47]:
def plot_graph(nxg):
    ag = pgv.AGraph(strict=False, directed=True)
    for u, v, k in nxg.edges(keys=True):
        ag.add_edge(u, v, label=k)
    ag.layout('dot')
    ag.draw('graph.png')
    ag.edge_attr['front_size']=0.1

plot_graph(g.metagraph)

edge types

### Build Model

In [7]:
class EmbeddingLayer(nn.Module):
    def __init__(self, num_nodes, h_dim):
        super(EmbeddingLayer, self).__init__()
        self.embedding = torch.nn.Embedding(num_nodes, h_dim)

    def forward(self, g, h, r, norm):
        return self.embedding(h.squeeze())

class RGCN(BaseRGCN):
    def build_input_layer(self):
        return EmbeddingLayer(self.num_nodes, self.h_dim)

    def build_hidden_layer(self, idx):
        act = F.relu if idx < self.num_hidden_layers - 1 else None
        return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, 'basis',
                self.num_bases, activation=act, self_loop=True,
                dropout=self.dropout)

class LinkPredict(nn.Module):
    def __init__(self, in_dim, h_dim, num_rels, num_bases=-1,
                 num_hidden_layers=1, dropout=0, use_cuda=False, reg_param=0):
        super(LinkPredict, self).__init__()
        self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, num_bases,
                         num_hidden_layers, dropout, use_cuda)
        self.reg_param = reg_param
        self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
        nn.init.xavier_uniform_(self.w_relation,
                                gain=nn.init.calculate_gain('relu'))

    def calc_score(self, embedding, triplets):
        # DistMult
        s = embedding[triplets[:,0]]
        r = self.w_relation[triplets[:,1]]
        o = embedding[triplets[:,2]]
        score = torch.sum(s * r * o, dim=1)
        return score

    def forward(self, g, h, r, norm):
        return self.rgcn.forward(g, h, r, norm)

    def regularization_loss(self, embedding):
        return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))

    def get_loss(self, g, embed, triplets, labels):
        # triplets is a list of data samples (positive and negative)
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss

def node_norm_to_edge_norm(g, node_norm):
    g = g.local_var()
    # convert to edge norm
    g.ndata['norm'] = node_norm
    g.apply_edges(lambda edges : {'norm' : edges.dst['norm']})
    return g.edata['norm']

In [95]:
data = load_data('FB15k-237')

# entities: 14541
# relations: 237
# edges: 272115


In [77]:
    num_nodes = data.num_nodes
    train_data = data.train
    valid_data = data.valid
    test_data = data.test
    num_rels = data.num_rels

In [60]:
g.nodes['disease'].data

{}

### load the graph

In [2]:
graph = np.load('../data/clean/graph.npy')

In [3]:
num_nodes = len(list(set(np.unique(graph[:,0])).union(set(np.unique(graph[:,2])))))
num_rels = np.unique(graph[:,1]).shape[0]

In [4]:
n_hidden = 500
n_bases = 100
n_layers = 2
dropout = 0.2
regularization =  0.01

use_cuda = torch.cuda.is_available()

In [8]:
model = LinkPredict(num_nodes,
                        n_hidden,
                        num_rels,
                        num_bases=n_bases,
                        num_hidden_layers=n_layers,
                        dropout=dropout,
                        use_cuda=True,
                        reg_param=regularization)
if use_cuda:
    model.cuda()

In [9]:
test_graph, test_rel, test_norm = utils.build_test_graph(
        num_nodes, num_rels, graph)

Test graph:
# nodes: 26785, # edges: 2283580


In [10]:
test_deg = test_graph.in_degrees(
                range(test_graph.number_of_nodes())).float().view(-1,1)
test_node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1)
test_rel = torch.from_numpy(test_rel)
test_norm = node_norm_to_edge_norm(test_graph, torch.from_numpy(test_norm).view(-1, 1))

In [12]:
# build adj list and calculate degrees for sampling
adj_list, degrees = utils.get_adj_and_degrees(num_nodes, graph)

In [15]:
degrees.shape

(26785,)