# GraphSAGE

- References:
> https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html
> https://arxiv.org/abs/1706.02216
- GraphSAGE is designed to handle large graphs.
- There are three gradient descent algorithms:
    - Batch gradient descent: update weight and bias after processing the entire dataset. This algorithm is accurate, but memory expensive.
    - Stochastic gradient descent: update weight and bias for every training example. This algorithm is cheap, but high highly fluctuate.
    - Mini-batch gradient descent: update weight and bias every n-training example. This algorithm is a balance between computational cost and accuracy and convergence.
- Creating dataloader for mini-batch in graph can be tricky since it can break node connections, create isolated nodes. GraphSAGE divide graph dataset using neighbor sampling approach.
***
- In 1-layer GNN, only neighbors (1-hop) of a target node is needed. In 2-layer GNN, neighbors of the target node neighbors (2-hop) are needed.
- There are 2 problems:
    - Computational cost grows exponentially.
    - Unbalance work for uneven node degree distribution.
- These issues are solved by soing neighbor sampling and limit the number of neighbors for aggregation.
***
- There are three aggregators:
    - Mean aggregator.
    - Long-short term memory aggregator.
    - Pooling aggregator.
- The mean aggregator consists of the following steps:
    - Averaging target node and its neighbor embeddings.
    - Perform linear transformation and apply activation function
> $h^\prime_A = \sigma(W . mean_{i\in N_A}\{h_i\})$
    - A variant of this method is to perform transformation for target node and neighbor separately:
> $h^\prime_A = \sigma(W_1h_A + W_2 . mean_{i\in N_A}\{h_i\})$

## Pubmed dataset
- Pubmed dataset is similar to Cora and CiteSeer that have:
    - 19,717 articles (nodes)
    - 88,648 references (edges)
    - 500-dimension binary vector title (node features)
- These articles belong to 3 categories:
    - Mellitus experimental,
    - Diabetes mellitus type 1
    - Diabetes mellitus type 2
- The target of node classification is to label articles to one of these groups.

In [None]:
# import modules
import torch
try:
    from torch_geometric.datasets import Planetoid
except:
    !pip install torch-geometric
    from torch_geometric.datasets import Planetoid

# download Pubmed
pubmed_dataset = Planetoid(root='.', name='Pubmed')
pubmed_data = pubmed_dataset[0]

## Classifying Pubmed articles with GraphSAGE

In [None]:
# device agnostic
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# import dataloader module
from torch_geometric.loader import NeighborLoader

# create the train dataloader
train_loader = NeighborLoader(
    pubmed_data,
    num_neighbors=[10,10],
    batch_size=16,
    input_nodes=pubmed_data.train_mask
)

# print a batch
batch = next(iter(train_loader))
batch

In [None]:
# create graphSAGE class
from torch.nn import functional
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    
    def __init__(self, dim_in:int, dim_h:int, dim_out:int):
        super().__init__()
        self.sage1 = SAGEConv(in_channels=dim_in, out_channels=dim_h)
        self.sage2 = SAGEConv(in_channels=dim_h, out_channels=dim_out)
        
    def forward(self, x:torch.Tensor, edge_index:torch.Tensor):
        x = self.sage1(x, edge_index)
        x = torch.relu(x)
        x = functional.dropout(x, p=0.5, training=self.training)
        x = self.sage2(x, edge_index)
        return x

In [None]:
# create a GraphSAGE instance for Pubmed and send the model to device
pubmed_model = GraphSAGE(
    dim_in = pubmed_dataset.num_features,
    dim_h = 64,
    dim_out = pubmed_dataset.num_classes
).to(device)
pubmed_model

In [None]:
# setup loss function and optimizer
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=pubmed_model.parameters(), lr=0.01)

In [None]:
# train GraphSage on device
from sources.revisited_engine import train_graph_sage

results = train_graph_sage(
    model=pubmed_model,
    dataloader=train_loader,
    device=device,
    loss_fn = loss_fn,
    optimizer=optimizer,
    epochs=101,
    print_results=True
)

In [None]:
# visualize results
from sources.revisited_utils import visualize_results

visualize_results(results=results)

In [None]:
# compute test accuracy
from sources.revisited_engine import test

test_acc = test(
    model=pubmed_model.cpu(),
    data=pubmed_data
)
print(f"Test accuracy: {test_acc*100:.1f}(%)")