In [1]:
%matplotlib inline



Relational graph convolutional network
================================================

**Author:** Lingfan Yu, Mufei Li, Zheng Zhang

In this tutorial, you learn how to implement a relational graph convolutional
network (R-GCN). This type of network is one effort to generalize GCN 
to handle different relationships between entities in a knowledge base. To 
learn more about the research behind R-GCN, see `Modeling Relational Data with Graph Convolutional
Networks <https://arxiv.org/pdf/1703.06103.pdf>`_ 

The straightforward graph convolutional network (GCN) and 
`DGL tutorial <http://doc.dgl.ai/tutorials/index.html>`_) exploits
structural information of a dataset (that is, the graph connectivity) in order to
improve the extraction of node representations. Graph edges are left as
untyped.

A knowledge graph is made up of a collection of triples in the form
subject, relation, object. Edges thus encode important information and
have their own embeddings to be learned. Furthermore, there may exist
multiple edges among any given pair.


A brief introduction to R-GCN
---------------------------
In *statistical relational learning* (SRL), there are two fundamental
tasks:

- **Entity classification** - Where you assign types and categorical
  properties to entities.
- **Link prediction** - Where you recover missing triples.

In both cases, missing information is expected to be recovered from the 
neighborhood structure of the graph. For example, the R-GCN
paper cited earlier provides the following example. Knowing that Mikhail Baryshnikov was educated at the Vaganova Academy
implies both that Mikhail Baryshnikov should have the label person, and
that the triple (Mikhail Baryshnikov, lived in, Russia) must belong to the
knowledge graph.

R-GCN solves these two problems using a common graph convolutional network. It's 
extended with multi-edge encoding to compute embedding of the entities, but
with different downstream processing.

- Entity classification is done by attaching a softmax classifier at the
  final embedding of an entity (node). Training is through loss of standard
  cross-entropy.
- Link prediction is done by reconstructing an edge with an autoencoder
  architecture, using a parameterized score function. Training uses negative
  sampling.

This tutorial focuses on the first task, entity classification, to show how to generate entity
representation. `Complete
code <https://github.com/dmlc/dgl/tree/rgcn/examples/pytorch/rgcn>`_
for both tasks is found in the DGL Github repository.

Key ideas of R-GCN
-------------------
Recall that in GCN, the hidden representation for each node $i$ at
$(l+1)^{th}$ layer is computed by:

\begin{align}h_i^{l+1} = \sigma\left(\sum_{j\in N_i}\frac{1}{c_i} W^{(l)} h_j^{(l)}\right)~~~~~~~~~~(1)\\\end{align}

where $c_i$ is a normalization constant.

The key difference between R-GCN and GCN is that in R-GCN, edges can
represent different relations. In GCN, weight $W^{(l)}$ in equation
$(1)$ is shared by all edges in layer $l$. In contrast, in
R-GCN, different edge types use different weights and only edges of the
same relation type $r$ are associated with the same projection weight
$W_r^{(l)}$.

So the hidden representation of entities in $(l+1)^{th}$ layer in
R-GCN can be formulated as the following equation:

\begin{align}h_i^{l+1} = \sigma\left(W_0^{(l)}h_i^{(l)}+\sum_{r\in R}\sum_{j\in N_i^r}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}\right)~~~~~~~~~~(2)\\\end{align}

where $N_i^r$ denotes the set of neighbor indices of node $i$
under relation $r\in R$ and $c_{i,r}$ is a normalization
constant. In entity classification, the R-GCN paper uses
$c_{i,r}=|N_i^r|$.

The problem of applying the above equation directly is the rapid growth of
the number of parameters, especially with highly multi-relational data. In
order to reduce model parameter size and prevent overfitting, the original
paper proposes to use basis decomposition.

\begin{align}W_r^{(l)}=\sum\limits_{b=1}^B a_{rb}^{(l)}V_b^{(l)}~~~~~~~~~~(3)\\\end{align}

Therefore, the weight $W_r^{(l)}$ is a linear combination of basis
transformation $V_b^{(l)}$ with coefficients $a_{rb}^{(l)}$.
The number of bases $B$ is much smaller than the number of relations
in the knowledge base.

<div class="alert alert-info"><h4>Note</h4><p>Another weight regularization, block-decomposition, is implemented in
   the `link prediction <link-prediction_>`_.</p></div>

Implement R-GCN in DGL
----------------------

An R-GCN model is composed of several R-GCN layers. The first R-GCN layer
also serves as input layer and takes in features (for example, description texts)
that are associated with node entity and project to hidden space. In this tutorial,
we only use the entity ID as an entity feature.

R-GCN layers
~~~~~~~~~~~~

For each node, an R-GCN layer performs the following steps:

- Compute outgoing message using node representation and weight matrix
  associated with the edge type (message function)
- Aggregate incoming messages and generate new node representations (reduce
  and apply function)

The following code is the definition of an R-GCN hidden layer.

<div class="alert alert-info"><h4>Note</h4><p>Each relation type is associated with a different weight. Therefore,
   the full weight matrix has three dimensions: relation, input_feature,
   output_feature.</p></div>




In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
import dgl.function as fn
from functools import partial
import dgl

class RGCNLayer(nn.Module):
    def __init__(self, in_feat, out_feat, num_rels, num_bases=-1, bias=None,
                 activation=None, is_input_layer=False):
        super(RGCNLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.num_rels = num_rels
        self.num_bases = num_bases
        self.bias = bias
        self.activation = activation
        self.is_input_layer = is_input_layer

        # sanity check
        if self.num_bases <= 0 or self.num_bases > self.num_rels:
            self.num_bases = self.num_rels

        # weight bases in equation (3)
        self.weight = nn.Parameter(torch.Tensor(self.num_bases, self.in_feat,
                                                self.out_feat))
        if self.num_bases < self.num_rels:
            # linear combination coefficients in equation (3)
            self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases))

        # add bias
        if self.bias:
            self.bias = nn.Parameter(torch.Tensor(out_feat))

        # init trainable parameters
        nn.init.xavier_uniform_(self.weight,
                                gain=nn.init.calculate_gain('relu'))
        if self.num_bases < self.num_rels:
            nn.init.xavier_uniform_(self.w_comp,
                                    gain=nn.init.calculate_gain('relu'))
        if self.bias:
            nn.init.xavier_uniform_(self.bias,
                                    gain=nn.init.calculate_gain('relu'))

    def forward(self, g):
        if self.num_bases < self.num_rels:
            # generate all weights from bases (equation (3))
            weight = self.weight.view(self.in_feat, self.num_bases, self.out_feat)
            weight = torch.matmul(self.w_comp, weight).view(self.num_rels,
                                                        self.in_feat, self.out_feat)
        else:
            weight = self.weight

        if self.is_input_layer:
            def message_func(edges):
                # for input layer, matrix multiply can be converted to be
                # an embedding lookup using source node id
                embed = weight.view(-1, self.out_feat)
                index = edges.data['rel_type'] * self.in_feat + edges.src['id']
                return {'msg': embed[index] * edges.data['norm']}
        else:
            def message_func(edges):
                w = weight[edges.data['rel_type'].long()]
                msg = torch.bmm(edges.src['h'].unsqueeze(1), w).squeeze()
                msg = msg * edges.data['norm']
                return {'msg': msg}

        def apply_func(nodes):
            h = nodes.data['h']
            if self.bias:
                h = h + self.bias
            if self.activation:
                h = self.activation(h)
            return {'h': h}

        g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)

Full R-GCN model defined
~~~~~~~~~~~~~~~~~~~~~~~



In [3]:
class Model(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels,
                 num_bases=-1, num_hidden_layers=1):
        super(Model, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_rels = num_rels
        self.num_bases = num_bases
        self.num_hidden_layers = num_hidden_layers

        # create rgcn layers
        self.build_model()

        # create initial features
        self.features = self.create_features()

    def build_model(self):
        self.layers = nn.ModuleList()
        # input to hidden
        i2h = self.build_input_layer()
        self.layers.append(i2h)
        # hidden to hidden
        for _ in range(self.num_hidden_layers):
            h2h = self.build_hidden_layer()
            self.layers.append(h2h)
        # hidden to output
        h2o = self.build_output_layer()
        self.layers.append(h2o)

    # initialize feature for each node
    def create_features(self):
        features = torch.arange(self.num_nodes)
        return features

    def build_input_layer(self):
        return RGCNLayer(self.num_nodes, self.h_dim, self.num_rels, self.num_bases,
                         activation=F.relu, is_input_layer=True)

    def build_hidden_layer(self):
        return RGCNLayer(self.h_dim, self.h_dim, self.num_rels, self.num_bases,
                         activation=F.relu)

    def build_output_layer(self):
        return RGCNLayer(self.h_dim, self.out_dim, self.num_rels, self.num_bases,
                         activation=partial(F.softmax, dim=1))

    def forward(self, g):
        if self.features is not None:
            g.ndata['id'] = self.features
        for layer in self.layers:
            layer(g)
        return g.ndata.pop('h')

Handle dataset
~~~~~~~~~~~~~~~~
This tutorial uses Institute for Applied Informatics and Formal Description Methods (AIFB) dataset from R-GCN paper.



In [4]:
# load graph data
from dgl.contrib.data import load_data
import numpy as np
data = load_data(dataset='aifb')
num_nodes = data.num_nodes
num_rels = data.num_rels
num_classes = data.num_classes
labels = data.labels
train_idx = data.train_idx
# split training and validation set
val_idx = train_idx[:len(train_idx) // 5]
train_idx = train_idx[len(train_idx) // 5:]

# edge type and normalization factor
edge_type = torch.from_numpy(data.edge_type)
edge_norm = torch.from_numpy(data.edge_norm).unsqueeze(1)

labels = torch.from_numpy(labels).view(-1)

Loading dataset aifb
Number of nodes:  8285
Number of edges:  66371
Number of relations:  91
Number of classes:  4
removing nodes that are more than 3 hops away


Create graph and model
~~~~~~~~~~~~~~~~~~~~~~~



In [5]:
len(data.labels)

8285

In [6]:
for i, j in zip(data.edge_src, data.edge_dst):
    print(i, j)

0 0
919 0
919 0
1127 0
1127 0
1934 0
2062 0
5286 0
5286 0
7786 0
7786 0
1 1
1534 1
2 2
3577 2
3 3
7656 3
4 4
377 4
1028 4
1148 4
1148 4
1174 4
1384 4
1384 4
3298 4
3722 4
3822 4
3822 4
5080 4
5080 4
5176 4
5176 4
5354 4
5667 4
5961 4
7056 4
5 5
49 5
6 6
7519 6
7 7
7708 7
8 8
377 8
592 8
1668 8
2364 8
3705 8
4362 8
4583 8
5761 8
5761 8
6110 8
7056 8
9 9
2724 9
10 10
6516 10
11 11
3049 11
12 12
1934 12
3060 12
3060 12
5474 12
5474 12
7229 12
7229 12
7836 12
8274 12
8274 12
13 13
4757 13
14 14
5979 14
15 15
7005 15
16 16
1357 16
17 17
1490 17
1934 17
2682 17
3765 17
3765 17
18 18
63 18
63 18
1934 18
7693 18
19 19
7242 19
20 20
2966 20
21 21
7884 21
22 22
5355 22
23 23
2374 23
24 24
2369 24
25 25
7255 25
26 26
2453 26
27 27
60 27
28 28
1606 28
2104 28
29 29
6855 29
30 30
3193 30
31 31
2661 31
32 32
4790 32
33 33
277 33
944 33
962 33
1174 33
2028 33
2028 33
2672 33
3823 33
4123 33
6024 33
6024 33
6260 33
6260 33
6631 33
6631 33
7008 33
7056 33
7264 33
34 34
7980 34
35 35
8192 35
36 36
2359 

7196 291
8197 291
292 292
1934 292
5989 292
5989 292
8226 292
293 293
6119 293
294 294
1934 294
3060 294
3060 294
3130 294
4738 294
4738 294
5474 294
5474 294
8274 294
8274 294
295 295
4362 295
4764 295
6957 295
7873 295
296 296
7399 296
8027 296
297 297
1151 297
1491 297
2520 297
3313 297
3313 297
3931 297
4463 297
5293 297
5612 297
6527 297
6527 297
7511 297
7535 297
7988 297
7988 297
8234 297
298 298
5257 298
196 299
196 299
299 299
377 299
2641 299
3351 299
3822 299
3822 299
4250 299
4250 299
4268 299
4268 299
4274 299
4320 299
4320 299
4362 299
4365 299
4365 299
5065 299
5667 299
5826 299
5966 299
6633 299
7056 299
7072 299
7691 299
300 300
1462 300
1934 300
4581 300
4581 300
5013 300
5013 300
7955 300
7955 300
302 302
377 302
541 302
1131 302
1131 302
1380 302
2206 302
2712 302
2712 302
3324 302
3624 302
3624 302
3626 302
3677 302
3677 302
3916 302
3916 302
3942 302
4032 302
4032 302
4881 302
4881 302
5667 302
6142 302
6346 302
6346 302
6909 302
6909 302
7056 302
7934 302
8194 30

4639 377
4648 377
4649 377
4667 377
4685 377
4690 377
4690 377
4699 377
4717 377
4747 377
4765 377
4778 377
4786 377
4788 377
4790 377
4793 377
4795 377
4799 377
4801 377
4818 377
4821 377
4830 377
4855 377
4861 377
4868 377
4879 377
4895 377
4898 377
4901 377
4927 377
4937 377
4947 377
5000 377
5001 377
5002 377
5011 377
5017 377
5025 377
5028 377
5037 377
5037 377
5064 377
5077 377
5077 377
5080 377
5086 377
5104 377
5110 377
5120 377
5126 377
5127 377
5131 377
5169 377
5176 377
5176 377
5186 377
5192 377
5195 377
5207 377
5213 377
5216 377
5222 377
5240 377
5242 377
5256 377
5257 377
5271 377
5277 377
5280 377
5297 377
5314 377
5335 377
5348 377
5369 377
5376 377
5378 377
5403 377
5404 377
5410 377
5411 377
5411 377
5414 377
5415 377
5418 377
5425 377
5427 377
5430 377
5449 377
5451 377
5457 377
5477 377
5495 377
5511 377
5515 377
5515 377
5522 377
5527 377
5537 377
5552 377
5566 377
5595 377
5618 377
5623 377
5636 377
5637 377
5641 377
5644 377
5649 377
5652 377
5676 377
5679 377
5

2173 735
3109 735
3344 735
3955 735
4162 735
4176 735
4454 735
4463 735
4581 735
4702 735
4742 735
4742 735
5389 735
5389 735
5451 735
5495 735
5612 735
6508 735
6994 735
7184 735
7364 735
7472 735
7487 735
7539 735
7634 735
7751 735
7779 735
7980 735
8148 735
736 736
1000 736
737 737
5650 737
738 738
1934 738
4685 738
4685 738
6595 738
6595 738
6642 738
405 739
739 739
1218 739
1218 739
1301 739
1911 739
2065 739
3298 739
4041 739
4362 739
4594 739
5053 739
5053 739
5337 739
5337 739
5491 739
5615 739
5615 739
5775 739
5775 739
7056 739
740 740
3266 740
81 741
81 741
167 741
167 741
268 741
268 741
377 741
473 741
473 741
552 741
552 741
741 741
1398 741
1398 741
1675 741
1923 741
1923 741
1934 741
2098 741
2098 741
2403 741
2403 741
2792 741
3306 741
3306 741
3395 741
3395 741
3713 741
3713 741
5110 741
5110 741
5537 741
5537 741
5701 741
6493 741
6493 741
6622 741
7099 741
7099 741
7186 741
7186 741
7343 741
7343 741
7444 741
7444 741
7554 741
7557 741
7557 741
184 742
230 742
377 7

1556 1174
1623 1174
1660 1174
1663 1174
1685 1174
1724 1174
1737 1174
1763 1174
1845 1174
1875 1174
1979 1174
2021 1174
2044 1174
2073 1174
2130 1174
2163 1174
2179 1174
2268 1174
2318 1174
2331 1174
2361 1174
2425 1174
2607 1174
2609 1174
2783 1174
2860 1174
2883 1174
2884 1174
2886 1174
2905 1174
2924 1174
2937 1174
2971 1174
3005 1174
3026 1174
3148 1174
3151 1174
3175 1174
3190 1174
3254 1174
3303 1174
3453 1174
3631 1174
3765 1174
3837 1174
3885 1174
3931 1174
3935 1174
4068 1174
4115 1174
4144 1174
4278 1174
4290 1174
4373 1174
4399 1174
4403 1174
4497 1174
4597 1174
4608 1174
4699 1174
4717 1174
4826 1174
5002 1174
5089 1174
5110 1174
5131 1174
5150 1174
5192 1174
5195 1174
5240 1174
5414 1174
5466 1174
5681 1174
5745 1174
5836 1174
5847 1174
5849 1174
5886 1174
5897 1174
6002 1174
6046 1174
6112 1174
6191 1174
6270 1174
6340 1174
6349 1174
6397 1174
6488 1174
6532 1174
6586 1174
6589 1174
6654 1174
6706 1174
6738 1174
6794 1174
6819 1174
6854 1174
6893 1174
6904 1174
6974 1174


4362 1538
4454 1538
4454 1538
5176 1538
5176 1538
6021 1538
6021 1538
6142 1538
6329 1538
6633 1538
7056 1538
164 1539
230 1539
377 1539
1110 1539
1110 1539
1539 1539
2004 1539
2004 1539
4250 1539
4250 1539
4362 1539
4389 1539
5667 1539
5900 1539
5900 1539
6633 1539
6716 1539
7036 1539
7036 1539
7056 1539
7099 1539 1591
7148 1591
7473 1591
1592 1592
6759 1592
1175 1593
1593 1593
1934 1593
3108 1593
3108 1593
3303 1593
3303 1593
372 1594
1594 1594
2031 1594
1595 1595
2608 1595
4846 1595
7205 1595
1596 1596
4709 1596
1597 1597
4075 1597
230 1598
377 1598
781 1598
781 1598
840 1598
1598 1598
2148 1598
2148 1598
4250 1598
4250 1598
4320 1598
4320 1598
4365 1598
4365 1598
5063 1598
5348 1598
5348 1598
6633 1598
7056 1598
7468 1598
1599 1599
1729 1599
1600 1600
3755 1600
1091 1601
1091 1601
1601 1601
1934 1601
4173 1601
4173 1601
4418 1601
4418 1601
4688 1601
5970 1601
5970 1601
1602 1602
4818 1602
1603 1603
4701 1603
838 1604
1604 1604
1934 1604
3671 1604
4679 1604
1605 1605
5457 1605
28 16

8190 1883
1885 1885
8147 1885
1886 1886
7661 1886
1887 1887
5612 1887
5665 1887
7085 1887
7085 1887
1888 1888
4144 1888
1889 1889
5623 1889
1890 1890
1934 1890
2256 1890
6000 1890
6000 1890
1891 1891
4493 1891
1805 1892
1805 1892
1892 1892
1934 1892
6673 1892
1893 1893
8235 1893
1894 1894
6205 1894
1895 1895
4698 1895
1896 1896
5975 1896
1897 1897
5411 1897
1898 1898
4610 1898
933 1899
1899 1899
1900 1900
6595 1900
153 1901
1901 1901
1902 1902
7889 1902
1903 1903
2768 1903
2808 1903
1904 1904
7466 1904
1078 1905
1905 1905
3779 1905
1906 1906
5990 1906
1907 1907
7856 1907
1908 1908
5775 1908
164 1909
164 1909
344 1909
344 1909
377 1909
377 1909
663 1909
663 1909
1399 1909
1399 1909
1451 1909
1909 1909
2053 1909
2053 1909
2631 1909
3197 1909
3197 1909
3642 1909
3642 1909
4320 1909
4320 1909
4454 1909
4454 1909
4811 1909
4811 1909
5002 1909
5002 1909
5667 1909
5667 1909
5799 1909
5799 1909
6097 1909
6097 1909
6139 1909
6139 1909
6207 1909
6423 1909
6423 1909
7148 1909
7148 1909
7203 1909


2223 2165
2223 2165
2279 2165
2279 2165
2361 2165
2361 2165
2372 2165
2372 2165
2434 2165
2434 2165
2467 2165
2467 2165
2529 2165
2529 2165
2616 2165
2616 2165
2803 2165
2803 2165
2860 2165
2860 2165
2866 2165
2866 2165
2889 2165
2889 2165
2912 2165
2912 2165
2919 2165
2919 2165
2920 2165
2920 2165
3035 2165
3035 2165
3148 2165
3148 2165
3254 2165
3254 2165
3315 2165
3315 2165
3328 2165
3418 2165
3418 2165
3533 2165
3533 2165
3584 2165
3584 2165
3591 2165
3591 2165
3703 2165
3703 2165
3717 2165
3717 2165
3840 2165
3840 2165
3850 2165
3850 2165
3877 2165
3877 2165
4132 2165
4132 2165
4244 2165
4244 2165
4383 2165
4383 2165
4400 2165
4400 2165
4483 2165
4483 2165
4528 2165
4528 2165
4555 2165
4555 2165
4562 2165
4562 2165
4597 2165
4597 2165
4639 2165
4639 2165
4685 2165
4685 2165
4699 2165
4699 2165
4717 2165
4717 2165
4879 2165
4879 2165
5192 2165
5192 2165
5240 2165
5240 2165
5257 2165
5257 2165
5335 2165
5335 2165
5410 2165
5410 2165
5623 2165
5623 2165
5681 2165
5681 2165
5682 2165


5745 2677
6586 2677
6974 2677
2678 2678
4702 2678
2236 2679
2679 2679
2680 2680
3824 2680
431 2681 2727
541 2727
569 2727
962 2727
1399 2727
2618 2727
2727 2727
3197 2727
4320 2727
4320 2727
4596 2727
4811 2727
5037 2727
5037 2727
5584 2727
5667 2727
5799 2727
6467 2727
7056 2727
2728 2728
6112 2728
132 2729
622 2729
742 2729
961 2729
1091 2729
1233 2729
1505 2729
1621 2729
1646 2729
1875 2729
2021 2729
2581 2729
2729 2729
2783 2729
2971 2729
3508 2729
3826 2729
3957 2729
4076 2729
4173 2729
4278 2729
4400 2729
4523 2729
4675 2729
4883 2729
5091 2729
5163 2729
5215 2729
5286 2729
6228 2729
6621 2729
6706 2729
6710 2729
6750 2729
6839 2729
6893 2729
7045 2729
7313 2729
7457 2729
7889 2729
8035 2729
8248 2729
699 2730
962 2730
2672 2730
2730 2730
3275 2730
3564 2730
3564 2730
4362 2730
4923 2730
6572 2730
6572 2730
6827 2730
7056 2730
7751 2730
7751 2730
2731 2731
2883 2731
377 2732
396 2732
396 2732
962 2732
1581 2732
2503 2732
2503 2732
2732 2732
3843 2732
3843 2732
4526 2732
5247 2732

7196 3197
7461 3197
7470 3197
7592 3197
7611 3197
7655 3197
7668 3197
7729 3197
7805 3197
7924 3197
8053 3197
8075 3197
8153 3197
8166 3197
3198 3198
8022 3198
3199 3199
6536 3199
1123 3200
3200 3200
3201 3201
6207 3201
7203 3201
3202 3202
5358 3202
3203 3203
5207 3203
3204 3204
5213 3204
3205 3205
7953 3205
1064 3206
1064 3206
1934 3206
3206 3206
6828 3206
377 3207
1618 3207
1618 3207
1668 3207
3207 3207
4362 3207
5432 3207
5570 3207
6142 3207
6823 3207
7056 3207
7997 3207
8109 3207
3208 3208
3313 3208
1934 3209
1938 3209
3209 3209
6160 3209
6160 3209
7547 3209
7547 3209
2278 3210
3210 3210
3211 3211
5584 3211
3212 3212
4772 3212
2406 3213
3213 3213
377 3214
933 3214
933 3214
1015 3214
1015 3214
1029 3214
1029 3214
1547 3214
1547 3214
1878 3214
1878 3214
1934 3214
2127 3214
2127 3214
2173 3214
2173 3214
2264 3214
2264 3214
2475 3214
2475 3214
2586 3214
2872 3214
3109 3214
3109 3214
3214 3214
3894 3214
3894 3214
4014 3214
4014 3214
4656 3214
4747 3214
4747 3214
5165 3214
5404 4365 3257

353 3677
377 3677
377 3677
401 3677
401 3677
432 3677
432 3677
603 3677
603 3677
622 3677
622 3677
661 3677
661 3677
662 3677
662 3677
742 3677
742 3677
749 3677
749 3677
805 3677
805 3677
809 3677
809 3677
896 3677
896 3677
1010 3677
1010 3677
1119 3677
1119 3677
1435 3677
1435 3677
1474 3677
1474 3677
1523 3677
1523 3677
1526 3677
1526 3677
1646 3677
1646 3677
1652 3677
1652 3677
1729 3677
1729 3677
1771 3677
1771 3677
2061 3677
2061 3677
2249 3677
2249 3677
2455 3677
2455 3677
2456 3677
2456 3677
2603 3677
2603 3677
2801 3677
2801 3677
2818 3677
2818 3677
2982 3677
2982 3677
2995 3677
2995 3677
3007 3677
3007 3677
5186 3696
3697 3697
7414 3697
3698 3698
4163 3698
377 3699
1007 3699
2094 3699
2094 3699
3699 3699
5095 3699
5095 3699
5200 3699
5200 3699
5365 3699
6282 3699
6677 3699
7056 3699
3700 3700
7510 3700
2218 3701
3701 3701
8126 3701
3702 3702
4747 3702
377 3703
2165 3703
2165 3703
3703 3703
4362 3703
5739 3703
6843 3703
7056 3703
7873 3703
8063 3703
377 3704
962 3704
1399 3704

4274 4166
4320 4166
4320 4166
5667 4166
5799 4166
6368 4166
7056 4166
7468 4166
4167 4167
7131 4167
4170 4170
6919 4170
4171 4171
7216 4171
1646 4172
4172 4172
377 4173
781 4173
781 4173
1384 4173
1384 4173
1432 4173
1601 4173
1601 4173
2031 4173
2031 4173
2356 4173
2356 4173
2598 4173
2598 4173
2660 4173
2660 4173
2729 4173
3019 4173
3019 4173
3298 4173
3657 4173
3657 4173
3807 4173
3807 4173
4173 4173
4312 4173
4312 4173
4687 4173
4687 4173
4871 4173
4871 4173
5444 4173
5444 4173
5967 4173
5967 4173
6677 4173
7056 4173
7196 4173
7196 4173
7323 4173
8027 4173
8027 4173
8235 4173
8235 4173
377 4174
2075 4174
2810 4174
2810 4174
4174 4174
4362 4174
5016 4174
5080 4174
5080 4174
5229 4174
7056 4174
7283 4174
2058 4175
2299 4175
2672 4175
3669 4175
4175 4175
5247 4175
6024 4175
6024 4175
6107 4175
6260 4175
6260 4175
7056 4175
7255 4175
7255 4175
7264 4175
7292 4175
76 4176
230 4176
377 4176
735 4176
1252 4176
1755 4176
1772 4176
1971 4176
2049 4176
2872 4176
3239 4176
4176 4176
4454 4176

KeyboardInterrupt: 

In [7]:
len(data.edge_dst),len(data.edge_src)

(65439, 65439)

In [8]:
# configurations
n_hidden = 16 # number of hidden units
n_bases = -1 # use number of relations as number of bases
n_hidden_layers = 0 # use 1 input layer, 1 output layer, no hidden layer
n_epochs = 25 # epochs to train
lr = 0.01 # learning rate
l2norm = 0 # L2 norm coefficient

# create graph
g = DGLGraph((data.edge_src, data.edge_dst))
g.edata.update({'rel_type': edge_type, 'norm': edge_norm})

# create model
model = Model(len(g),
              n_hidden,
              num_classes,
              num_rels,
              num_bases=n_bases,
              num_hidden_layers=n_hidden_layers)

In [9]:
g.edges(),g.nodes()

((tensor([   0,  919,  919,  ..., 5939, 5939, 8284]),
  tensor([   0,    0,    0,  ..., 8284, 8284, 8284])),
 tensor([   0,    1,    2,  ..., 8282, 8283, 8284]))

Training loop
~~~~~~~~~~~~~~~~



In [12]:
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2norm)

print("start training...")
model.train()
for epoch in range(n_epochs):
    optimizer.zero_grad()
    logits = model.forward(g)
    loss = F.cross_entropy(logits[train_idx], labels[train_idx].long())
    loss.backward()

    optimizer.step()

    train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx])
    train_acc = train_acc.item() / len(train_idx)
    val_loss = F.cross_entropy(logits[val_idx], labels[val_idx].long())
    val_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx])
    val_acc = val_acc.item() / len(val_idx)
    print("Epoch {:05d} | ".format(epoch) +
          "Train Accuracy: {:.4f} | Train Loss: {:.4f} | ".format(
              train_acc, loss.item()) +
          "Validation Accuracy: {:.4f} | Validation loss: {:.4f}".format(
              val_acc, val_loss.item()))

start training...
Epoch 00000 | Train Accuracy: 0.1429 | Train Loss: 1.3868 | Validation Accuracy: 0.1786 | Validation loss: 1.3866
Epoch 00001 | Train Accuracy: 0.9554 | Train Loss: 1.3479 | Validation Accuracy: 1.0000 | Validation loss: 1.3601
Epoch 00002 | Train Accuracy: 0.9554 | Train Loss: 1.2883 | Validation Accuracy: 1.0000 | Validation loss: 1.3200
Epoch 00003 | Train Accuracy: 0.9554 | Train Loss: 1.2108 | Validation Accuracy: 1.0000 | Validation loss: 1.2657
Epoch 00004 | Train Accuracy: 0.9464 | Train Loss: 1.1272 | Validation Accuracy: 1.0000 | Validation loss: 1.2005
Epoch 00005 | Train Accuracy: 0.9554 | Train Loss: 1.0516 | Validation Accuracy: 1.0000 | Validation loss: 1.1306
Epoch 00006 | Train Accuracy: 0.9643 | Train Loss: 0.9899 | Validation Accuracy: 1.0000 | Validation loss: 1.0632
Epoch 00007 | Train Accuracy: 0.9821 | Train Loss: 0.9412 | Validation Accuracy: 1.0000 | Validation loss: 1.0039
Epoch 00008 | Train Accuracy: 0.9821 | Train Loss: 0.9022 | Validation


The second task, link prediction
--------------------------------
So far, you have seen how to use DGL to implement entity classification with an 
R-GCN model. In the knowledge base setting, representation generated by
R-GCN can be used to uncover potential relationships between nodes. In the 
R-GCN paper, the authors feed the entity representations generated by R-GCN
into the `DistMult <https://arxiv.org/pdf/1412.6575.pdf>`_ prediction model
to predict possible relationships.

The implementation is similar to that presented here, but with an extra DistMult layer
stacked on top of the R-GCN layers. You can find the complete
implementation of link prediction with R-GCN in our `Github Python code example
 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn/link_predict.py>`_.



In [45]:
from dgl.nn.pytorch import RelGraphConv
class BaseRGCN(nn.Module):
    def __init__(self, num_nodes, h_dim, out_dim, num_rels, num_bases,
                 num_hidden_layers=1, dropout=0,
                 use_self_loop=False, use_cuda=False):
        super(BaseRGCN, self).__init__()
        self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_rels = num_rels
        self.num_bases = None if num_bases < 0 else num_bases
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
        self.use_self_loop = use_self_loop
        self.use_cuda = use_cuda

        # create rgcn layers
        self.build_model()

    def build_model(self):
        self.layers = nn.ModuleList()
        # i2h
        i2h = self.build_input_layer()
        if i2h is not None:
            self.layers.append(i2h)
        # h2h
        for idx in range(self.num_hidden_layers):
            h2h = self.build_hidden_layer(idx)
            self.layers.append(h2h)
        # h2o
        h2o = self.build_output_layer()
        if h2o is not None:
            self.layers.append(h2o)

    def build_input_layer(self):
        return None

    def build_hidden_layer(self, idx):
        raise NotImplementedError

    def build_output_layer(self):
        return None

    def forward(self, g, h, r, norm):
        for layer in self.layers:
            h = layer(g, h, r, norm)
        return h

In [46]:
class EmbeddingLayer(nn.Module):
    def __init__(self, num_nodes, h_dim):
        super(EmbeddingLayer, self).__init__()
        self.embedding = torch.nn.Embedding(num_nodes, h_dim)

    def forward(self, g, h, r, norm):
        return self.embedding(h.squeeze())

class RGCN(BaseRGCN):
    def build_input_layer(self):
        return EmbeddingLayer(self.num_nodes, self.h_dim)

    def build_hidden_layer(self, idx):
        act = F.relu if idx < self.num_hidden_layers - 1 else None
        return RelGraphConv(self.h_dim, self.h_dim, self.num_rels, "bdd",
                self.num_bases, activation=act, self_loop=True,
                dropout=self.dropout)

class LinkPredict(nn.Module):
    def __init__(self, in_dim, h_dim, num_rels, num_bases=-1,
                 num_hidden_layers=1, dropout=0, use_cuda=False, reg_param=0):
        super(LinkPredict, self).__init__()
        self.rgcn = RGCN(in_dim, h_dim, h_dim, num_rels * 2, num_bases,
                         num_hidden_layers, dropout, use_cuda)
        self.reg_param = reg_param
        self.w_relation = nn.Parameter(torch.Tensor(num_rels, h_dim))
        nn.init.xavier_uniform_(self.w_relation,
                                gain=nn.init.calculate_gain('relu'))

    def calc_score(self, embedding, triplets):
        # DistMult
        s = embedding[triplets[:,0]]
        r = self.w_relation[triplets[:,1]]
        o = embedding[triplets[:,2]]
        score = torch.sum(s * r * o, dim=1)
        return score

    def forward(self, g, h, r, norm):
        return self.rgcn.forward(g, h, r, norm)

    def regularization_loss(self, embedding):
        return torch.mean(embedding.pow(2)) + torch.mean(self.w_relation.pow(2))

    def get_loss(self, g, embed, triplets, labels):
        # triplets is a list of data samples (positive and negative)
        # each row in the triplets is a 3-tuple of (source, relation, destination)
        score = self.calc_score(embed, triplets)
        predict_loss = F.binary_cross_entropy_with_logits(score, labels)
        reg_loss = self.regularization_loss(embed)
        return predict_loss + self.reg_param * reg_loss

def node_norm_to_edge_norm(g, node_norm):
    g = g.local_var()
    # convert to edge norm
    g.ndata['norm'] = node_norm
    g.apply_edges(lambda edges : {'norm' : edges.dst['norm']})
    return g.edata['norm']

In [47]:
model = LinkPredict(num_nodes,
                        500,
                        num_rels,
                        num_bases=100,
                        num_hidden_layers=2,
                        dropout=0.2,
                        use_cuda=-1,
                        reg_param=0.01)

In [48]:
data = load_data('FB15k-237')

# entities: 14541
# relations: 237
# edges: 272115


In [49]:
num_nodes = data.num_nodes
train_data = data.train
valid_data = data.valid
test_data = data.test
num_rels = data.num_rels

In [50]:
valid_data = torch.LongTensor(valid_data)
test_data = torch.LongTensor(test_data)

In [51]:
def get_adj_and_degrees(num_nodes, triplets):
    """ Get adjacency list and degrees of the graph
    """
    adj_list = [[] for _ in range(num_nodes)]
    for i,triplet in enumerate(triplets):
        adj_list[triplet[0]].append([i, triplet[2]])
        adj_list[triplet[2]].append([i, triplet[0]])

    degrees = np.array([len(a) for a in adj_list])
    adj_list = [np.array(a) for a in adj_list]
    return adj_list, degrees

def sample_edge_neighborhood(adj_list, degrees, n_triplets, sample_size):
    """Sample edges by neighborhool expansion.
    This guarantees that the sampled edges form a connected graph, which
    may help deeper GNNs that require information from more than one hop.
    """
    edges = np.zeros((sample_size), dtype=np.int32)

    #initialize
    sample_counts = np.array([d for d in degrees])
    picked = np.array([False for _ in range(n_triplets)])
    seen = np.array([False for _ in degrees])

    for i in range(0, sample_size):
        weights = sample_counts * seen

        if np.sum(weights) == 0:
            weights = np.ones_like(weights)
            weights[np.where(sample_counts == 0)] = 0

        probabilities = (weights) / np.sum(weights)
        chosen_vertex = np.random.choice(np.arange(degrees.shape[0]),
                                         p=probabilities)
        chosen_adj_list = adj_list[chosen_vertex]
        seen[chosen_vertex] = True

        chosen_edge = np.random.choice(np.arange(chosen_adj_list.shape[0]))
        chosen_edge = chosen_adj_list[chosen_edge]
        edge_number = chosen_edge[0]

        while picked[edge_number]:
            chosen_edge = np.random.choice(np.arange(chosen_adj_list.shape[0]))
            chosen_edge = chosen_adj_list[chosen_edge]
            edge_number = chosen_edge[0]

        edges[i] = edge_number
        other_vertex = chosen_edge[1]
        picked[edge_number] = True
        sample_counts[chosen_vertex] -= 1
        sample_counts[other_vertex] -= 1
        seen[other_vertex] = True

    return edges

def sample_edge_uniform(adj_list, degrees, n_triplets, sample_size):
    """Sample edges uniformly from all the edges."""
    all_edges = np.arange(n_triplets)
    return np.random.choice(all_edges, sample_size, replace=False)

def generate_sampled_graph_and_labels(triplets, sample_size, split_size,
                                      num_rels, adj_list, degrees,
                                      negative_rate, sampler="uniform"):
    """Get training graph and signals
    First perform edge neighborhood sampling on graph, then perform negative
    sampling to generate negative samples
    """
    # perform edge neighbor sampling
    if sampler == "uniform":
        edges = sample_edge_uniform(adj_list, degrees, len(triplets), sample_size)
    elif sampler == "neighbor":
        edges = sample_edge_neighborhood(adj_list, degrees, len(triplets), sample_size)
    else:
        raise ValueError("Sampler type must be either 'uniform' or 'neighbor'.")

    # relabel nodes to have consecutive node ids
    edges = triplets[edges]
    src, rel, dst = edges.transpose()
    uniq_v, edges = np.unique((src, dst), return_inverse=True)
    src, dst = np.reshape(edges, (2, -1))
    relabeled_edges = np.stack((src, rel, dst)).transpose()

    # negative sampling
    samples, labels = negative_sampling(relabeled_edges, len(uniq_v),
                                        negative_rate)

    # further split graph, only half of the edges will be used as graph
    # structure, while the rest half is used as unseen positive samples
    split_size = int(sample_size * split_size)
    graph_split_ids = np.random.choice(np.arange(sample_size),
                                       size=split_size, replace=False)
    src = src[graph_split_ids]
    dst = dst[graph_split_ids]
    rel = rel[graph_split_ids]

    # build DGL graph
    print("# sampled nodes: {}".format(len(uniq_v)))
    print("# sampled edges: {}".format(len(src) * 2))
    g, rel, norm = build_graph_from_triplets(len(uniq_v), num_rels,
                                             (src, rel, dst))
    return g, uniq_v, rel, norm, samples, labels

def comp_deg_norm(g):
    g = g.local_var()
    in_deg = g.in_degrees(range(g.number_of_nodes())).float().numpy()
    norm = 1.0 / in_deg
    norm[np.isinf(norm)] = 0
    return norm

def build_graph_from_triplets(num_nodes, num_rels, triplets):
    """ Create a DGL graph. The graph is bidirectional because RGCN authors
        use reversed relations.
        This function also generates edge type and normalization factor
        (reciprocal of node incoming degree)
    """
    g = dgl.DGLGraph()
    g.add_nodes(num_nodes)
    src, rel, dst = triplets
    src, dst = np.concatenate((src, dst)), np.concatenate((dst, src))
    rel = np.concatenate((rel, rel + num_rels))
    edges = sorted(zip(dst, src, rel))
    dst, src, rel = np.array(edges).transpose()
    g.add_edges(src, dst)
    norm = comp_deg_norm(g)
    print("# nodes: {}, # edges: {}".format(num_nodes, len(src)))
    return g, rel.astype('int64'), norm.astype('int64')

def build_test_graph(num_nodes, num_rels, edges):
    src, rel, dst = edges.transpose()
    print("Test graph:")
    return build_graph_from_triplets(num_nodes, num_rels, (src, rel, dst))

def negative_sampling(pos_samples, num_entity, negative_rate):
    size_of_batch = len(pos_samples)
    num_to_generate = size_of_batch * negative_rate
    neg_samples = np.tile(pos_samples, (negative_rate, 1))
    labels = np.zeros(size_of_batch * (negative_rate + 1), dtype=np.float32)
    labels[: size_of_batch] = 1
    values = np.random.randint(num_entity, size=num_to_generate)
    choices = np.random.uniform(size=num_to_generate)
    subj = choices > 0.5
    obj = choices <= 0.5
    neg_samples[subj, 0] = values[subj]
    neg_samples[obj, 2] = values[obj]

    return np.concatenate((pos_samples, neg_samples)), labels

#######################################################################
#
# Utility functions for evaluations (raw)
#
#######################################################################

def sort_and_rank(score, target):
    _, indices = torch.sort(score, dim=1, descending=True)
    indices = torch.nonzero(indices == target.view(-1, 1))
    indices = indices[:, 1].view(-1)
    return indices

def perturb_and_get_raw_rank(embedding, w, a, r, b, test_size, batch_size=100):
    """ Perturb one element in the triplets
    """
    n_batch = (test_size + batch_size - 1) // batch_size
    ranks = []
    for idx in range(n_batch):
        print("batch {} / {}".format(idx, n_batch))
        batch_start = idx * batch_size
        batch_end = min(test_size, (idx + 1) * batch_size)
        batch_a = a[batch_start: batch_end]
        batch_r = r[batch_start: batch_end]
        emb_ar = embedding[batch_a] * w[batch_r]
        emb_ar = emb_ar.transpose(0, 1).unsqueeze(2) # size: D x E x 1
        emb_c = embedding.transpose(0, 1).unsqueeze(1) # size: D x 1 x V
        # out-prod and reduce sum
        out_prod = torch.bmm(emb_ar, emb_c) # size D x E x V
        score = torch.sum(out_prod, dim=0) # size E x V
        score = torch.sigmoid(score)
        target = b[batch_start: batch_end]
        ranks.append(sort_and_rank(score, target))
    return torch.cat(ranks)

# return MRR (raw), and Hits @ (1, 3, 10)
def calc_raw_mrr(embedding, w, test_triplets, hits=[], eval_bz=100):
    with torch.no_grad():
        s = test_triplets[:, 0]
        r = test_triplets[:, 1]
        o = test_triplets[:, 2]
        test_size = test_triplets.shape[0]

        # perturb subject
        ranks_s = perturb_and_get_raw_rank(embedding, w, o, r, s, test_size, eval_bz)
        # perturb object
        ranks_o = perturb_and_get_raw_rank(embedding, w, s, r, o, test_size, eval_bz)

        ranks = torch.cat([ranks_s, ranks_o])
        ranks += 1 # change to 1-indexed

        mrr = torch.mean(1.0 / ranks.float())
        print("MRR (raw): {:.6f}".format(mrr.item()))

        for hit in hits:
            avg_count = torch.mean((ranks <= hit).float())
            print("Hits (raw) @ {}: {:.6f}".format(hit, avg_count.item()))
    return mrr.item()

#######################################################################
#
# Utility functions for evaluations (filtered)
#
#######################################################################

def filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities):
    target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
    filtered_o = []
    # Do not filter out the test triplet, since we want to predict on it
    if (target_s, target_r, target_o) in triplets_to_filter:
        triplets_to_filter.remove((target_s, target_r, target_o))
    # Do not consider an object if it is part of a triplet to filter
    for o in range(num_entities):
        if (target_s, target_r, o) not in triplets_to_filter:
            filtered_o.append(o)
    return torch.LongTensor(filtered_o)

def filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities):
    target_s, target_r, target_o = int(target_s), int(target_r), int(target_o)
    filtered_s = []
    # Do not filter out the test triplet, since we want to predict on it
    if (target_s, target_r, target_o) in triplets_to_filter:
        triplets_to_filter.remove((target_s, target_r, target_o))
    # Do not consider a subject if it is part of a triplet to filter
    for s in range(num_entities):
        if (s, target_r, target_o) not in triplets_to_filter:
            filtered_s.append(s)
    return torch.LongTensor(filtered_s)

def perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter):
    """ Perturb object in the triplets
    """
    num_entities = embedding.shape[0]
    ranks = []
    for idx in range(test_size):
        if idx % 100 == 0:
            print("test triplet {} / {}".format(idx, test_size))
        target_s = s[idx]
        target_r = r[idx]
        target_o = o[idx]
        filtered_o = filter_o(triplets_to_filter, target_s, target_r, target_o, num_entities)
        target_o_idx = int((filtered_o == target_o).nonzero())
        emb_s = embedding[target_s]
        emb_r = w[target_r]
        emb_o = embedding[filtered_o]
        emb_triplet = emb_s * emb_r * emb_o
        scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
        _, indices = torch.sort(scores, descending=True)
        rank = int((indices == target_o_idx).nonzero())
        ranks.append(rank)
    return torch.LongTensor(ranks)

def perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter):
    """ Perturb subject in the triplets
    """
    num_entities = embedding.shape[0]
    ranks = []
    for idx in range(test_size):
        if idx % 100 == 0:
            print("test triplet {} / {}".format(idx, test_size))
        target_s = s[idx]
        target_r = r[idx]
        target_o = o[idx]
        filtered_s = filter_s(triplets_to_filter, target_s, target_r, target_o, num_entities)
        target_s_idx = int((filtered_s == target_s).nonzero())
        emb_s = embedding[filtered_s]
        emb_r = w[target_r]
        emb_o = embedding[target_o]
        emb_triplet = emb_s * emb_r * emb_o
        scores = torch.sigmoid(torch.sum(emb_triplet, dim=1))
        _, indices = torch.sort(scores, descending=True)
        rank = int((indices == target_s_idx).nonzero())
        ranks.append(rank)
    return torch.LongTensor(ranks)

def calc_filtered_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[]):
    with torch.no_grad():
        s = test_triplets[:, 0]
        r = test_triplets[:, 1]
        o = test_triplets[:, 2]
        test_size = test_triplets.shape[0]

        triplets_to_filter = torch.cat([train_triplets, valid_triplets, test_triplets]).tolist()
        triplets_to_filter = {tuple(triplet) for triplet in triplets_to_filter}
        print('Perturbing subject...')
        ranks_s = perturb_s_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter)
        print('Perturbing object...')
        ranks_o = perturb_o_and_get_filtered_rank(embedding, w, s, r, o, test_size, triplets_to_filter)

        ranks = torch.cat([ranks_s, ranks_o])
        ranks += 1 # change to 1-indexed

        mrr = torch.mean(1.0 / ranks.float())
        print("MRR (filtered): {:.6f}".format(mrr.item()))

        for hit in hits:
            avg_count = torch.mean((ranks <= hit).float())
            print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item()))
    return mrr.item()

#######################################################################
#
# Main evaluation function
#
#######################################################################

def calc_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits=[], eval_bz=100, eval_p="filtered"):
    if eval_p == "filtered":
        mrr = calc_filtered_mrr(embedding, w, train_triplets, valid_triplets, test_triplets, hits)
    else:
        mrr = calc_raw_mrr(embedding, w, test_triplets, hits, eval_bz)
    return mrr

In [52]:
# build test graph
test_graph, test_rel, test_norm = build_test_graph(
    num_nodes, num_rels, train_data)
test_deg = test_graph.in_degrees(range(test_graph.number_of_nodes())).float().view(-1,1)
test_node_id = torch.arange(0, num_nodes, dtype=torch.long).view(-1, 1)
test_rel = torch.from_numpy(test_rel)
test_norm = node_norm_to_edge_norm(test_graph, torch.from_numpy(test_norm).view(-1, 1))

Test graph:




# nodes: 14541, # edges: 544230


In [53]:
 # build adj list and calculate degrees for sampling
adj_list, degrees = get_adj_and_degrees(num_nodes, train_data)

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

In [57]:
import time
model_state_file = 'model_state.pth'
forward_time = []
backward_time = []
use_cuda=False
# training loop
print("start training...")

epoch = 0
best_mrr = 0
while True:
    model.train()
    epoch += 1

    # perform edge neighborhood sampling to generate training graph and data
    g, node_id, edge_type, node_norm, data, labels = \
        generate_sampled_graph_and_labels(
            train_data, 30000, 0.5,
            num_rels, adj_list, degrees, 10,
            'uniform')
    print("Done edge sampling")

        # set node/edge feature
    node_id = torch.from_numpy(node_id).view(-1, 1).long()
    edge_type = torch.from_numpy(edge_type)
    edge_norm = node_norm_to_edge_norm(g, torch.from_numpy(node_norm).view(-1, 1))
    data, labels = torch.from_numpy(data), torch.from_numpy(labels)
    deg = g.in_degrees(range(g.number_of_nodes())).float().view(-1, 1)
    if use_cuda:
        node_id, deg = node_id.cuda(), deg.cuda()
        edge_type, edge_norm = edge_type.cuda(), edge_norm.cuda()
        data, labels = data.cuda(), labels.cuda()

    t0 = time.time()
    embed = model(g, node_id, edge_type, edge_norm)
    loss = model.get_loss(g, embed, data, labels)
    t1 = time.time()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clip gradients
    optimizer.step()
    t2 = time.time()

    forward_time.append(t1 - t0)
    backward_time.append(t2 - t1)
    print("Epoch {:04d} | Loss {:.4f} | Best MRR {:.4f} | Forward {:.4f}s | Backward {:.4f}s".
            format(epoch, loss.item(), best_mrr, forward_time[-1], backward_time[-1]))

    optimizer.zero_grad()

        # validation
    if epoch % 2 == 0:
        # perform validation on CPU because full graph is too large
        if use_cuda:
            model.cpu()
        model.eval()
        print("start eval")
        embed = model(test_graph, test_node_id, test_rel, test_norm)
        mrr = utils.calc_mrr(embed, model.w_relation, torch.LongTensor(train_data),
                                valid_data, test_data, hits=[1, 3, 10], eval_bz=500,
                                eval_p='filtered')
        # save best model
        if mrr < best_mrr:
            if epoch >= 10:
                break
        else:
            best_mrr = mrr
            torch.save({'state_dict': model.state_dict(), 'epoch': epoch},
                        model_state_file)
        if use_cuda:
            model.cuda()

print("training done")
print("Mean forward time: {:4f}s".format(np.mean(forward_time)))
print("Mean Backward time: {:4f}s".format(np.mean(backward_time)))

start training...
# sampled nodes: 11815
# sampled edges: 30000




# nodes: 11815, # edges: 30000
Done edge sampling
Epoch 0001 | Loss 4.8210 | Best MRR 0.0000 | Forward 59.7757s | Backward 26.2488s
# sampled nodes: 11784
# sampled edges: 30000
# nodes: 11784, # edges: 30000
Done edge sampling
Epoch 0002 | Loss 3.4957 | Best MRR 0.0000 | Forward 6.3529s | Backward 10.1972s
start eval


KeyboardInterrupt: 

In [38]:
g

DGLGraph(num_nodes=11774, num_edges=30000,
         ndata_schemes={}
         edata_schemes={})

In [39]:
num_nodes

14541

In [44]:
g, len(node_id), len(edge_type), len(edge_norm)

(DGLGraph(num_nodes=11752, num_edges=30000,
          ndata_schemes={}
          edata_schemes={}), 11752, 30000, 30000)