In [7]:
pip install dgl

Note: you may need to restart the kernel to use updated packages.


In [9]:
import dgl

import torch as th

# edges 0->1, 0->2, 0->3, 1->3

u, v = th.tensor([0, 0, 0, 1]), th.tensor([1, 2, 3, 3])

g = dgl.graph((u, v))

print(g) # number of nodes are inferred from the max node IDs in the given edges

# Node IDs

print(g.nodes())

# Edge end nodes

print(g.edges())

# Edge end nodes and edge IDs

print(g.edges(form='all'))

# If the node with the largest ID is isolated (meaning no edges),

# then one needs to explicitly set the number of nodes

g = dgl.graph((u, v), num_nodes=8)

Graph(num_nodes=4, num_edges=4,
      ndata_schemes={}
      edata_schemes={})
tensor([0, 1, 2, 3])
(tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]))
(tensor([0, 0, 0, 1]), tensor([1, 2, 3, 3]), tensor([0, 1, 2, 3]))


In [10]:
edges = th.tensor([2, 5, 3]), th.tensor([3, 5, 0])  # edges 2->3, 5->5, 3->0
g64 = dgl.graph(edges)  # DGL uses int64 by default
print(g64.idtype)
g32 = dgl.graph(edges, idtype=th.int32)  # create a int32 graph
g32.idtype
g64_2 = g32.long()  # convert to int64
g64_2.idtype
g32_2 = g64.int()  # convert to int32
g32_2.idtype

torch.int64


torch.int32

In [71]:
g = dgl.heterograph({
   ('drug', 'interacts', 'drug'): (th.tensor([0, 1]), th.tensor([1, 2])),
   ('drug', 'treats', 'disease'): (th.tensor([1]), th.tensor([2]))})
g.nodes['drug'].data['hv'] = th.ones(3, 4)
g.nodes['disease'].data['hv'] = th.ones(3, 4)
g.edges['interacts'].data['he'] = th.zeros(2, 1)
g.edges['treats'].data['he'] = th.zeros(1, 2)
# By default, it does not merge any features
hg = dgl.to_homogeneous(g)
'hv' in hg.ndata
# Copy edge features
# For feature copy, it expects features to have
# the same size and dtype across node/edge types
# hg = dgl.to_homogeneous(g, edata=['he'])
# Copy node features
hg = dgl.to_homogeneous(g) # , ndata=['hv']
print(hg)
g_feat = g.ndata['hv']

Graph(num_nodes=6, num_edges=3,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), '_TYPE': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), '_TYPE': Scheme(shape=(), dtype=torch.int64)})


In [16]:
import dgl
import torch as th
u, v = th.tensor([0, 1, 2]), th.tensor([2, 3, 4])
g = dgl.graph((u, v))
g.ndata['x'] = th.randn(5, 3)  # original feature is on CPU
g.device
# cuda_g = g.to('cuda:0')  # accepts any device objects from backend framework
# cuda_g.device
# cuda_g.ndata['x'].device       # feature data is copied to GPU too
# # A graph constructed from GPU tensors is also on GPU
# u, v = u.to('cuda:0'), v.to('cuda:0')
# g = dgl.graph((u, v))
# g.device

device(type='cpu')

In [73]:
import torch.nn as nn

from dgl.utils import expand_as_pair

import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape

class SAGEConv(nn.Module):
    def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type,
                 bias=True,
                 norm=None,
                 activation=None):
        super(SAGEConv, self).__init__()

        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._aggre_type = aggregator_type
        self.norm = norm
        self.activation = activation
        # aggregator type: mean, pool, lstm, gcn
        if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn']:
            raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
        if aggregator_type == 'pool':
            self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
        if aggregator_type == 'lstm':
            self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
        if aggregator_type in ['mean', 'pool', 'lstm']:
            self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
        self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
        
        self.reset_parameters()
   
    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
    
    def forward(self, graph, feat):
        with graph.local_scope():
            # Specify graph type then expand input feature according to graph type
            feat_src, feat_dst = expand_as_pair(feat, graph)
            if self._aggre_type == 'mean':
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst
                graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
                # divide in_degrees
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
            elif self._aggre_type == 'pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

            # GraphSAGE GCN does not require fc_self.
            if self._aggre_type == 'gcn':
                rst = self.fc_neigh(h_neigh)
            else:
                h_self = graph.dstdata['h']
                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
        return rst


In [74]:
graphSage = SAGEConv(4, 2, 'mean')

In [75]:
print(hg)
feat = th.ones(6,4)

Graph(num_nodes=6, num_edges=3,
      ndata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), '_TYPE': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'_ID': Scheme(shape=(), dtype=torch.int64), '_TYPE': Scheme(shape=(), dtype=torch.int64)})


In [76]:
dst_feat = graphSage(hg, feat.float())
print(dst_feat)

tensor([[ 2.2324, -0.7498],
        [ 2.2324, -0.7498],
        [ 3.9787,  1.6786],
        [ 2.2324, -0.7498],
        [ 3.9787,  1.6786],
        [ 3.9787,  1.6786]], grad_fn=<AddBackward0>)


In [77]:
import torch.nn as nn

def get_aggregate_fn(aggregator_type, in_feats, out_feats):
    if aggregator_type not in ['mean', 'pool', 'lstm', 'gcn','sum']:
        raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
    if aggregator_type == 'pool':
        return nn.Linear(in_feats, in_feats)
    if aggregator_type == 'lstm':
        return nn.LSTM(in_feats, in_feats, batch_first=True)
    if aggregator_type in ['mean', 'pool', 'lstm']:
        return nn.Linear(in_feats, out_feats, bias=True)
    
class HeteroGraphConv(nn.Module):
    def __init__(self, mods,
                 in_feats,
                 out_feats,
                 aggregate='sum'):
        
        super(HeteroGraphConv, self).__init__()
        self.mods = nn.ModuleDict(mods)
        if isinstance(aggregate, str):
            # An internal function to get common aggregation functions
            self.agg_fn = get_aggregate_fn(aggregate, in_feats, out_feats)
        else:
            self.agg_fn = aggregate
            
    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
        if mod_args is None:
            mod_args = {}
        if mod_kwargs is None:
            mod_kwargs = {}
        outputs = {nty : [] for nty in g.dsttypes}
        
        if g.is_block:
            src_inputs = inputs
            dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
        else:
            src_inputs = dst_inputs = inputs

        for stype, etype, dtype in g.canonical_etypes:
            rel_graph = g[stype, etype, dtype]
            if rel_graph.num_edges() == 0:
                continue
            if stype not in src_inputs or dtype not in dst_inputs:
                continue
            dstdata = self.mods[etype](
                rel_graph,
                (src_inputs[stype], dst_inputs[dtype]),
                *mod_args.get(etype, ()),
                **mod_kwargs.get(etype, {}))
            outputs[dtype].append(dstdata)
        return outputs
    
    

In [78]:

in_feats = 4
out_feats = 2
mods = {
    'interacts' : SAGEConv(4, 2, 'mean'),
    'treats' : SAGEConv(4, 2, 'mean')
}

HetroConv = HeteroGraphConv(mods, 
                            in_feats,
                            out_feats,
                            aggregate='sum')


In [79]:
out = HetroConv(g, g_feat)
print(out)