References: 

Fast Graph Representation Learning with PyTorch Geometric
https://arxiv.org/abs/1903.02428


PyTorch Geometric Github https://github.com/rusty1s/pytorch_geometric/blob/master/examples/graph_saint.py


In [None]:
# Install required packages of PyTorch Geometric
!pip install -q torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

[K     |████████████████████████████████| 11.9MB 255kB/s 
[K     |████████████████████████████████| 24.3MB 137kB/s 
[K     |████████████████████████████████| 235kB 11.9MB/s 
[K     |████████████████████████████████| 2.2MB 31.3MB/s 
[K     |████████████████████████████████| 51kB 8.4MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [None]:
# import datasets from pytorch geometric

from torch_geometric.datasets import Yelp, Flickr, Amazon
from torch_geometric.transforms import NormalizeFeatures

dataset = Yelp(root='data/Yelp', transform=NormalizeFeatures())
#dataset_amazon = Amazon(root='data/Amazon', name='computers')
#dataset_flickr = Flickr(root='data/Flickr', transform=NormalizeFeatures()) # this is an image dataset, so only for demo


Downloading 1Juwx8HtDwSzmVIJ31ooVa1WljI4U5JnA into data/Yelp/raw/adj_full.npz... Done.
Downloading 1Zy6BZH_zLEjKlEFSduKE5tV9qqA_8VtM into data/Yelp/raw/feats.npy... Done.
Downloading 1VUcBGr0T0-klqerjAjxRmAqFuld_SMWU into data/Yelp/raw/class_map.json... Done.
Downloading 1NI5pa5Chpd-52eSmLW60OnB3WS5ikxq_ into data/Yelp/raw/role.json... Done.
Processing...
Done!


In [None]:
# print basic information about the dataset

def print_data(input_data):
  ''' input_data: pytorch geometric dataset format
      prints basic information about the dataset
  '''

  print()
  print(f'Dataset: {input_data}:')
  print('======================')
  print(f'Number of graphs: {len(input_data)}')
  print(f'Number of features: {input_data.num_features}')
  print(f'Number of classes: {input_data.num_classes}')

  data = input_data[0]  # Get the first graph object.

  print()
  print(data)
  print('===========================================================================================================')

  # Gather some statistics about the graph.
  print(f'Number of nodes: {data.num_nodes}')
  print(f'Number of edges: {data.num_edges}')
  print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
  #if data.train_mask: # to check if the data is masked with train/val/test mask
  #  print(f'Number of training nodes: {data.train_mask.sum()}')
  #  print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
  print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
  print(f'Contains self-loops: {data.contains_self_loops()}')
  print(f'Is undirected: {data.is_undirected()}')

In [None]:
print_data(dataset)


Dataset: Yelp():
Number of graphs: 1
Number of features: 300
Number of classes: 100

Data(edge_index=[2, 13954819], test_mask=[716847], train_mask=[716847], val_mask=[716847], x=[716847, 300], y=[716847, 100])
Number of nodes: 716847
Number of edges: 13954819
Average node degree: 19.47
Contains isolated nodes: False
Contains self-loops: True
Is undirected: True


In [None]:
# define data
data = dataset[0] # the first graph object of the dataset (which has just 1 grph anyways)

In [None]:
# import graphsaint sampler (explain what this does)
from torch_geometric.data import GraphSAINTRandomWalkSampler

In [None]:
loader = GraphSAINTRandomWalkSampler(data, batch_size=6000, walk_length=2,
                                     num_steps=5, sample_coverage=100,
                                     save_dir=dataset.processed_dir,
                                     num_workers=4)

Compute GraphSAINT normalization: : 71705856it [10:35, 112750.99it/s]                            


In [None]:
# just to check how Amazon data looks like 
for data in loader:
  print(data)
  #print(data.y)
  #print(data.train_mask)
  break

Data(edge_index=[2, 316298], edge_norm=[316298], node_norm=[8993], x=[8993, 767], y=[8993])


In [None]:
# just to check how Yelp data looks like 
for data in loader:
  #print(data)
  print(data.y) # one-hot-encoded label
  #print(data.train_mask)
  break

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 0., 0., 1.],
        [0., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 0., 1.]])


In [None]:
# just to check how Flickr data looks like 
for data in loader:
  #print(data)
  print(data.y)
  #print(data.train_mask)
  break

tensor([4, 4, 6,  ..., 3, 0, 6])


In [None]:
# import necessary libraries for model building
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
import torch.nn.functional as F

In [None]:
# model structure

class Net(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(Net, self).__init__()
        in_channels = dataset.num_node_features
        out_channels = dataset.num_classes
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(3 * hidden_channels, out_channels)

    def set_aggr(self, aggr):
        self.conv1.aggr = aggr
        self.conv2.aggr = aggr
        self.conv3.aggr = aggr

    def forward(self, x0, edge_index, edge_weight=None):
        x1 = F.relu(self.conv1(x0, edge_index, edge_weight))
        x1 = F.dropout(x1, p=0.2, training=self.training)
        x2 = F.relu(self.conv2(x1, edge_index, edge_weight))
        x2 = F.dropout(x2, p=0.2, training=self.training)
        x3 = F.relu(self.conv3(x2, edge_index, edge_weight))
        x3 = F.dropout(x3, p=0.2, training=self.training)
        x = torch.cat([x1, x2, x3], dim=-1)
        x = self.lin(x)
        
        return x.log_softmax(dim=-1)

In [None]:
# define model 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(hidden_channels=256).to(device)
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print(model)

Net(
  (conv1): GCNConv(300, 256)
  (conv2): GCNConv(256, 256)
  (conv3): GCNConv(256, 256)
  (lin): Linear(in_features=768, out_features=100, bias=True)
)


In [None]:
# train function
def train():
    model.train()
    #model.set_aggr('add' if args.use_normalization else 'mean')

    total_loss = total_examples = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        #if args.use_normalization:
        #    edge_weight = data.edge_norm * data.edge_weight
        #    out = model(data.x, data.edge_index, edge_weight)
        #    loss = F.nll_loss(out, data.y, reduction='none')
        #    loss = (loss * data.node_norm)[data.train_mask].sum()
        #else:
        #    out = model(data.x, data.edge_index)
        #    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])

        out = model(data.x, data.edge_index)
        #loss = F.nll_loss(out, data.y) # for Amazon
        loss = F.nll_loss(out[data.train_mask], torch.argmax(data.y[data.train_mask], dim=-1)) # for Yelp
        #loss = criterion(out[data.train_mask], data.y[data.train_mask])

        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_nodes
        total_examples += data.num_nodes
    return total_loss / total_examples

In [None]:
# test function

@torch.no_grad()
def test():
    model.eval()
    model.set_aggr('mean')

    out = model(data.x.to(device), data.edge_index.to(device))
    pred = out.argmax(dim=-1)
    #correct = pred.eq(data.y.to(device)) # for Amazon
    correct = pred.eq(torch.argmax(data.y, dim=-1).to(device)) # for Yelp

    accs = []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        accs.append(correct[mask].sum().item() / mask.sum().item())
    return accs # for Yelp

    #accs.append(correct.sum().item()) # for Amazon

In [None]:
# print results

for epoch in range(1, 51):
    loss = train()
    accs = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {accs[0]:.4f}, '
          f'Val: {accs[1]:.4f}, Test: {accs[2]:.4f}') # for Yelp
    #print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {accs[0]:.4f}, Test: {accs[1]:.4f}') # for Amazon
    #print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') for Amazon

Epoch: 01, Loss: 2.9649, Train: 0.5415, Val: 0.5513, Test: 0.5458
Epoch: 02, Loss: 3.2837, Train: 0.5404, Val: 0.5495, Test: 0.5464
Epoch: 03, Loss: 2.6773, Train: 0.5421, Val: 0.5508, Test: 0.5451
Epoch: 04, Loss: 2.4568, Train: 0.5390, Val: 0.5482, Test: 0.5431
Epoch: 05, Loss: 2.3303, Train: 0.5421, Val: 0.5508, Test: 0.5451
Epoch: 06, Loss: 2.2290, Train: 0.5390, Val: 0.5482, Test: 0.5444
Epoch: 07, Loss: 2.1174, Train: 0.5422, Val: 0.5508, Test: 0.5451
Epoch: 08, Loss: 2.0672, Train: 0.5422, Val: 0.5513, Test: 0.5451
Epoch: 09, Loss: 2.0120, Train: 0.5429, Val: 0.5513, Test: 0.5471
Epoch: 10, Loss: 1.9539, Train: 0.5414, Val: 0.5491, Test: 0.5451
Epoch: 11, Loss: 1.9080, Train: 0.5404, Val: 0.5477, Test: 0.5444
Epoch: 12, Loss: 1.8662, Train: 0.5405, Val: 0.5482, Test: 0.5451
Epoch: 13, Loss: 1.8110, Train: 0.5407, Val: 0.5473, Test: 0.5458
Epoch: 14, Loss: 1.7682, Train: 0.5428, Val: 0.5500, Test: 0.5464
Epoch: 15, Loss: 1.7396, Train: 0.5432, Val: 0.5508, Test: 0.5471
Epoch: 16,

Test Accuracy: 0.5900


Test Accuracy: 0.8140
