<a href="https://colab.research.google.com/github/hyeamykim/GCN-related-works/blob/main/vanilla_GraphSAINT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

References: 

Fast Graph Representation Learning with PyTorch Geometric
https://arxiv.org/abs/1903.02428


PyTorch Geometric Github https://github.com/rusty1s/pytorch_geometric/blob/master/examples/graph_saint.py


#Install required packages, libraries, and datasets of PyTorch Geometric

In [None]:
!pip install -q torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

[K     |████████████████████████████████| 11.9MB 54.7MB/s 
[K     |████████████████████████████████| 24.3MB 1.3MB/s 
[K     |████████████████████████████████| 235kB 18.2MB/s 
[K     |████████████████████████████████| 2.2MB 39.0MB/s 
[K     |████████████████████████████████| 51kB 8.4MB/s 
[?25h  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [None]:
import torch
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader, GraphSAINTRandomWalkSampler, GraphSAINTNodeSampler, GraphSAINTEdgeSampler, GraphSAINTSampler
from torch_geometric.nn import GCNConv
import argparse

In [None]:
from torch_geometric.datasets import Yelp, Amazon, Reddit
from torch_geometric.transforms import NormalizeFeatures

dataset = Yelp(root='data/Yelp', transform=NormalizeFeatures()) # multi-class classification wtih 100 classes
#dataset_amazon = Amazon(root='data/Amazon', name='computers') # multi- class classification
#dataset_reddit = Reddit(root='data/Reddit',transform=NormalizeFeatures()) # single class classification

Downloading 1Juwx8HtDwSzmVIJ31ooVa1WljI4U5JnA into data/Yelp/raw/adj_full.npz... Done.
Downloading 1Zy6BZH_zLEjKlEFSduKE5tV9qqA_8VtM into data/Yelp/raw/feats.npy... Done.
Downloading 1VUcBGr0T0-klqerjAjxRmAqFuld_SMWU into data/Yelp/raw/class_map.json... Done.
Downloading 1NI5pa5Chpd-52eSmLW60OnB3WS5ikxq_ into data/Yelp/raw/role.json... Done.
Processing...
Done!


#Pre-processing

Print basic information about the dataset.

In [None]:
def print_data(input_data):
  ''' input_data: pytorch geometric dataset format
      prints basic information about the dataset
  '''

  print()
  print(f'Dataset: {input_data}:')
  print('======================')
  print(f'Number of graphs: {len(input_data)}')
  print(f'Number of features: {input_data.num_features}')
  print(f'Number of classes: {input_data.num_classes}')

  data = input_data[0]  # Get the first graph object.

  print()
  print(data)
  print('===========================================================================================================')

  # Gather some statistics about the graph.
  print(f'Number of nodes: {data.num_nodes}')
  print(f'Number of edges: {data.num_edges}')
  print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
  #if data.train_mask: # to check if the data is masked with train/val/test mask
  #  print(f'Number of training nodes: {data.train_mask.sum()}')
  #  print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
  print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
  print(f'Contains self-loops: {data.contains_self_loops()}')
  print(f'Is undirected: {data.is_undirected()}')

In [None]:
print_data(dataset)


Dataset: Yelp():
Number of graphs: 1
Number of features: 300
Number of classes: 100

Data(edge_index=[2, 13954819], test_mask=[716847], train_mask=[716847], val_mask=[716847], x=[716847, 300], y=[716847, 100])
Number of nodes: 716847
Number of edges: 13954819
Average node degree: 19.47
Contains isolated nodes: False
Contains self-loops: True
Is undirected: True


Define data and inspect

In [None]:
data = dataset[0] # the first graph object of the dataset (which has just 1 grph anyways)

In [None]:
data

Data(edge_index=[2, 13954819], test_mask=[716847], train_mask=[716847], val_mask=[716847], x=[716847, 300], y=[716847, 100])

Some extra functions in case we cant to get information about the graph data (e.g. degree of the dataset). 

Reference: (insert link)

In [None]:
from typing import Optional

def maybe_num_nodes(index: torch.Tensor,
                    num_nodes: Optional[int] = None) -> int:
    return int(index.max()) + 1 if num_nodes is None else num_nodes

def degree(index, num_nodes: Optional[int] = None,
           dtype: Optional[int] = None):
    r"""Computes the (unweighted) degree of a given one-dimensional index
    tensor.
    Args:
        index (LongTensor): Index tensor.
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
        dtype (:obj:`torch.dtype`, optional): The desired data type of the
            returned tensor.
    :rtype: :class:`Tensor`
    """
    N = maybe_num_nodes(index, num_nodes)
    out = torch.zeros((N, ), dtype=dtype, device=index.device)
    one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)
    return out.scatter_add_(0, index, one)

In [None]:
row, col = data.edge_index
data.edge_weight = 1. / degree(col, data.num_nodes)[col]  # Norm by in-degree.

#parser = argparse.ArgumentParser()
#parser.add_argument('--use_normalization', action='store_true')
#args = parser.parse_args()

Define data loader (e.g. GraphSAINTRandomWalk, GraphSAINTRandomNode, GraphSAINTRandomEdge)

In [None]:
loader = GraphSAINTRandomWalkSampler(data, batch_size=512, walk_length=2,
                                     num_steps=5, sample_coverage=10,
                                     save_dir=dataset.processed_dir,
                                     num_workers=4)

Compute GraphSAINT normalization: : 7171207it [05:39, 21152.42it/s]                           


Another quick data inspection

In [None]:
for data in loader:
  print(data)
  print(data.y)
  print(data.y.shape)
  print(torch.argmax(data.y))
  print(torch.argmax(data.y, dim=-1))
  print(data.y.argmax(dim=1))
  print(data.train_mask)
  break

Data(edge_index=[2, 5163], edge_norm=[5163], node_norm=[1325], test_mask=[1325], train_mask=[1325], val_mask=[1325], x=[1325, 300], y=[1325, 100])
tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.]])
torch.Size([1325, 100])
tensor(1)
tensor([1, 1, 3,  ..., 1, 1, 2])
tensor([1, 1, 3,  ..., 1, 1, 2])
tensor([True, True, True,  ..., True, True, True])


#Define a model structure

GraphSAINT: 2 layers of GCN with ReLU activation and LogSoftmax layer as output layer.

In [None]:
class Net(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(Net, self).__init__()
        in_channels = dataset.num_node_features
        out_channels = dataset.num_classes
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        #self.conv3 = GCNConv(hidden_channels, hidden_channels)
        #self.lin = torch.nn.Linear(3 * hidden_channels, out_channels)
        self.lin = torch.nn.Linear(2 * hidden_channels, out_channels)

    def set_aggr(self, aggr):
        self.conv1.aggr = aggr
        self.conv2.aggr = aggr
        #self.conv3.aggr = aggr

    def forward(self, x0, edge_index, edge_weight=None):
        x1 = F.relu(self.conv1(x0, edge_index, edge_weight))
        x1 = F.dropout(x1, p=0.2, training=self.training)
        x2 = F.relu(self.conv2(x1, edge_index, edge_weight))
        x2 = F.dropout(x2, p=0.2, training=self.training)
        #x3 = F.relu(self.conv3(x2, edge_index, edge_weight))
        #x3 = F.dropout(x3, p=0.2, training=self.training)
        #x = torch.cat([x1, x2, x3], dim=-1)
        x = torch.cat([x1, x2], dim=-1)
        x = self.lin(x)
        
        return x.log_softmax(dim=-1)

        #nn.Sigmoid()(preds) if self.sigmoid_loss else F.softmax(preds, dim=1)

Initialize the model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(hidden_channels=256).to(device)
criterion = torch.nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
print(model)

Net(
  (conv1): GCNConv(300, 256)
  (conv2): GCNConv(256, 256)
  (lin): Linear(in_features=512, out_features=100, bias=True)
)


#Train the model

Train Function with negative log-likelihood loss from train dataset.

In [None]:
def train():
    model.train()
    #model.set_aggr('mean')
    #model.set_aggr('max')
    model.set_aggr('add')
    #model.set_aggr('none')

    total_loss = total_examples = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        #if args.use_normalization:
        #    edge_weight = data.edge_norm * data.edge_weight
        #    out = model(data.x, data.edge_index, edge_weight)
        #    loss = F.nll_loss(out, data.y, reduction='none')
        #    loss = (loss * data.node_norm)[data.train_mask].sum()
        #else:
        #    out = model(data.x, data.edge_index)
        #    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])

        out = model(data.x, data.edge_index)
        #print(out.argmax(dim=-1))
        loss = F.nll_loss(out[data.train_mask], torch.argmax(data.y[data.train_mask], dim=-1)) # for Yelp
        #loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) # for Yelp
        #loss = criterion(out[data.train_mask], torch.argmax(data.y[data.train_mask], dim=-1))

        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_nodes
        total_examples += data.num_nodes

    return total_loss / total_examples

Test the model with accuracy metric.

In [None]:
from sklearn.metrics import f1_score

In [None]:
def test():
    model.eval()
    #model.set_aggr('mean')
    #model.set_aggr('max')
    model.set_aggr('add')
    #model.set_aggr('none')

    out = model(data.x.to(device), data.edge_index.to(device))
    pred = out.argmax(dim=-1)
    #correct = pred.eq(data.y.to(device))
    correct = pred.eq(torch.argmax(data.y, dim=-1).to(device)) # for Yelp

    accs = []
    micro_scores = []
    macro_scores = []

    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        accs.append(correct[mask].sum().item() / mask.sum().item())
        
        micro_score = f1_score(torch.argmax(data.y[mask], dim=-1).cpu(), pred[mask].cpu(), average='micro')
        micro_scores.append(micro_score) 

        macro_score = f1_score(torch.argmax(data.y[mask], dim=-1).cpu(), pred[mask].cpu(), average='macro')
        macro_scores.append(macro_score)


    return accs, micro_scores, macro_scores

Print out the results.

In [None]:
for epoch in range(1, 51):
    loss = train()
    accs, micro_scores, macro_scores = test()

    # Accuracy
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {accs[0]:.4f}, '
          f'Val: {accs[1]:.4f}, Test: {accs[2]:.4f}')
    
    #micro F1
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {micro_scores[0]:.4f}, '
          f'Val: {micro_scores[1]:.4f}, Test: {micro_scores[2]:.4f}')
    
    #macro F1
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {macro_scores[0]:.4f}, '
          f'Val: {macro_scores[1]:.4f}, Test: {macro_scores[2]:.4f}')

Epoch: 01, Loss: 1.6986, Train: 0.5276, Val: 0.5217, Test: 0.5725
Epoch: 01, Loss: 1.6986, Train: 0.5276, Val: 0.5217, Test: 0.5725
Epoch: 01, Loss: 1.6986, Train: 0.2120, Val: 0.2369, Test: 0.3393
Epoch: 02, Loss: 1.3529, Train: 0.5796, Val: 0.5362, Test: 0.6087
Epoch: 02, Loss: 1.3529, Train: 0.5796, Val: 0.5362, Test: 0.6087
Epoch: 02, Loss: 1.3529, Train: 0.2050, Val: 0.1368, Test: 0.4061
Epoch: 03, Loss: 1.2754, Train: 0.5888, Val: 0.5604, Test: 0.6014
Epoch: 03, Loss: 1.2754, Train: 0.5888, Val: 0.5604, Test: 0.6014
Epoch: 03, Loss: 1.2754, Train: 0.2146, Val: 0.1959, Test: 0.4211
Epoch: 04, Loss: 1.2460, Train: 0.5816, Val: 0.5362, Test: 0.5725
Epoch: 04, Loss: 1.2460, Train: 0.5816, Val: 0.5362, Test: 0.5725
Epoch: 04, Loss: 1.2460, Train: 0.2056, Val: 0.1565, Test: 0.3690
Epoch: 05, Loss: 1.1929, Train: 0.5816, Val: 0.5459, Test: 0.5942
Epoch: 05, Loss: 1.1929, Train: 0.5816, Val: 0.5459, Test: 0.5942
Epoch: 05, Loss: 1.1929, Train: 0.1819, Val: 0.1821, Test: 0.3958
Epoch: 06,

Test Accuracy: 0.5900


Test Accuracy: 0.8140


Plot the results