## Graph Sample and Aggregate (GraphSAGE) 

Paper:  [Inductive Representation Learning on Large Graphs](https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf) (NIPS 2017)

**Message Passing Perspective**

|Notion | Meaning | 
|---|---|
|$\mathcal{G}$ = $(V, E)$ | Input graph |
|$x_v$ | Node features for node $v\in V$|
|$h_v$ | Node embedding for node $v\in V$ |
|$\mathcal{N}(v)$ | Neighbours of node $v\in V$|

Initial:
$$h^{(0)}_v = x_v , \forall v \in V .$$

Aggregate:
$$h^{(l)}_{\mathcal{N}(v)} \leftarrow \text{AGG}\{h^{(l-1)}_u, \forall u\in \mathcal{N}(v)\} .$$

Update: 
$$h^{(l)}_v \leftarrow \text{ReLU}\left(W^{(l)} \cdot \text{CONCAT}\left(h^{(l-1)}_v, h^{(l)}_{\mathcal{N}(v)}\right)\right) .$$


## Reproduce Results

Metric: F1

|Dataset | Reddit | PPI | 
|---| --- | --- |
|GraphSAGE-mean | 0.950 | 0.598 |
|Ours | 0.959 | 0.974 |

Refererence of implementation: https://docs.dgl.ai/tutorials/blitz/3_message_passing.html

In [1]:
import time

import torch 
import torch.nn as nn
import torch.nn.functional as F
import dgl 
import dgl.function as fn
from dgl.data import RedditDataset, PPIDataset
import numpy as np
from sklearn.metrics import f1_score

In [2]:
class GraphSAGELayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GraphSAGELayer, self).__init__()
        self.linear = nn.Linear(in_feats * 2, out_feats)
    
    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)
        

In [3]:
# a two-layer GraphSAGE as described in the paper
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GraphSAGE, self).__init__()
        self.conv1 = GraphSAGELayer(in_feats, h_feats)
        self.conv2 = GraphSAGELayer(h_feats, num_classes)

    def forward(self, g, h):
        h = self.conv1(g, h)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [4]:
# we use the same configurations as the paper's
hidden_size = 256
lr = 1e-2
epochs = 200

In [5]:
def evaluate(model, g, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        return f1_score(indices.cpu().numpy(), labels.cpu().numpy(), average='micro')

In [6]:
def main(dataset, device='cuda'):
    g = dataset[0]
    features = g.ndata['feat'].to(device)
    labels = g.ndata['label'].to(device)
    train_mask = g.ndata['train_mask'].to(device)
    val_mask = g.ndata['val_mask'].to(device)
    test_mask = g.ndata['test_mask'].to(device)
    in_feats = features.shape[1]
    n_classes = dataset.num_classes
    g = dgl.remove_self_loop(g)
    n_edges = g.number_of_edges()
    g = g.int().to(device)
    
    print(f"#nodes: {g.number_of_nodes()}, #edges: {g.number_of_edges()}")
    
    model = GraphSAGE(in_feats, hidden_size, n_classes)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    
    early_stopping_cnt = 0
    best_val = 100 # large enough
    for epoch in range(epochs):
        start = time.time()
        model.train()
        # forward
        logits = model(g, features)
        loss = loss_fn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        end = time.time()
        logits = model(g, features)
        val_loss = loss_fn(logits[val_mask], labels[val_mask])
        
        if val_loss < best_val:
            best_val = val_loss
            early_stopping_cnt = 0
        else:
            early_stopping_cnt += 1
            if early_stopping_cnt == 10:
                print("Early stopping (val loss does not decrease for 10 consecutive epochs).")
                break        
                
        print("Epoch {:03d} | Time(s) {:.4f} | Train Loss {:.4f} | Val Loss {:.4f} | ".format(epoch, end - start, loss.item(), val_loss.item()))
    
    f1 = evaluate(model, g, features, labels, test_mask)
    print("Test F1 {:.3f}".format(f1))

In [7]:
dataset = RedditDataset()
main(dataset)

#nodes: 232965, #edges: 114615892
Epoch 000 | Time(s) 1.1911 | Train Loss 3.7194 | Val Loss 2.6151 | 
Epoch 001 | Time(s) 0.4856 | Train Loss 2.7078 | Val Loss 1.8982 | 
Epoch 002 | Time(s) 0.4837 | Train Loss 1.9473 | Val Loss 1.3232 | 
Epoch 003 | Time(s) 0.4855 | Train Loss 1.3681 | Val Loss 0.9192 | 
Epoch 004 | Time(s) 0.4853 | Train Loss 0.9821 | Val Loss 0.7158 | 
Epoch 005 | Time(s) 0.4863 | Train Loss 0.7571 | Val Loss 0.6174 | 
Epoch 006 | Time(s) 0.4840 | Train Loss 0.6250 | Val Loss 0.5222 | 
Epoch 007 | Time(s) 0.4873 | Train Loss 0.5357 | Val Loss 0.4646 | 
Epoch 008 | Time(s) 0.4870 | Train Loss 0.4993 | Val Loss 0.3998 | 
Epoch 009 | Time(s) 0.4818 | Train Loss 0.4423 | Val Loss 0.3646 | 
Epoch 010 | Time(s) 0.4872 | Train Loss 0.4016 | Val Loss 0.3873 | 
Epoch 011 | Time(s) 0.4871 | Train Loss 0.4243 | Val Loss 0.3513 | 
Epoch 012 | Time(s) 0.5027 | Train Loss 0.3760 | Val Loss 0.3539 | 
Epoch 013 | Time(s) 0.4861 | Train Loss 0.3717 | Val Loss 0.3689 | 
Epoch 014 | Ti

In [8]:
# we use the same configurations as the paper's
hidden_size = 256
lr = 1e-3
epochs = 200

In [9]:
model = GraphSAGE(50, hidden_size, 121)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()

In [10]:
train_dataset = PPIDataset(mode='train')
val_dataset = PPIDataset(mode='valid')
test_dataset = PPIDataset(mode='test')

early_stopping_cnt = 0
best_val = 100 # large enough
for epoch in range(epochs):
    train_loss = 0
    start = time.time()
    for g in train_dataset:
        features = g.ndata['feat']
        labels = g.ndata['label']
        logits = model(g, features)
        loss = loss_fn(logits, labels.argmax(dim=1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_dataset)
    end = time.time()
    
    with torch.no_grad():
        val_loss = 0
        for g in val_dataset:
            features = g.ndata['feat']
            labels = g.ndata['label']
            logits = model(g, features)
            val_loss = loss_fn(logits, labels.argmax(dim=1))
        val_loss /= len(val_dataset)

    if val_loss < best_val:
        best_val = val_loss
        early_stopping_cnt = 0
    else:
        early_stopping_cnt += 1
        if early_stopping_cnt == 10:
            print("Early stopping (val loss does not decrease for 10 consecutive epochs).")
            break        
                
    print("Epoch {:03d} | Time(s) {:.4f} | Train Loss {:.4f} | Val Loss {:.4f} | ".format(epoch, end - start, train_loss, val_loss))


model.eval()
with torch.no_grad():
    total_pred = np.array([])
    total_labels = np.array([])
    for g in test_dataset:
        features = g.ndata['feat']
        labels = g.ndata['label']
        logits = model(g, features)
        _, indices = torch.max(logits, dim=1)
        total_pred = np.append(total_pred, indices.cpu().numpy())
        total_labels = np.append(total_labels, labels.argmax(dim=1).cpu().numpy())
    
    f1 = f1_score(total_pred, total_labels, average='micro')
    print("Test F1 {:.3f}".format(f1))

Epoch 000 | Time(s) 0.9862 | Train Loss 3.5365 | Val Loss 1.1641 | 
Epoch 001 | Time(s) 0.9264 | Train Loss 1.6464 | Val Loss 0.7393 | 
Epoch 002 | Time(s) 1.0114 | Train Loss 1.2514 | Val Loss 0.6692 | 
Epoch 003 | Time(s) 0.9163 | Train Loss 1.1785 | Val Loss 0.6490 | 
Epoch 004 | Time(s) 0.8818 | Train Loss 1.1348 | Val Loss 0.6324 | 
Epoch 005 | Time(s) 0.9318 | Train Loss 1.0976 | Val Loss 0.6175 | 
Epoch 006 | Time(s) 0.8670 | Train Loss 1.0635 | Val Loss 0.6038 | 
Epoch 007 | Time(s) 0.8117 | Train Loss 1.0314 | Val Loss 0.5909 | 
Epoch 008 | Time(s) 0.8762 | Train Loss 1.0010 | Val Loss 0.5786 | 
Epoch 009 | Time(s) 0.9118 | Train Loss 0.9722 | Val Loss 0.5669 | 
Epoch 010 | Time(s) 0.9682 | Train Loss 0.9448 | Val Loss 0.5558 | 
Epoch 011 | Time(s) 0.9459 | Train Loss 0.9185 | Val Loss 0.5450 | 
Epoch 012 | Time(s) 0.9124 | Train Loss 0.8930 | Val Loss 0.5344 | 
Epoch 013 | Time(s) 0.9441 | Train Loss 0.8684 | Val Loss 0.5242 | 
Epoch 014 | Time(s) 0.8958 | Train Loss 0.8445 |