<a href="https://colab.research.google.com/github/cschlick/notebooks/blob/main/dgl/1_introduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#%%capture
!pip install dgl-cu111 -f https://data.dgl.ai/wheels/repo.html

Looking in links: https://data.dgl.ai/wheels-test/repo.html
Collecting dgl-cu111
  Downloading https://data.dgl.ai/wheels-test/dgl_cu111-0.8a210920-cp37-cp37m-manylinux1_x86_64.whl (144.8 MB)
[K     |████████████████████████████████| 144.8 MB 39 kB/s 
Installing collected packages: dgl-cu111
Successfully installed dgl-cu111-0.8a210920



Node Classification with DGL
============================

GNNs are powerful tools for many machine learning tasks on graphs. In
this introductory tutorial, you will learn the basic workflow of using
GNNs for node classification, i.e. predicting the category of a node in
a graph.

By completing this tutorial, you will be able to

-  Load a DGL-provided dataset.
-  Build a GNN model with DGL-provided neural network modules.
-  Train and evaluate a GNN model for node classification on either CPU
   or GPU.

This tutorial assumes that you have experience in building neural
networks with PyTorch.

(Time estimate: 13 minutes)


In [13]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [16]:
# Contruct a two-layer GNN model
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h

In [22]:
import dgl

dataset = dgl.data.CiteseerGraphDataset()
graph = dataset[0]

Downloading /root/.dgl/citeseer.zip from https://data.dgl.ai/dataset/citeseer.zip...
Extracting file to /root/.dgl/citeseer
Finished data loading and preprocessing.
  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.


  r_inv = np.power(rowsum, -1).flatten()


In [23]:
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)

In [26]:
def evaluate(model, graph, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(graph, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

In [27]:
model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    model.train()
    # forward propagation by using all nodes
    logits = model(graph, node_features)
    # compute loss
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # compute validation accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model if necessary.  Omitted in this example.

1.793243408203125
1.7782562971115112
1.7635517120361328
1.7482908964157104
1.7322710752487183
1.7153481245040894
1.6975038051605225
1.678845763206482
1.6594936847686768
1.6393656730651855


In [48]:
self = model
inputs = node_features

h1 = self.conv1(graph, inputs)
h2 = F.relu(h1)
h3 = self.conv2(graph, h2)

In [51]:
print(h1.shape)
print(h2.shape)
print(h3.shape)

torch.Size([3327, 100])
torch.Size([3327, 100])
torch.Size([3327, 6])


torch.Size([3327, 3703])

In [43]:
logits = model(graph, features)

3327

In [45]:
node_labels.shape

torch.Size([3327])

In [46]:
n_features

3703