<a href="https://colab.research.google.com/github/mjmousavi97/Graph-Neural-Networks/blob/main/08_GNN_Architects/src/GraphSAGE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

GraphSAGE is a Graph Neural Network (GNN) architecture designed to handle large-scale graphs.  
In the tech industry, scalability is one of the key factors driving growth. Therefore, systems are inherently built to support millions of users.  
This capability requires a fundamental shift in how GNN models operate compared to architectures like **GCN** and **GAT**.  
For this reason, it is no surprise that **GraphSAGE** has become the architecture of choice for companies such as **Uber Eats** and **Pinterest**.  

In this chapter, we explore the two main ideas behind GraphSAGE:  
1. The **neighbor sampling** technique, which forms the core of its scalability and efficiency.  
2. Three types of **aggregation operators** used to generate node embeddings.  



**GraphSAGE** (Hamilton et al., 2017) is a scalable Graph Neural Network (GNN) designed for **inductive representation learning** on large graphs.  
It addresses two main issues of traditional GNNs like GCN and GAT:
1. **Scalability** to large graphs  
2. **Generalization** to unseen data  

The model works by:
- **Sampling a fixed number of neighbors** (neighbor sampling) instead of using all connected nodes  
- **Aggregating information** from these sampled neighbors to compute node embeddings  

This approach prevents the **computation graph from growing exponentially** and allows efficient **mini-batch training** on GPUs.  
For example, sampling 3 neighbors at the first hop and 5 at the second limits the computation graph to only 15 nodes.


___

## Node Classification on PubMed

- **Dataset:** PubMed citation network (Planetoid family), available under MIT license [GitHub](https://github.com/kimiyoung/planetoid)  
- **Graph size:** 19,717 nodes, 88,648 edges  
- **Node features:** 500-dimensional TF-IDF-weighted word vectors  
- **Task:** Classify nodes into three categories:  
  1. Diabetes Mellitus Experimental  
  2. Diabetes Mellitus Type 1  
  3. Diabetes Mellitus Type 2  

We will implement this step by step using **PyTorch Geometric (PyG)**.


In [3]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m41.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0


In [5]:
from torch_geometric.datasets import Planetoid

In [6]:
dataset = Planetoid(root='.', name='PubMed')
data = dataset[0]

print(f'Dataset: {dataset}')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.pubmed.test.index
Processing...


Dataset: PubMed()


Done!


In [13]:
print('Dataset: ')
print('--------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f"Average node degree: {data.num_edges / data.num_nodes:.2f}{'\n'}")

print('Graph: ')
print('------')
print(f'Training Nodes = {data.train_mask.sum()}')
print(f'Validation Nodes = {data.val_mask.sum()}')
print(f'Testing Nodes = {data.test_mask.sum()}')

Dataset: 
--------
Number of graphs: 1
Number of features: 500
Number of classes: 3
Number of nodes: 19717
Number of edges: 88648
Average node degree: 4.50

Graph: 
------
Training Nodes = 60
Validation Nodes = 500
Testing Nodes = 1000
