In [2]:
import os
import sys
from collections import defaultdict
import itertools

import numpy as np
np.set_printoptions(edgeitems=10)
np.set_printoptions(edgeitems=30, linewidth=100000, 
    formatter=dict(float=lambda x: "%.3g" % x))
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import dgl

PackageDir = '/home/ubuntu/KGReasoning/tKGR'
sys.path.insert(1, PackageDir)

from utils import Data, NeighborFinder

%load_ext autoreload
%autoreload 2

# Reproducibility
torch.manual_seed(0)
np.random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Using backend: pytorch


In [3]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

## Load Data

In [8]:
contents = Data(dataset='ICEWS14_forecasting', add_reverse_relation=False)

In [9]:
adj = contents.get_adj_dict()
max_time = max(contents.data[:, 3])

In [14]:
contents.test_data_seen_entity[-10:]

array([[4573,   11,   18, 8736,   -1],
       [4573,   25,    5, 8736,   -1],
       [4573,   25,   18, 8736,   -1],
       [5136,   11,  900, 8736,   -1],
       [5299,   43,  128, 8736,   -1],
       [5299,   48,  128, 8736,   -1],
       [5411,   15, 2607, 8736,   -1],
       [5453,   11, 2165, 8736,   -1],
       [5726,   15,  302, 8736,   -1],
       [5737,   11,   95, 8736,   -1]])

## E2Graph

In [191]:
class E2Graph(NeighborFinder):
    def __init__(self, adj, sampling, num_entities, num_rel, num_neighbors, max_time=366*24):
        """

        :param adj:
        :param sampling:
        :param max_time: maximal timestamp of dataset
        :param num_entities:
        :param num_rel:
        """
        super(E2Graph, self).__init__(adj, sampling, max_time, num_entities)
        self.num_neighbors = num_neighbors
        self.selfloop = num_rel # index assigned for selfloop
        
    def set_init(self, events):
        """

        :param events: list of quadruplet (sub, pre, obj, timestamp), shape: [batch_size, ]
        Variables:
        event_graph_nodes: concatenate along node dimension, since different queries can have different number of neighbors
        """
        self.query_src_idx_l = np.array([event[0] for event in events])
        self.query_rel_idx_l = np.array([event[1] for event in events])
        self.query_tar_idx_l = np.array([event[2] for event in events])
        self.query_cut_time_l = np.array([event[3] for event in events])
        self.DP_step = 0
        self.num_query = len(events)
        
        # init Event Graph
        sub_batch = np.repeat(e2g.query_src_idx_l, e2g.num_neighbors + 1)
        group_index = np.repeat(np.arange(e2g.num_query), e2g.num_neighbors + 1)
        rel_batch_with_selfloop = np.concatenate([rel_batch, np.ones([e2g.num_query,1])*e2g.selfloop], axis=1)
        obj_batch_with_selfloop = np.concatenate([obj_batch, e2g.query_src_idx_l[:, np.newaxis]], axis=1)
        cut_time_batch_with_selfloop = np.concatenate([cut_time_batch, e2g.query_cut_time_l[:, np.newaxis]], axis=1)
        node_index = np.repeat(np.arange(self.num_neighbors-1, -1, -1)[np.newaxis, :], self.num_query, axis=0)
        
        event_graph_edges = self._construct_graph(sub_batch, rel_batch_with_selfloop, obj_batch_with_selfloop, cut_time_batch_with_selfloop, node_index)
        
        steps = np.zeros_like(node_index)
        event_graph_nodes = np.hstack([group_index[:, np.newaxis], 
                                 sub_batch[:, np.newaxis], 
                                 rel_batch_with_selfloop.reshape(-1,1), 
                                 obj_batch_with_selfloop.reshape(-1,1), 
                                 cut_time_batch_with_selfloop.reshape(-1,1), 
                                 node_index[:, np.newaxis], 
                                 steps]).astype(np.int32)
        event_graph_nodes = graph_nodes[graph_nodes[:, 2]!=-1]
        self.EventGraph = {0:{'nodes': graph_nodes, 'edges': {'same_sub': event_graph_edges[0], 'same_obj': event_graph_edges[1], 'sub_obj': event_graph_edges[2]}}}
        
    @staticmethod
    def _construct_graph(sub_batch, obj_batch, node_index_batch):
        """
        Really time consuming
        construct adj matrix(tensor) from event_graph_nodes
        return:
        event_graph_edges:
        """
        edge_same_sub = [] # edge: same subject
        edge_same_obj = [] # edge: same object
        edge_sub_obj = [] # edge: sub to obj
        for row in tqdm(range(len(sub_batch))):
            mask = sub_batch[row] != -1
            sub_fil = sub_batch[row][mask]
            obj_fil = obj_batch[row][mask]
            node_index_fil = node_index_batch[row][mask]
            
            sub2idx = defaultdict(list)
            obj2idx = defaultdict(list)
            
            for s, o, idx in zip(sub_fil, obj_fil, node_index_fil):
                sub2idx[s].append(idx)
                obj2idx[o].append(idx)
                
            for same_sub in sub2idx.values():
                if len(same_sub) > 1:
                    edge_same_sub.append(np.array([[row, i, j] for i, j in itertools.combinations(same_sub, r=2)]))

            for same_obj in obj2idx.values():
                if len(same_obj) > 1:
                    edge_same_obj.append(np.array([[row, i, j] for i, j in itertools.combinations(same_obj, r=2)]))

            obj_keys = obj2idx.keys()
            for sub in sub2idx.keys():
                if sub in obj_keys:
                    edge_sub_obj.append(np.array([[row, i, j, row, j, i] for i, j in itertools.product(sub2idx[sub], obj2idx[sub])]).reshape(-1, 3))
                    
        return np.concatenate(edge_same_sub), np.concatenate(edge_same_obj), np.concatenate(edge_sub_obj)
#         return edge_same_sub
    
    def flow(self)

In [192]:
e2g = E2Graph(adj, 2, len(contents.id2entity), len(contents.id2relation), 20, max_time=max_time)

In [44]:
e2g.find_before(5737, 8736)

(array([2074,   49,   49,  321,   49,   75,   95,   75,   95,   49,   75,
          49,   47,   95, 2074,   29,   49,   75,   49,  953,   95,   95,
          95,   75,   46,   55,   55,   75,   55,   75,  321,  321,  321,
         321,   76,  321,  321,   95,   92,  290,   95,   49,   49,   94,
          49,   49, 4841,   49,   26,   92,    0,   26,   96,   26,    0,
          26,   49,   26,   52, 2074,   52,   49,   52,   26,  151,   26,
          95, 1244, 1221, 1244,   95,   49,   46,   49,   95,  151,   46,
          49,   95,   46,   29,   95,   49,   75,   26,   75,   46,   95,
          95,   95,  321, 2180,   95, 2008,   94, 5454,   46,   95,   46,
          95, 1221,   46,   94, 2009,   95,   46,   95,  151,   46,   95,
         922,   46,   95,   46,  184,   46,   95,   46,   46,   46, 1430,
          46,   46,   46,   95,   46,   95,  391, 2734,   85,   95,   95,
          95,   46,   95,  953,   94,   94,   95,   95,   95,   46,   52,
        1383,   95,   95, 2866,   46, 

In [28]:
for batch_idx, sample in enumerate(DataLoader(contents.train_data, batch_size=8, shuffle=True)):
    print(sample)
    break

tensor([[ 265,    9,  633, 1728,   -1],
        [  18,    7,    8, 4320,   -1],
        [ 106,   13,  448, 4248,   -1],
        [1102,   14,   46, 4152,   -1],
        [  18,   74,    5, 3480,   -1],
        [  85,   18,   46, 2472,   -1],
        [  49,   25,   95, 4080,   -1],
        [ 143,   18,  936, 2328,   -1]])


In [29]:
e2g.set_init(sample.numpy())

In [78]:
graph_nodes[graph_nodes[:, 2]!=-1]

array([[   0,  265,   25,  633,  888],
       [   0,  265,   11,  633, 1488],
       [   0,  265,  230,  265, 1728],
       [   1,   18,   11,   36, 4224],
       [   1,   18,   11,  177, 4224],
       [   1,   18,    4,    8, 4248],
       [   1,   18,    4,   12, 4248],
       [   1,   18,    7, 2467, 4248],
       [   1,   18,   10,   12, 4248],
       [   1,   18,   13, 1802, 4248],
       [   1,   18,   18,  106, 4248],
       [   1,   18,   22,   12, 4248],
       [   1,   18,    0,   96, 4272],
       [   1,   18,    7,  302, 4272],
       [   1,   18,    7, 3294, 4272],
       [   1,   18,    9,   12, 4272],
       [   1,   18,    9,   50, 4272],
       [   1,   18,   14,   12, 4272],
       [   1,   18,   17,    8, 4272],
       [   1,   18,   17,   11, 4272],
       [   1,   18,   22,  132, 4272],
       [   1,   18,    7, 2467, 4296],
       [   1,   18,   11, 2467, 4296],
       [   1,   18,  230,   18, 4320],
       [   2,  106,   13,  107, 3888],
       [   2,  106,   13,

In [81]:
np.arange(,0,-1)

array([10,  9,  8,  7,  6,  5,  4,  3,  2,  1])

In [87]:
np.repeat(np.array([1,2,3]), 3, axis=0)

array([1, 1, 1, 2, 2, 2, 3, 3, 3])

In [89]:
x = np.ones([3,4])

In [91]:
import timeit

In [97]:
a = np.random.randint(3, size = (10))

In [153]:
sub = np.random.randint(100, size=(1024, 800))
obj = np.random.randint(100, size=(1024, 800))
node_idx = np.repeat(np.arange(800)[np.newaxis, :], 1024, axis=0)
_construct_graph(sub, obj, node_idx)

  9%|▉         | 91/1024 [03:38<36:54,  2.37s/it]

KeyboardInterrupt: 

In [163]:
import time
from tqdm import tqdm
import multiprocessing

def _construct_graph(sub_batch, obj_batch, node_index_batch):
    """
    construct adj matrix(tensor) from event_graph_nodes
    return:
    event_graph_edges:
    """
    t_0 = time.time()
    edge_same_sub = [] # edge: same subject
    edge_same_obj = [] # edge: same object
    edge_sub_obj = [] # edge: sub to obj
    batch_size = len(sub_batch)
    def each_row(sub_batch_row, obj_batch_row, node_index_row):
        mask = sub_batch_row != -1
        sub_fil = sub_batch_row[mask]
        obj_fil = obj_batch_row[mask]
        node_index_fil = node_index_batch_row[mask]

        sub2idx = defaultdict(list)
        obj2idx = defaultdict(list)

        for s, o, idx in zip(sub_fil, obj_fil, node_index_fil):
            sub2idx[s].append(idx)
            obj2idx[o].append(idx)

            for same_sub in sub2idx.values():
                same_sub = np.array([[row, i, j] for i, j in itertools.combinations(same_sub, r=2)])

            for same_obj in obj2idx.values():
                same_obj = np.array([[row, i, j] for i, j in itertools.combinations(same_obj, r=2)])

            for sub in sub2idx.keys():
                sub_obj = np.array([[row, i, j, row, j, i] for i, j in itertools.product(sub2idx[sub], obj2idx[sub])]).reshape(-1, 3)
        
        return same_sub, same_obj, sub_obj
    
    with multiprocessing.Pool(processes=4) as pool:
        results = pool.starmap(each_row, zip(sub_batch, obj_batch, node_index_batch))
        
    for res in results:
        edge_same_sub.append(res[0])
        edge_same_obj.append(res[1])
        edge_sub_obj.append(res[2])
    a, b, c = np.concatenate(edge_same_sub), np.concatenate(edge_same_obj), np.concatenate(edge_sub_obj)
    print(time.time() - t_0)
    return a, b, c    
        

In [184]:
sub = np.random.randint(10, size=(128, 40))
obj = np.random.randint(10, size=(128, 40))
node_idx = np.repeat(np.arange(40)[np.newaxis, :], 128, axis=0)
E2Graph._construct_graph(sub, obj, node_idx)


  0%|          | 0/128 [00:00<?, ?it/s][A
 19%|█▉        | 24/128 [00:00<00:00, 232.28it/s][A
 36%|███▌      | 46/128 [00:00<00:00, 226.57it/s][A
 54%|█████▍    | 69/128 [00:00<00:00, 224.48it/s][A
 72%|███████▏  | 92/128 [00:00<00:00, 223.22it/s][A
100%|██████████| 128/128 [00:00<00:00, 219.91it/s][A


[array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([[0, 2, 3]]),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([[0, 2, 3]]),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([[0, 2, 3]]),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([[0, 2, 3]]),
 array([], dtype=float64),
 array([], dtype=float64),
 array([[0, 4, 6]]),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 array([], dtype=float64),
 arr

In [190]:
sub = np.random.randint(10, size=(128, 800))
obj = np.random.randint(10, size=(128, 800))
node_idx = np.repeat(np.arange(800)[np.newaxis, :], 128, axis=0)
E2Graph._construct_graph(sub, obj, node_idx)


  0%|          | 0/128 [00:00<?, ?it/s][A
  1%|          | 1/128 [00:28<1:00:07, 28.40s/it][A

KeyboardInterrupt: 

## Step 1

H = [batch_size, num_nodes, embed_dim]\
A = [batch_size, num_nodes, num_nodes]
a = [batch_size, num_nodes]

In [None]:
def attn_head(seq, out_sz, bias_mat, activation, in_drop=0.0, coef_drop=0.0, residual=False):
    if in_drop != 0.0:
        seq =  nn.Dropout(p=in_drop)(seq)
        

In [6]:
m = torch.nn.Dropout(p=0.5)
input = torch.randn(5,3)
m(input)

tensor([[ 0.2454,  0.0000,  0.0000],
        [-4.3763,  0.0000,  1.4911],
        [-0.0000,  1.7023, -0.0000],
        [ 0.0000, -4.3149, -0.0000],
        [-0.7221,  3.7479, -2.7672]])

In [7]:
a=np.array([2,3])[np.newaxis]

In [10]:
a.shape

(1, 2)

In [13]:
import pdb

import torch.nn as nn
import torch.nn.functional as F

# Define a GAT layer
class GATLayer(nn.Module):
    def __init__(self, g, in_feats, out_feats, out_feats_query=5):
        super(GATLayer, self).__init__()
        self.g = g
        self.linear_func = nn.Linear(in_feats, out_feats, bias=False)
        
        self.attention_func_wq = nn.Linear(out_feats, out_feats_query, bias=False)
        self.attention_func_wk = nn.Linear(out_feats, out_feats_query, bias=False)
        self.attention_func_wv = nn.Linear(out_feats, out_feats, bias=False)
        
    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attention_func_wq, gain=gain)
        nn.init.xavier_normal_(self.attention_func_wk, gain=gain)
        nn.init.xavier_normal_(self.attention_func_wv, gain=gain)
        
    def edge_attention(self, edges):
        concat_z = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        Q = self.attention_func_wq(edges.src['z'])
        K = self.attention_func_wk(edges.dst['z'])
        Q_len, Q_feat_dim = Q.shape
#         src_e = torch.dot(self.attention_func_wq(edges.src['z']), self.attention_func_wk(edges.dst['z']))
        src_e = torch.bmm(Q.view(Q_len, 1, Q_feat_dim), K.view(Q_len, Q_feat_dim, 1))
        src_e = torch.squeeze(src_e, dim=-1)
#         src_e = F.leaky_relu(src_e)
        return {'e': src_e}
    
    def message_func(self, edges):
        V = self.attention_func_wv(edges.src['z'])
        return {'v': V, 'flow_score': edges.src['flow_score'], 'e':edges.data['e']}
        
    def reduce_func(self, nodes):
        alpha = F.softmax(F.normalize(nodes.mailbox['e'], p=1), dim=1)
#         pdb.set_trace()
        h = torch.sum(alpha * nodes.mailbox['v'], dim=1)
        new_flow_score = torch.sum(alpha * nodes.mailbox['flow_score'], dim=1)
        return {'h': h, 'flow_score': new_flow_score}
                               
    def forward(self, h, flow_score):
        self.g.ndata['z'] = h
        self.g.ndata['flow_score'] = flow_score
        self.g.apply_edges(self.edge_attention)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h'), self.g.ndata.pop('flow_score')
  
   
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h, flow_score):
        head_outs = [attn_head(h, flow_score) for attn_head in self.heads]
        head_outs_h = [o[0] for o in head_outs]
        head_outs_flow_score = [o[1] for o in head_outs]
#         pdb.set_trace()
        head_outs_flow_score = torch.mean(torch.stack(head_outs_flow_score), 0)
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs_h, dim=1), head_outs_flow_score
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs_h), 0), head_outs_flow_score
        
    
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h, flow_score):
        h, flow_score = self.layer1(h, flow_score)
        h = F.elu(h)
        flow_score = F.softmax(flow_score)
        h = self.layer2(h, flow_score)
        return h

In [14]:
import time
import numpy as np
import torch
from dgl.data import citation_graph as citegrh
from dgl import DGLGraph

In [15]:
def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    flow_score = torch.randn(len(labels), 1, dtype=torch.float32)
    mask = torch.BoolTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, flow_score, mask

In [19]:
data = citegrh.load_cora()

In [40]:
DGLGraph(data.graph)

AttributeError: 'DGLGraph' object has no attribute 'g'

In [16]:
g, features, labels, flow_score, mask = load_cora_data()

net = GAT(g,
          in_dim=features.size()[1],
          hidden_dim=8,
          out_dim=7,
          num_heads=2)

# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

In [17]:
# main loop
dur = []
for epoch in range(30):
    if epoch >= 3:
        t0 = time.time()

    logits, new_flow_score = net(features, flow_score)
#     print(new_flow_score)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
        epoch, loss.item(), np.mean(dur)))

RuntimeError: size mismatch, m1: [10556 x 1433], m2: [8 x 5] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:136

In [None]:
class MultiHeadGATLayer(nn.Module):
        "Multi-Head Attention"
    def __init__(self, dim_model, n_head):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadAttention, self).__init__()
        self.d_k = dim_model // n_heaad
        self.n_head = n_head
        # W_q, W_k, W_v, W_o
        self.attention_func_wq = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wk = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wv = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wo = nn.Linear(dim_model, self,d_k, bias=False)

    def get(self, x, fields='qkv'):
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
        if 'q' in fields:
            ret['q'] = self.attention_func_wq(x).view(batch_size, self.n_head, self.d_k)
        if 'k' in fields:
            ret['k'] = self.attention_func_wk(x).view(batch_size, self.n_head, self.d_k)
        if 'v' in fields:
            ret['v'] = self.attention_func_wv(x).view(batch_size, self.n_head, self.d_k)
        return ret
    
    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.attention_func_wo(x.view(batch_size, -1))
    
    def forward(self, graph, h, flow_score):
        graph.ndata['z'] = h
        graph.ndata['flow_score'] = flow_score
        graph.apply_nodes()
        self.propagate_attention(graph, eids)
        g.apply_edges(self.propagate_attention)
        g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h'), self.g.ndata.pop('flow_score')


import copy
def clones(module, k):
    return nn.ModuleList(
        copy.deepcopy(module) for _ in range(k)
    )

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def pre_func(self, i, fields='qkv'):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
        return func
    
    def post_func(self, i):
        """
        1, Normalize (softmax denominator) and get output of multi-Head attention
        2, Applying a two layer position-wise feed forward layer on x then add residual connection:
        """
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func
        
def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func

class GAT(nn.Module):
    def __init__(self, encoder, node_embed, h, d_k):
        super(GAT, self).__init__()
        self.node_embed = node_embed
        self.encoder encoder
        self.h = h
        self.d_k = d_k
        
    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
        # Update node state
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), # src to edge
                         fn.src_mul_edge('flow_score', 'score', 'flow_score'), # src to edge
                         fn.copy_e('score', 'score'), # edge to message
                         fn.copy_e('flow_score', 'flow_score')], # edge to message
                        [fn.sum('v', 'wv'), fn.sum('score', 'z'), fn.sum('flow_score', 'flow_score')])

    def update_graph(g, eids, pre_pairs, post_pairs):
        # pre-compute queries and key_value pairs
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # post-compute
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)
        
    
    def forward(self, g, nodes_l):
        """
        g: DGLGraph, preprocessed: node raw embedding from entity, relation and time embedding
        nodes: 2-d numpy array: (node_idx, sub_idx, rel_idx, obj_idx, timestamp)
        """
        # embed
        node_embed = self.node_embed(nodes_l)
        g.nodes[nodes_l[:,0]].data['x'] = node_embed
        
        for i in range(self.encoder.N):
            # dynamically expand graph
            # TBD
            pre_func = self.encoder.pre_func(i, 'qkv')
            post_func = self.encoder.post_func(i)
            edges = g.edges()
            nodes = g.nodes()
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
            
            # pruning
            # TBD
    
class NodeEncoder(nn.Module):
    def __init__(self, entity_encoder, relation_encoder, time_encoder):
        super(NodeEncoder, self).__init__()
        self.entity_encoder = entity_encoder
        self.relation_encoder = relation_encoder
        self.time_encoder = time_encoder
        
    def forward(self, node):
        """
        node: 2d-numpy.array, (sub_idx, rel_idx, obj_idx, timestamp)
        return:
        torch.tensor
        """
        # TBD
#         sub_embed = self.entity_encoder(node[0])
        pass
            
            

In [None]:
entity_encoder = torch.nn.Embedding(num_nodes + 1, embed_dim)
rel_encoder = torch.nn.Embedding(num_relation + 1, embed_dim)
time_encoder = TimeEncode(embed_dim)

In [None]:
model = GAT(encoder, decoder, )

In [82]:
import dgl
g = dgl.graph([(0, 1), (1, 2)], 'user', 'follows')

In [86]:
g.ndata['f'] = torch.randn(3,4)

In [60]:
feat = torch.randn(3,4)
g.nodes['user'].data['h']=feat

In [71]:
g.nodes[np.array([1,0,2])].data['h'] = feat

In [72]:
g.nodes[2]

NodeSpace(data={'h': tensor([[ 0.2211,  0.1151, -1.5755,  0.6421]])})

In [68]:
feat

tensor([[-0.3082,  0.4389,  0.6447, -0.1865],
        [ 1.3271, -1.3016, -0.4016, -0.4803],
        [ 0.2211,  0.1151, -1.5755,  0.6421]])

## Unit test on Graph Attention Network based on scaled dot-product attention 

In [1]:
import time

import numpy as np
import torch
import torch.nn as nn

from dgl.data import citation_graph as citegrh
from dgl import DGLGraph
import dgl.function as fn

Using backend: pytorch


In [9]:
class MultiHeadGATLayer(nn.Module):
    """
    Multi-Head Attention
    """
    def __init__(self, dim_model, n_head):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadGATLayer, self).__init__()
        self.d_k = dim_model // n_head
        self.n_head = n_head
        # W_q, W_k, W_v, W_o
        self.attention_func_wq = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wk = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wv = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wo = nn.Linear(dim_model, self.d_k, bias=False)

    def get(self, x, fields='qkv'):
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
        if 'q' in fields:
            ret['q'] = self.attention_func_wq(x).view(batch_size, self.n_head, self.d_k)
        if 'k' in fields:
            ret['k'] = self.attention_func_wk(x).view(batch_size, self.n_head, self.d_k)
        if 'v' in fields:
            ret['v'] = self.attention_func_wv(x).view(batch_size, self.n_head, self.d_k)
        return ret
    
    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.attention_func_wo(x.view(batch_size, -1))


import copy
def clones(module, k):
    return nn.ModuleList(
        copy.deepcopy(module) for _ in range(k)
    )

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = torch.nn.LayerNorm(layer.size)
        
    def pre_func(self, i, fields='qkv'):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
        return func
    
    def post_func(self, i):
        """
        1, Normalize (softmax denominator) and get output of multi-Head attention
        2, Applying a two layer position-wise feed forward layer on x then add residual connection:
        """
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func
    
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn # (key, query, value, mask)
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerWrapper(size, dropout), 2)
        
class SubLayerWrapper(nn.Module):
    '''
    The module wraps normalization, dropout, residual connection into one equation:
    sublayerwrapper(sublayer)(x) = x + dropout(sublayer(norm(x)))
    '''
    def __init__(self, size, dropout):
        super(SubLayerWrapper, self).__init__()
        self.norm = torch.nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
        
def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func

class GAT(nn.Module):
    def __init__(self, encoder, h, d_k):
        super(GAT, self).__init__()
        self.encoder = encoder
        self.h = h
        self.d_k = d_k
        
    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
        # Update node state
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), # src to edge
                         fn.copy_e('score', 'score')], # edge to message
                        [fn.sum('v', 'wv'), fn.sum('score', 'z')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        # pre-compute queries and key_value pairs
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # post-compute
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)
        
    
    def forward(self, g):
        """
        g: DGLGraph, preprocessed: node raw embedding from entity, relation and time embedding
        nodes: 2-d numpy array: (node_idx, sub_idx, rel_idx, obj_idx, timestamp)
        """
        for i in range(self.encoder.N):
            # dynamically expand graph
            # TBD
            pre_func = self.encoder.pre_func(i, 'qkv')
            post_func = self.encoder.post_func(i)
            edges = g.edges()
            nodes = g.nodes()
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
            
            # pruning
            # TBD
        return g.ndata.pop('x')

load data

In [10]:
def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    flow_score = torch.ones(len(labels), 1, dtype=torch.float32)
    mask = torch.BoolTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, flow_score, mask

In [14]:
g, features, labels, flow_score, mask = load_cora_data()

In [15]:
dim_model = 1433
c = copy.deepcopy
attn = MultiHeadGATLayer(dim_model, 1)

encoder = Encoder(EncoderLayer(dim_model, c(attn), torch.nn.Linear(dim_model, dim_model), 0.1), 1)

net = GAT(encoder, 1, dim_model//1)

# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

In [16]:
from tqdm import tqdm
# main loop
dur = []

for epoch in tqdm(range(30)):
    optimizer.zero_grad()
    if epoch >= 3:
        t0 = time.time()
    g_copy = copy.deepcopy(g) # otherwise RuntimeError: Trying to backward through the graph a second time
    g_copy.ndata['x'] = features
    out = net(g_copy)
    loss = torch.sum(out)
    loss.backward()
    optimizer.step()
    
#     logits, new_flow_score = net(g)
# #     print(new_flow_score)
#     logp = F.log_softmax(logits, 1)
#     loss = F.nll_loss(logp[mask], labels[mask])

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()

#     if epoch >= 3:
#         dur.append(time.time() - t0)

#     print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
#         epoch, loss.item(), np.mean(dur)))

100%|██████████| 30/30 [00:48<00:00,  1.62s/it]


## Add flow score

In [None]:
class MultiHeadGATLayer(nn.Module):
    """
    Multi-Head Attention
    """
    def __init__(self, dim_model, n_head):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadGATLayer, self).__init__()
        self.d_k = dim_model // n_head
        self.n_head = n_head
        # W_q, W_k, W_v, W_o
        self.attention_func_wq = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wk = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wv = nn.Linear(dim_model, self.d_k, bias=False)
        self.attention_func_wo = nn.Linear(dim_model, self.d_k, bias=False)

    def get(self, x, fields='qkv'):
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
        if 'q' in fields:
            ret['q'] = self.attention_func_wq(x).view(batch_size, self.n_head, self.d_k)
        if 'k' in fields:
            ret['k'] = self.attention_func_wk(x).view(batch_size, self.n_head, self.d_k)
        if 'v' in fields:
            ret['v'] = self.attention_func_wv(x).view(batch_size, self.n_head, self.d_k)
        return ret
    
    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.attention_func_wo(x.view(batch_size, -1))


import copy
def clones(module, k):
    return nn.ModuleList(
        copy.deepcopy(module) for _ in range(k)
    )

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.N = N
        self.layers = clones(layer, N)
        self.norm = torch.nn.LayerNorm(layer.size)
        
    def pre_func(self, i, fields='qkv'):
        layer = self.layers[i]
        def func(nodes):
            x = nodes.data['x']
            norm_x = layer.sublayer[0].norm(x)
            return layer.self_attn.get(norm_x, fields=fields)
        return func
    
    def post_func(self, i):
        """
        1, Normalize (softmax denominator) and get output of multi-Head attention
        2, Applying a two layer position-wise feed forward layer on x then add residual connection:
        """
        layer = self.layers[i]
        def func(nodes):
            x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
            o = layer.self_attn.get_o(wv / z)
            x = x + layer.sublayer[0].dropout(o)
            x = layer.sublayer[1](x, layer.feed_forward)
            return {'x': x if i < self.N - 1 else self.norm(x)}
        return func
    
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn # (key, query, value, mask)
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerWrapper(size, dropout), 2)
        
class SubLayerWrapper(nn.Module):
    '''
    The module wraps normalization, dropout, residual connection into one equation:
    sublayerwrapper(sublayer)(x) = x + dropout(sublayer(norm(x)))
    '''
    def __init__(self, size, dropout):
        super(SubLayerWrapper, self).__init__()
        self.norm = torch.nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
        
def src_dot_dst(src_field, dst_field, out_field):
    def func(edges):
        return {out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(-1, keepdim=True)}

    return func

def scaled_exp(field, scale_constant):
    def func(edges):
        # clamp for softmax numerical stability
        return {field: torch.exp((edges.data[field] / scale_constant).clamp(-5, 5))}

    return func

class GAT(nn.Module):
    def __init__(self, encoder, h, d_k):
        super(GAT, self).__init__()
        self.encoder = encoder
        self.h = h
        self.d_k = d_k
        
    def propagate_attention(self, g, eids):
        # Compute attention score
        g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
        g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
        # Update node state
        g.send_and_recv(eids,
                        [fn.src_mul_edge('v', 'score', 'v'), # src to edge
                         fn.src_mul_edge('flow_score', 'score', 'flow_score'), # src to edge
                         fn.copy_e('score', 'score'), 
                         fn.copy_e('flow_score', 'flow_score')], # edge to message
                        [fn.sum('v', 'wv'), fn.sum('score', 'z'), fn.sum('flow_score', 'flow_score')])

    def update_graph(self, g, eids, pre_pairs, post_pairs):
        # pre-compute queries and key_value pairs
        for pre_func, nids in pre_pairs:
            g.apply_nodes(pre_func, nids)
        self.propagate_attention(g, eids)
        # post-compute
        for post_func, nids in post_pairs:
            g.apply_nodes(post_func, nids)
        
    
    def forward(self, g):
        """
        g: DGLGraph, preprocessed: node raw embedding from entity, relation and time embedding
        nodes: 2-d numpy array: (node_idx, sub_idx, rel_idx, obj_idx, timestamp)
        """
        for i in range(self.encoder.N):
            # dynamically expand graph
            # TBD
            
            pre_func = self.encoder.pre_func(i, 'qkv')
            post_func = self.encoder.post_func(i)
            edges = g.edges()
            nodes = g.nodes()
            self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
            
            # pruning
            # TBD
            
            # update entity graph
            
        return g.ndata.pop('x')

## Load dataset

In [5]:
contents = Data(dataset='ICEWS14_forecasting', add_reverse_relation=False)

In [6]:
adj = contents.get_adj_dict()
max_time = max(contents.data[:, 3])

In [7]:
for batch_idx, sample in enumerate(DataLoader(contents.train_data, batch_size=8, shuffle=True)):
    print(sample)
    break

tensor([[  22,   36,   23, 5976,   -1],
        [ 718,   56,  113, 5352,   -1],
        [  18,    7, 1223,  840,   -1],
        [   5,   13,  802, 1704,   -1],
        [  23,   15,   22, 4176,   -1],
        [  69,   24,   18, 5184,   -1],
        [  72,   30,   63, 1992,   -1],
        [ 405,   25,   15, 1656,   -1]])


In [8]:
sample

tensor([[  22,   36,   23, 5976,   -1],
        [ 718,   56,  113, 5352,   -1],
        [  18,    7, 1223,  840,   -1],
        [   5,   13,  802, 1704,   -1],
        [  23,   15,   22, 4176,   -1],
        [  69,   24,   18, 5184,   -1],
        [  72,   30,   63, 1992,   -1],
        [ 405,   25,   15, 1656,   -1]])

In [8]:
nf_last20 = NeighborFinder(adj, sampling=2, max_time=max_time, num_entities=len(contents.id2entity))

In [10]:
sample_np = sample.numpy()
src_idx_l = sample[:, 0]
cut_time_l = sample[:, 3]
ngh_obj_batch, ngh_rel_batch, ngh_ts_batch = nf_last20.get_temporal_neighbor(src_idx_l, cut_time_l)

In [11]:
ngh_obj_batch

array([[  23, 5251,   21,   21, 2000,   23,   60,   23, 5197,   23,   23, 2664, 3455, 3455, 1912, 5840,   23, 4981, 3246, 1197],
       [  22,   22,   22,   22,   22,   23, 2000,   22,   23,  611,   23,   23,   22, 2000,   22,   23,   23,   22,   23,   23],
       [ 591, 1378,   26,  550,    5,   36, 1107,    5,    5,   12, 2020,    6,    5,  633,   26,   96,   17,    5,   39,   96],
       [  49,   49, 2796,   18,   49,   12,    8,    8,   12,    8,  802,   26,   12,  756,   49,   12,  365,   18,   12,   12],
       [  22,  113,  477,   22,   22,   22,  108,   22,   22,  113,   22,   22,   22,   22,   22,   22,   22,   22,   22,  113],
       [ 262,  262,    1,  262,  262,    1,  262, 1303,  218,  263,  262,  262,  262,    1,  262,  262,    1,  262,   18,  262],
       [  63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63,   63],
       [  15, 1466,   15,   15, 1913,   15,   15,   15,   15,   15,   15,   15,  286,  658,   15,

In [12]:
batch_offset = 2000 # > max_attended_nodes * num_neighbors * DP_step
batch_size = 8
num_neighbors = 20

In [57]:
ngh_events = np.hstack([np.repeat(src_idx_l, num_neighbors)[:, np.newaxis], ngh_rel_batch.reshape(-1, 1), ngh_obj_batch.reshape(-1, 1), ngh_ts_batch.reshape(-1, 1)])
ngh_events_idx = np.tile(np.arange(num_neighbors)[::-1], batch_size) + np.repeat(np.arange(batch_size)*batch_offset, num_neighbors)
ngh_events_mask = ngh_events[:, 2] != -1
ngh_events = ngh_events[ngh_events_mask]
ngh_events_idx = ngh_events_idx[ngh_events_mask]
id2event = {idx: tuple(evt) for idx, evt in zip(ngh_events_idx, ngh_events)}

Let's assume there is only one relation in event graph

In [53]:
import pdb
from tqdm import tqdm
def construct_graph(sub_batch, obj_batch, node_index_batch):
    """
    Really time consuming
    construct adj matrix(tensor) from event_graph_nodes
    return:
    event_graph_edges:
    """
    edge_same_sub = [] # edge: same subject
    edge_same_obj = [] # edge: same object
    edge_sub_obj = [] # edge: sub to obj
    for row in tqdm(range(len(sub_batch))):

        mask = sub_batch[row] != -1
# event may have no neighbors
        sub_fil = sub_batch[row][mask]
        obj_fil = obj_batch[row][mask]
        node_index_fil = node_index_batch[row][mask]

        sub2idx = defaultdict(list)
        obj2idx = defaultdict(list)

        for s, o, idx in zip(sub_fil, obj_fil, node_index_fil):
            sub2idx[s].append(idx)
            obj2idx[o].append(idx)

#         pdb.set_trace()
        for same_sub in sub2idx.values():
            if len(same_sub) > 1:
                edge_same_sub.append(np.array([[i, j] for i, j in itertools.combinations(same_sub, r=2)]))

        for same_obj in obj2idx.values():
            if len(same_obj) > 1:
                edge_same_obj.append(np.array([[i, j] for i, j in itertools.combinations(same_obj, r=2)]))

        obj_keys = obj2idx.keys()
        for sub in sub2idx.keys():
            if sub in obj_keys:
                edge_sub_obj.append(np.array([[i, j, j, i] for i, j in itertools.product(sub2idx[sub], obj2idx[sub])]).reshape(-1, 3))

    edges = edge_same_sub + edge_same_obj + edge_sub_obj
    if edges:
        return np.concatenate(edges)
    else:
        None
#     return np.concatenate(edge_same_sub), np.concatenate(edge_same_obj), np.concatenate(edge_sub_obj)
#         return edge_same_sub

In [74]:
def expand_graph(source_event, id2event, subgraph_idx):
    """
    source_event: nodes index in DGLGraph, 1d tensor, [batch_size*max_attended_nodes, ]
    id2event: map from event id to quadruplet
    subgraph_idx: subgraph_idx[i] indicates source_event[i] is in which subgraph 
    """
    source_event_quad = np.vstack([np.array(id2event[src]) for src in source_event.numpy()])
    src_idx_l = source_event_quad[:, 2] # object of source_event is the subject of new events
    cut_time_l = source_event_quad[:, 3]
    
    # each with shape [batch_size*max_attended_nodes x num_neighbors]
    ngh_obj_batch, ngh_rel_batch, ngh_ts_batch = nf_last20.get_temporal_neighbor(src_idx_l, cut_time_l)
    ngh_events = np.hstack([np.repeat(src_idx_l, num_neighbors)[:, np.newaxis], 
                            ngh_rel_batch[:, np.newaxis], 
                            ngh_obj_batch[:, np.newaxis], 
                            ngh_ts_batch[:, np.newaxis]]) # N x 4
    ngh_subgraph_idx = np.repeat(subgraph_idx, num)
    
    
    # for each subgraph:
        # update id2event
        # 
    
    #
    
    
    
    
    
    
    

In [54]:
src_batch = np.repeat(src_idx_l.numpy()[:, np.newaxis], num_neighbors, axis=1)
ngh_events_idx = np.repeat(np.arange(num_neighbors)[::-1][np.newaxis], batch_size, axis=0) + batch_offset*np.repeat(np.arange(batch_size)[:, np.newaxis], num_neighbors, axis=1)

In [55]:
edges = construct_graph(src_batch, ngh_obj_batch, ngh_events_idx)



100%|██████████| 8/8 [00:00<00:00, 4762.87it/s]


In [57]:
edges.shape

(2076, 2)

In [64]:
import dgl
G = dgl.DGLGraph((list(edges[:, 0]), list(edges[:,1])))

torch.Size([2076])

In [76]:
G.nodes()

tensor([    0,     1,     2,  ..., 14017, 14018, 14019])

In [77]:
np.array((0,1,2))

array([0, 1, 2])

In [79]:
np.vstack([np.array([1,2,3]), np.array([2,3,4])])

array([[1, 2, 3],
       [2, 3, 4]])

In [9]:
from collections import defaultdict
sub2evt = defaultdict(list)
obj2evt = defaultdict(list)

for i, evt in enumerate(contents.train_data):
    sub2evt[evt[0]].append(i)
    obj2evt[evt[2]].append(i)

In [None]:
sub_sub_edges = [(i,j, abs(contents.train_data[i,3]-contents.train_data[j,3])) for gr in sub2evt.values() for i, j in itertools.combinations(gr, r=2)]
obj_obj_edges = [(i,j, abs(contents.train_data[i,3]-contents.train_data[j,3])) for gr in obj2evt.values() for i, j in itertools.combinations(gr, r=2)]
sub_obj_edges = list(itertools.chain.from_iterable(((i,j, abs(contents.train_data[i,3]-contents.train_data[j,3])),(j,i,abs(contents.train_data[j,3]-contents.train_data[i,3]))) for sub, sub_evt in sub2evt.items() for i, j in itertools.product(sub_evt, obj2evt[sub])))

In [60]:
edges = sub_sub_edges + obj_obj_edges + sub_obj_edges

In [61]:
temp = np.array(edges)
u,v = temp[:, 0], temp[:, 1]

In [62]:
G = dgl.DGLGraph()
G.add_nodes(len(contents.train_data))
#TBD add weights
G = dgl.DGLGraph((u,v))

In [106]:
contents.train_data

array([[   0,    0,    1,    0,   -1],
       [   0,    9,   84,    0,   -1],
       [   1,    0,    0,    0,   -1],
       [   1,   14,   26,    0,   -1],
       [   2,    1,    3,    0,   -1],
       [   2,   29,   78,    0,   -1],
       [   4,    2,    5,    0,   -1],
       [   5,    7,   13,    0,   -1],
       [   6,    3,    3,    0,   -1],
       [   6,   12,    3,    0,   -1],
       [   6,   26,   60,    0,   -1],
       [   7,    4,    8,    0,   -1],
       [   8,    7,   11,    0,   -1],
       [   8,   24,  101,    0,   -1],
       [   8,   36,  112,    0,   -1],
       [   9,    5,   10,    0,   -1],
       [  11,    4,    8,    0,   -1],
       [  11,    6,   12,    0,   -1],
       [  11,   14,    8,    0,   -1],
       [  11,   20,   12,    0,   -1],
       [  12,   10,   18,    0,   -1],
       [  12,   22,   18,    0,   -1],
       [  13,   11,   27,    0,   -1],
       [  13,   14,    5,    0,   -1],
       [  14,    8,   15,    0,   -1],
       [  14,   13,   15,

In [91]:
G.readonly()
hop1neigh = dgl.contrib.sampling.sampler.NeighborSampler(G, 20, expand_factor=500, num_hops=1, seed_nodes = torch.from_numpy(np.tile(np.arange(20),1000)))

In [92]:
for n in hop1neigh:
    print(n.layer_parent_nid(1))

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])
tensor([ 0,  1,  2,  3,  4,  5

In [45]:
sub = G.subgraph(np.arange(8))

In [46]:
sub.parent_nid

tensor([0, 1, 2, 3, 4, 5, 6, 7])

In [47]:
sub.parent_eid

tensor([       0, 24084998,    92235, 24177233,   189255, 24274253])

In [49]:
G.find_edges(24084998)

(tensor([0]), tensor([1]))

In [54]:
u[24084998:24085008]

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [39]:
sub_obj_edges

[(0, 2),
 (2, 0),
 (0, 119),
 (119, 0),
 (0, 195),
 (195, 0),
 (0, 288),
 (288, 0),
 (0, 576),
 (576, 0),
 (0, 680),
 (680, 0),
 (0, 693),
 (693, 0),
 (0, 753),
 (753, 0),
 (0, 1006),
 (1006, 0),
 (0, 1064),
 (1064, 0),
 (0, 1165),
 (1165, 0),
 (0, 1235),
 (1235, 0),
 (0, 1478),
 (1478, 0),
 (0, 1608),
 (1608, 0),
 (0, 1735),
 (1735, 0),
 (0, 1784),
 (1784, 0),
 (0, 1978),
 (1978, 0),
 (0, 2115),
 (2115, 0),
 (0, 2275),
 (2275, 0),
 (0, 2342),
 (2342, 0),
 (0, 2348),
 (2348, 0),
 (0, 2350),
 (2350, 0),
 (0, 2355),
 (2355, 0),
 (0, 2360),
 (2360, 0),
 (0, 2634),
 (2634, 0),
 (0, 2657),
 (2657, 0),
 (0, 2672),
 (2672, 0),
 (0, 2706),
 (2706, 0),
 (0, 2892),
 (2892, 0),
 (0, 3630),
 (3630, 0),
 (0, 3631),
 (3631, 0),
 (0, 3750),
 (3750, 0),
 (0, 3963),
 (3963, 0),
 (0, 4272),
 (4272, 0),
 (0, 4522),
 (4522, 0),
 (0, 4525),
 (4525, 0),
 (0, 4614),
 (4614, 0),
 (0, 4793),
 (4793, 0),
 (0, 4873),
 (4873, 0),
 (0, 4922),
 (4922, 0),
 (0, 4923),
 (4923, 0),
 (0, 5190),
 (5190, 0),
 (0, 5197),


In [38]:
len(sub_sub_edges) + len(obj_obj_edges)

24084998

In [35]:
G.edges()[0]

tensor([    0,     0,     0,  ..., 63674, 63675, 63675])

In [140]:
sub_obj_edges = list(itertools.chain.from_iterable(((i,j),(j,i)) for sub, sub_evt in sub2evt.items() for i, j in itertools.product(sub_evt, obj2evt[sub])))

In [24]:
test_g = dgl.DGLGraph()
test_g.add_nodes(5)
test_g.add_edges([0, 1, 2, 3, 4], [1, 2, 3, 4, 0])
SG = test_g.subgraph([0, 1, 4])

In [26]:
SG.edges()

(tensor([0, 2]), tensor([1, 0]))

In [40]:
sub_sub_edges_edges

[(0, 1),
 (0, 113),
 (0, 114),
 (0, 115),
 (0, 116),
 (0, 287),
 (0, 574),
 (0, 575),
 (0, 689),
 (0, 690),
 (0, 691),
 (0, 692),
 (0, 928),
 (0, 1163),
 (0, 1164),
 (0, 1900),
 (0, 2169),
 (0, 2170),
 (0, 2171),
 (0, 2172),
 (0, 2173),
 (0, 2174),
 (0, 2502),
 (0, 2503),
 (0, 2504),
 (0, 2762),
 (0, 2763),
 (0, 2764),
 (0, 3300),
 (0, 3548),
 (0, 4086),
 (0, 4366),
 (0, 4367),
 (0, 4368),
 (0, 4634),
 (0, 4635),
 (0, 4917),
 (0, 4918),
 (0, 4919),
 (0, 4920),
 (0, 5300),
 (0, 5301),
 (0, 5454),
 (0, 5722),
 (0, 6535),
 (0, 6800),
 (0, 6801),
 (0, 6926),
 (0, 6927),
 (0, 6928),
 (0, 6929),
 (0, 7072),
 (0, 7073),
 (0, 7074),
 (0, 7629),
 (0, 7630),
 (0, 7631),
 (0, 8372),
 (0, 8509),
 (0, 8602),
 (0, 8890),
 (0, 9153),
 (0, 9440),
 (0, 9441),
 (0, 9442),
 (0, 9722),
 (0, 9723),
 (0, 9724),
 (0, 9725),
 (0, 9726),
 (0, 9727),
 (0, 9728),
 (0, 9729),
 (0, 10038),
 (0, 10039),
 (0, 10040),
 (0, 10041),
 (0, 10042),
 (0, 10043),
 (0, 10044),
 (0, 10226),
 (0, 10343),
 (0, 10344),
 (0, 1034

In [29]:
sub_obj_edges

[(0, 2),
 (2, 0),
 (0, 119),
 (119, 0),
 (0, 195),
 (195, 0),
 (0, 288),
 (288, 0),
 (0, 576),
 (576, 0),
 (0, 680),
 (680, 0),
 (0, 693),
 (693, 0),
 (0, 753),
 (753, 0),
 (0, 1006),
 (1006, 0),
 (0, 1064),
 (1064, 0),
 (0, 1165),
 (1165, 0),
 (0, 1235),
 (1235, 0),
 (0, 1478),
 (1478, 0),
 (0, 1608),
 (1608, 0),
 (0, 1735),
 (1735, 0),
 (0, 1784),
 (1784, 0),
 (0, 1978),
 (1978, 0),
 (0, 2115),
 (2115, 0),
 (0, 2275),
 (2275, 0),
 (0, 2342),
 (2342, 0),
 (0, 2348),
 (2348, 0),
 (0, 2350),
 (2350, 0),
 (0, 2355),
 (2355, 0),
 (0, 2360),
 (2360, 0),
 (0, 2634),
 (2634, 0),
 (0, 2657),
 (2657, 0),
 (0, 2672),
 (2672, 0),
 (0, 2706),
 (2706, 0),
 (0, 2892),
 (2892, 0),
 (0, 3630),
 (3630, 0),
 (0, 3631),
 (3631, 0),
 (0, 3750),
 (3750, 0),
 (0, 3963),
 (3963, 0),
 (0, 4272),
 (4272, 0),
 (0, 4522),
 (4522, 0),
 (0, 4525),
 (4525, 0),
 (0, 4614),
 (4614, 0),
 (0, 4793),
 (4793, 0),
 (0, 4873),
 (4873, 0),
 (0, 4922),
 (4922, 0),
 (0, 4923),
 (4923, 0),
 (0, 5190),
 (5190, 0),
 (0, 5197),


Unit test on DGL subgraph and sampler

In [13]:
contents = Data(dataset='ICEWS14_forecasting', add_reverse_relation=False)

In [23]:
id2evts = {k:tuple(v) for k, v in enumerate(contents.valid_data)}

In [30]:
from collections import defaultdict
sub2evt = defaultdict(list)
obj2evt = defaultdict(list)

for i, evt in enumerate(contents.train_data):
    sub2evt[evt[0]].append(i)
    obj2evt[evt[2]].append(i)
    
# edge point from event happen earlier to later
sub_sub_edges = [(i,j, contents.train_data[j,3]-contents.train_data[i,3])
                 for gr in sub2evt.values() for i, j in itertools.combinations(gr, r=2) 
                 if contents.train_data[j,3]-contents.train_data[i,3] > 0]
obj_obj_edges = [(i,j, contents.train_data[j,3]-contents.train_data[i,3]) 
                 for gr in obj2evt.values() for i, j in itertools.combinations(gr, r=2) 
                 if contents.train_data[j,3]-contents.train_data[i,3] > 0]
sub_obj_edges = [((i,j, contents.train_data[j,3]-contents.train_data[i,3])) 
                 if contents.train_data[j,3] > contents.train_data[i,3] 
                 else (j,i,contents.train_data[i,3]-contents.train_data[j,3])
                 for sub, sub_evt in sub2evt.items() 
                 for i, j in itertools.product(sub_evt, obj2evt[sub])]
edges = sub_sub_edges + obj_obj_edges + sub_obj_edges

In [31]:
temp = np.array(edges)
u,v = temp[:, 0], temp[:, 1]
G = dgl.DGLGraph()
G.add_nodes(len(contents.train_data))
#TBD add weights
G = dgl.DGLGraph((u,v))

In [None]:
G.readonly()

In [None]:
# test connected components
G_nx = G.to_networkx().to_undirected()
cc = nx.connected_components(G_nx)

In [25]:
for batch_idx, sample in enumerate(DataLoader(np.arange(len(contents.train_data)), batch_size=8, shuffle=True)):
    print(sample)
    break

tensor([34166, 41879, 26680, 24638, 56275, 53601, 15173, 54132])


In [26]:
hop1neigh = dgl.contrib.sampling.sampler.NeighborSampler(G, len(sample), expand_factor=10000, num_hops=1, seed_nodes = sample)

NameError: name 'G' is not defined