## **Introducing GraphSAGE**
- GraphSAGE is a GNN architecture designed to handle large graphs (with over 100000 nodes) and is adopted by tech companies such as UberEats and Pinterest. It solves two issues with GCN and GAT - scaling to large graphs and efficiently generalizing to unseen data.

#### **Neighbour Sampling**
- Since every GNN layer computes node embeddings based on their neighbours, computing an embedding only requires the direct neighbours of this node (1 hop) and if a GNN has two layers, we need these neighbours and their own neighbours (2 hops).
- The 2-hop neighbours are aggregated to compute the embedding of 1-hop neighbours, which are aggreated to compute the embedding of a central node.
- However, the computation graph become exponentially large with respect to the number of hops and nodes with high degrees of connectivity create enormous computation graphs. Hence, neighbour sampling is used to limit the size of computation graphs.

#### **Aggregation**
There are three methods to compute embeddings given selected neighbouring nodes.
1. Mean aggregator
- The mean aggregator takes the embeddings of target nodes and their sampled neighbours and average them. A linear transformation with a weight matrix $W$ is applied and a non-linear transformation using ReLU or tanh is finally applied.
$$h_i' = \sigma(W_1h_i + W_2 mean_{j \in N_i} (h_j))$$
2. LSTM aggregator
3. Pooling aggregator
- Every neighbour's embedding is fed to a multi-layer perceptron to produce a new vector. An elementwise max operation is performed to only keep the highest value for each feature.

### **Implementing GraphSAGE to classify nodes on PubMed**

In [None]:
import torch
!pip install -q torch-scatter~=2.1.0 torch-sparse~=0.6.16 torch-cluster~=1.6.0 torch-spline-conv~=1.2.1 torch-geometric==2.2.0 -f https://data.pyg.org/whl/torch-{torch.__version__}.html

torch.manual_seed(-1)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m53.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m54.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m994.8/994.8 kB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m565.0/565.0 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for torch-geometric (setup.py) ... [?25l[?25hdone


In [None]:
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='.', name="Pubmed")
data = dataset[0]

# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
print(f'Training nodes: {sum(data.train_mask).item()}')
print(f'Evaluation nodes: {sum(data.val_mask).item()}')
print(f'Test nodes: {sum(data.test_mask).item()}')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...
Done!
  self.data, self.slices = torch.load(self.processed_paths[0])


Dataset: Pubmed()
-------------------
Number of graphs: 1
Number of nodes: 19717
Number of features: 500
Number of classes: 3
Training nodes: 60
Evaluation nodes: 500
Test nodes: 1000


In [None]:
from torch_geometric.loader import NeighborLoader

# Create batches with neighbor sampling
train_loader = NeighborLoader(
    data,
    num_neighbors=[10, 10], # Keep 10 neighbours of target node and 10 of their own neighbours
    batch_size=16, # Group 60 target nodes into batches of 16 nodes, which result in four batches
    input_nodes=data.train_mask)

In [None]:
import torch
torch.manual_seed(1)
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

def accuracy(pred_y, y):
    """Calculate accuracy."""
    return ((pred_y == y).sum() / len(y)).item()

class GraphSAGE(torch.nn.Module):
    """GraphSAGE"""
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h)
        self.sage2 = SAGEConv(dim_h, dim_out)

    def forward(self, x, edge_index):
        h = self.sage1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.sage2(h, edge_index)
        return h

    def fit(self, loader, epochs):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)

        self.train()
        for epoch in range(epochs+1):
            total_loss = 0
            acc = 0
            val_loss = 0
            val_acc = 0

            # Train on batches
            for batch in loader:
                optimizer.zero_grad()
                out = self(batch.x, batch.edge_index)
                loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
                total_loss += loss.item()
                acc += accuracy(out[batch.train_mask].argmax(dim=1), batch.y[batch.train_mask])
                loss.backward()
                optimizer.step()

                # Validation
                val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
                val_acc += accuracy(out[batch.val_mask].argmax(dim=1), batch.y[batch.val_mask])

            # Print metrics every 10 epochs
            if epoch % 20 == 0:
                print(f'Epoch {epoch:>3} | Train Loss: {loss/len(loader):.3f} | Train Acc: {acc/len(loader)*100:>6.2f}% | Val Loss: {val_loss/len(train_loader):.2f} | Val Acc: {val_acc/len(train_loader)*100:.2f}%')

    @torch.no_grad()
    def test(self, data):
        self.eval()
        out = self(data.x, data.edge_index)
        acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
        return acc

In [None]:
sage = GraphSAGE(dataset.num_features, 64, dataset.num_classes)
print(sage)

sage.fit(train_loader, epochs = 200)
acc = sage.test(data)
print(f'\nGraphSAGE test accuracy: {acc*100:.2f}%\n')

GraphSAGE(
  (sage1): SAGEConv(500, 64, aggr=mean)
  (sage2): SAGEConv(64, 3, aggr=mean)
)
Epoch   0 | Train Loss: 0.311 | Train Acc:  21.42% | Val Loss: 1.12 | Val Acc: 24.79%
Epoch  20 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.78 | Val Acc: 63.39%
Epoch  40 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.64 | Val Acc: 75.12%
Epoch  60 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.75 | Val Acc: 77.50%
Epoch  80 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.51 | Val Acc: 78.33%
Epoch 100 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.72 | Val Acc: 77.11%
Epoch 120 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.80 | Val Acc: 71.04%
Epoch 140 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.76 | Val Acc: 70.00%
Epoch 160 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.59 | Val Acc: 73.19%
Epoch 180 | Train Loss: 0.000 | Train Acc: 100.00% | Val Loss: 0.53 | Val Acc: 86.81%
Epoch 200 | Train Loss: 0.000 | Train Acc: 100.00