In [2]:
import torch, dgl

In [3]:
import os, os.path as osp

In [4]:
print("Load ogbn-products graph")
dataset = dgl.load_graphs(osp.join(os.environ['DATASETS'], 'ogbn_products/graph.dgl'))

Load ogbn-products graph


In [5]:
g = dataset[0][0]
g.num_nodes(), g.num_edges()

(2449029, 123718280)

In [6]:
seeds = torch.randint(g.num_nodes(), (32,))
print("Seed nodes to start sampling with:")
print(seeds)

Seed nodes to start sampling with:
tensor([1300887, 2099796, 2418273, 1632254, 1366510, 1241817, 1395205, 1757922,
         344507, 2110793, 2446189, 1886552, 1636847,  544967,  475622,  442955,
        1457436, 2046614, 1496639,  157365,  941564, 1344644, 2125957, 2128770,
        1856015,  626054,  103844, 1597071, 1000066, 2376133,  925419,  627206])


In [7]:
fanout=10
frontier = g.sample_neighbors(seeds, fanout, edge_dir='in')
print("Sampled 1-hop subgraph:")
print(frontier.edges())

Sampled 1-hop subgraph:
(tensor([1414391, 1062794,  582377,  745235, 1441563,  685249, 2322820, 1392156,
        1138176,  939746,  909662, 1008337, 1472125, 2041130,  100846, 1203738,
          88334,  186645, 1973699, 1927726, 2022744,  313528, 1483036,  378927,
        1827406, 1800309, 1939697, 1400924, 2019163,  615993, 2167952,   80597,
          37809, 2409352, 2251837, 1778888, 1669695, 1263006, 2279300,   68707,
         532943,   16379,  313967,   16352, 1970360,  107201,  189773,  141295,
         668877, 2260706,  317543,  174567, 2260516, 2370161,  779409, 1177104,
        1007843, 1605586, 2020269,  908377, 1320843,  410841,  898381, 1491302,
        1390901, 2311585, 1952108, 2273798,  131563,   53828,   60637, 2441018,
        1756402,  119966,  188475, 1564025, 1997013, 2090986,  737371, 2333467,
        2426177,  365996, 2342601, 1164346,  581063, 1402793, 2275613,  323629,
        1596405, 1409250, 1733362,  204048,  586019,  563390, 1721149, 1574548,
        1767642

In [8]:
block = dgl.transforms.to_block(frontier, seeds)
print("Bipartite graph from frontier:")
print(block.edges())

Bipartite graph from frontier:
(tensor([ 32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,
         46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
         60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,  72,  73,
         74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,  86,  87,
         88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101,
        102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115,
        116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
        130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
        144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
        158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171,
        172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185,
        186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199,
        200, 201, 202, 203, 204,

In [9]:
block.srcnodes()
block.dstnodes()

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

## Multi-Layer Neighbor Sampler

In [10]:
sampler = dgl.dataloading.NeighborSampler([8,4])
seeds = torch.randint(g.num_nodes(), (8,))
input_nodes, output_nodes, blocks = sampler.sample_blocks(g, seeds)
print(blocks)

[Block(num_src_nodes=297, num_dst_nodes=40, num_edges=296), Block(num_src_nodes=40, num_dst_nodes=8, num_edges=32)]


In [12]:
print(blocks[-1].srcnodes())
print(blocks[-1].dstnodes())
print(blocks[-1].dstdata[dgl.NID])
print(output_nodes)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39])
tensor([0, 1, 2, 3, 4, 5, 6, 7])
tensor([1432072, 1395124, 2342110, 2339586,  646553,  938625, 1141776, 1789836])
tensor([1432072, 1395124, 2342110, 2339586,  646553,  938625, 1141776, 1789836])


In [44]:
print("src nodes:", blocks[0].srcnodes())
print("dst nodes:", blocks[0].dstnodes())

src nodes: tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175, 176, 

In [47]:
print("src ID num: ", blocks[0].srcdata[dgl.NID].numel())
print("src IDs:", blocks[0].srcdata[dgl.NID])

src ID num:  299
src IDs: tensor([1228655,  281709, 2409152, 1937782,  821957, 2319000,  268568,  509148,
        1791701,  142619,  285469,  156119,   56929, 1200929, 1516444,   88742,
        1988004, 1188688, 2032623,  877336,  328800,  259688, 1354468, 2362417,
         227433,  104100, 2306380,  140074,  668167, 1654503, 2350934, 2205293,
          67453,  293234, 1583973,  140187,  308465, 1719122, 1461419, 1797708,
        1794787, 1881660,  156502, 2014592,    2094, 1882628, 1149813, 2020692,
         215165,  177790,  118851,   80285, 1320856, 1655103,  738696, 1610451,
        2038501,  158359, 1007817,   94722,   78232, 2055461, 1991166,  672949,
        2206720, 1929019, 1138372, 2355601,  483959, 2397263, 1579795, 2044294,
        2096285, 1679746, 1598540, 1362730,    4975,  665638,  147829,  728023,
         772611, 2020157, 1782811,   74370,  586685,  402880, 1211174,  622852,
         241512,  762259,  996497,  918847, 1188115,  493786, 1716866, 2111665,
        205625