# Training a GCN with neighbor sampling

In [1]:
import dgl
import dgl.nn
import dgl.dataloading
import dgl.sampling
import dgl.function as fn
import ogb.nodeproppred
import torch
import torch.nn as nn
import torch.nn.functional as F

dataset = ogb.nodeproppred.DglNodePropPredDataset('ogbn-products')
g, labels = dataset[0]
labels = labels[:, 0]

Using backend: pytorch


In [2]:
spl = dataset.get_idx_split()
train_idx = spl['train']
val_idx = spl['valid']
test_idx = spl['test']

The message passing formulation is

$$
h_{v}^{(l+1)} = \sigma \left( \sum_{u \in \mathcal{N}(v)} e_{uv} h_{u}^{(l)} W^{(l+1)} \right)
$$

where $e_{uv} = \dfrac{1}{\sqrt{d_u d_v}}$ is the entry of the graph Laplacian on $u$-th row and $v$-th column.

We estimate it via neighbor sampling with

$$
\tilde{h}_{v}^{(l+1)} = \sigma \left( \tilde{D}_v \mathbb{E}_{u \sim P_v(u)}  \left[ \tilde{h}_{u}^{(l)} W^{(l+1)} \right] \right)
$$

where $P_v(u) \propto e_{uv}$ and $\tilde{D}_v = \sum_{u \in \mathcal{N}(v)} e_{uv}$.

## Compute $e$ and $\tilde{D}$

In [25]:
def compute_D_and_e(g):
    g.ndata['D_in'] = 1 / g.in_degrees().float().sqrt()
    g.ndata['D_out'] = 1 / g.out_degrees().float().sqrt()
    g.apply_edges(fn.u_mul_v('D_in', 'D_out', 'e'))
    g.update_all(fn.copy_e('e', 'm'), fn.sum('m', 'D_tilde'))
    g.edata['e'] = g.edata['e'].view(g.num_edges(), 1)
    # produces g.edata['e'] and g.ndata['D_tilde']

In [26]:
g = dgl.add_self_loop(dgl.remove_self_loop(g))
compute_D_and_e(g)

## Module definition

In [33]:
class Conv(nn.Module):
    def __init__(self, in_dims, out_dims):
        super().__init__()
        
        self.W = nn.Linear(in_dims, out_dims)
        
    def forward_train(self, block, x):
        with block.local_scope():
            block.srcdata['x'] = x
            block.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'y'))
            y = block.dstdata['y']
            return self.W(y) * block.dstdata['D_tilde'][:, None]   # D_tilde is computed outside in preprocessing stage
        
    def forward_eval(self, block, x):
        with block.local_scope():
            block.srcdata['x'] = x
            block.update_all(fn.u_mul_e('x', 'e', 'm'), fn.sum('m', 'y'))
            return self.W(block.dstdata['y'])
        
    def forward(self, block, x):
        if self.training:
            return self.forward_train(block, x)
        else:
            return self.forward_eval(block, x)

class StochasticGCN(nn.Module):
    def __init__(self, in_dims, hid_dims, out_dims):
        super().__init__()
        
        self.conv1 = Conv(in_dims, hid_dims)
        self.conv2 = Conv(hid_dims, hid_dims)
        self.conv3 = Conv(hid_dims, out_dims)
        
        self.hid_dims = hid_dims
        self.out_dims = out_dims
        
    def forward(self, blocks, x):
        x = F.relu(self.conv1(blocks[0], x))
        x = F.relu(self.conv2(blocks[1], x))
        x = self.conv3(blocks[2], x)
        return x
    
    def inference(self, g, x, batch_size, device):
        layers = [self.conv1, self.conv2, self.conv3]
        
        for l, layer in enumerate(layers):
            y = torch.zeros(g.num_nodes(), self.hid_dims if l != len(layers) - 1 else self.out_dims)
            
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
            dataloader = dgl.dataloading.NodeDataLoader(
                g,
                torch.arange(g.num_nodes()),
                sampler,
                batch_size=batch_size,
                shuffle=True,
                drop_last=False,
                num_workers=4)
            
            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0].to(device)
                
                # get inputs
                h = x[input_nodes].to(device)
                
                # This must match the procedure in forward()
                h = layer(block, h)
                if l != len(layers) - 1:
                    h = F.relu(h)
                    
                # write outputs
                y[output_nodes] = h.cpu()
                
            x = y
        return y

## Sampler definition

In [34]:
class NonUniformNeighborSampler(dgl.dataloading.MultiLayerNeighborSampler):
    def __init__(self, fanouts, return_eids=False):
        # Always with replacement
        super().__init__(fanouts, replace=True, return_eids=return_eids)
        
    def sample_frontier(self, block_id, g, seed_nodes):
        fanout = self.fanouts[block_id]
        # e refers to the weights on each edge (i.e. graph Laplacian).
        frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout, replace=True, prob='e')
        return frontier

In [35]:
BATCH_SIZE = 1024

sampler = NonUniformNeighborSampler([5, 10, 15])
dataloader = dgl.dataloading.NodeDataLoader(
    g, train_idx, sampler, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=4)

## Misc

In [36]:
def compute_acc(pred, labels):
    return ((torch.argmax(pred, dim=1) == labels).float().sum() / len(pred)).item()

## Train loop

In [37]:
import tqdm

N_EPOCHS = 20
LR = 0.003

model = StochasticGCN(g.ndata['feat'].shape[1], 256, dataset.num_classes).cuda()
opt = torch.optim.Adam(model.parameters(), lr=LR)

In [None]:
for epoch in range(N_EPOCHS):
    with tqdm.tqdm(dataloader) as tq:
        for step, (input_nodes, seeds, blocks) in enumerate(tq):
            blocks = [block.to('cuda') for block in blocks]
            batch_inputs = blocks[0].srcdata['feat']
            batch_labels = labels[seeds].cuda()

            batch_pred = model(blocks, batch_inputs)
            loss = F.cross_entropy(batch_pred, batch_labels)
            acc = compute_acc(batch_pred, batch_labels)
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            tq.set_postfix({'loss': '%.3f' % loss.item(), 'acc': '%.3f' % acc}, refresh=False)
            
        if epoch % 5 == 0:
            # evaluate every 5 epochs
            model.eval()
            with torch.no_grad():
                pred = model.inference(g, g.ndata['feat'], BATCH_SIZE, 'cuda')
            model.train()
            
            val_acc = compute_acc(pred[val_idx], labels[val_idx])
            test_acc = compute_acc(pred[test_idx], labels[test_idx])
            
            print('Val Acc', val_acc, 'Test Acc', test_acc)

100%|██████████| 193/193 [01:56<00:00,  1.66it/s, loss=1.065, acc=0.857]
100%|██████████| 2392/2392 [01:15<00:00, 31.56it/s]
100%|██████████| 2392/2392 [01:39<00:00, 23.96it/s]
100%|██████████| 2392/2392 [01:30<00:00, 26.38it/s]
  0%|          | 0/193 [00:00<?, ?it/s]

Val Acc 0.8800193071365356 Test Acc 0.7254767417907715


 66%|██████▌   | 127/193 [01:16<00:38,  1.73it/s, loss=0.525, acc=0.889]