In [10]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from ogb.nodeproppred import PygNodePropPredDataset
from torch_geometric.loader import NeighborLoader
import torch_geometric.data.data
import torch_geometric.data.storage

torch.serialization.add_safe_globals([
    torch_geometric.data.data.DataEdgeAttr,
    torch_geometric.data.data.DataTensorAttr,
    torch_geometric.data.storage.GlobalStorage
])



## Load Dataset

In [11]:
dataset = PygNodePropPredDataset(name='ogbn-products')
data = dataset[0]
split_idx = dataset.get_idx_split()

print(f"Node features: {data.x.element_size() * data.x.nelement() / 1e9:.2f} GB")
print(f"Edge index: {data.edge_index.element_size() * data.edge_index.nelement() / 1e9:.2f} GB")


Node features: 0.98 GB
Edge index: 1.98 GB


## GraphSAGE Model

In [12]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, p=0.5):
        super().__init__()
        self.num_layers = num_layers
        self.convs = torch.nn.ModuleList()
        self.dropout_p = p
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
    
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout_p, training=self.training)
        return x



## Set Hyperparameters

In [13]:
# Model Hyperparameters
NUMBER_OF_LAYERS = 2
NEIGHBOR_SAMPLES = [10, 5]
HIDDEN_LAYER_DIMENSION = 256

# Training Hyperparameters
NUMBER_OF_EPOCHS = 25
BATCH_SIZE = 1024
DROPOUT_P = 0.5
LEARNING_RATE = 0.003
WORKER_COUNT = 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


Using device: cuda


## Initialize Model

In [14]:
model = GraphSAGE(data.x.shape[1], HIDDEN_LAYER_DIMENSION, dataset.num_classes, NUMBER_OF_LAYERS, p=DROPOUT_P).to(device)

## Train and Validation data loader 
### NeighborLoader keeps the full graph in CPU RAM, but for each iteration, it samples a small subgraph

In [15]:
train_loader = NeighborLoader(
    data,
    num_neighbors=NEIGHBOR_SAMPLES,
    batch_size=BATCH_SIZE,
    input_nodes=split_idx['train'],
    num_workers=WORKER_COUNT,
    shuffle=True,
)

val_loader = NeighborLoader(
    data,
    num_neighbors=[-1],
    batch_size=BATCH_SIZE,
    input_nodes=split_idx['valid'],
    num_workers=WORKER_COUNT,
    shuffle=False,
)

## Training function

In [16]:

def train(epoch):
    model.train()
    total_loss = total_correct = total_examples = 0
    
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)[:batch.batch_size]
        y = batch.y[:batch.batch_size].squeeze()
        loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
        
        total_loss += float(loss) * batch.batch_size
        total_correct += int((out.argmax(dim=-1) == y).sum())
        total_examples += batch.batch_size
    
    return total_loss / total_examples, total_correct / total_examples

@torch.no_grad()
def test(loader):
    model.eval()
    total_correct = total_examples = 0
    for batch in loader:
        batch = batch.to(device)
        out = model(batch.x, batch.edge_index)[:batch.batch_size]
        y = batch.y[:batch.batch_size].squeeze()
        total_correct += int((out.argmax(dim=-1) == y).sum())
        total_examples += batch.batch_size
    return total_correct / total_examples

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)



## Model Training

In [17]:
print("Training started...")
for epoch in range(1, NUMBER_OF_EPOCHS + 1):
    loss, train_acc = train(epoch)
    if epoch % 5 == 0:
        val_acc = test(val_loader)
        print(f'Epoch {epoch:02d}: Loss={loss:.4f}, Train={train_acc:.4f}, Val={val_acc:.4f}')
    else:
        print(f'Epoch {epoch:02d}: Loss={loss:.4f}, Train={train_acc:.4f}')

Training started...
Epoch 01: Loss=0.7483, Train=0.8109
Epoch 02: Loss=0.4682, Train=0.8745
Epoch 03: Loss=0.4375, Train=0.8826
Epoch 04: Loss=0.4169, Train=0.8875
Epoch 05: Loss=0.4106, Train=0.8898, Val=0.8513
Epoch 06: Loss=0.3997, Train=0.8920
Epoch 07: Loss=0.4010, Train=0.8913
Epoch 08: Loss=0.4148, Train=0.8903
Epoch 09: Loss=0.3889, Train=0.8949
Epoch 10: Loss=0.3761, Train=0.8977, Val=0.8637
Epoch 11: Loss=0.3903, Train=0.8962
Epoch 12: Loss=0.3806, Train=0.8984
Epoch 13: Loss=0.3725, Train=0.8991
Epoch 14: Loss=0.3731, Train=0.8987
Epoch 15: Loss=0.3828, Train=0.8971, Val=0.8603
Epoch 16: Loss=0.3781, Train=0.8985
Epoch 17: Loss=0.3774, Train=0.8980
Epoch 18: Loss=0.3615, Train=0.9013
Epoch 19: Loss=0.3652, Train=0.9009
Epoch 20: Loss=0.3631, Train=0.9016, Val=0.8455
Epoch 21: Loss=0.3674, Train=0.9000
Epoch 22: Loss=0.3578, Train=0.9021
Epoch 23: Loss=0.3584, Train=0.9022
Epoch 24: Loss=0.3570, Train=0.9032
Epoch 25: Loss=0.4052, Train=0.8954, Val=0.8547


## Model Testing

In [18]:
# Create test loader
test_loader = NeighborLoader(
    data,
    num_neighbors=[-1],
    batch_size=BATCH_SIZE,
    input_nodes=split_idx['test'],
    num_workers=WORKER_COUNT,
    shuffle=False,
)

# Evaluate on test set
print("Evaluating on test set...")
test_acc = test(test_loader)
print(f'Final Test Accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)')

Evaluating on test set...
Final Test Accuracy: 0.6615 (66.15%)
