# Graph Machine Learning

This notebook demonstrates Graph Machine Learning techniques using PyTorch Geometric.

## Learning Objectives
- Understand Graph Neural Networks (GNNs)
- Implement node classification
- Perform link prediction
- Apply graph classification

In [None]:
# Install required packages (uncomment if needed)
# !pip install torch torch-geometric torch-scatter torch-sparse

# Import required libraries
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.datasets import Planetoid, KarateClub
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Introduction to Graph Neural Networks

Let's start with a simple GNN implementation.

In [None]:
# Simple GCN implementation
class SimpleGCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(SimpleGCN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)
    
    def forward(self, x, edge_index):
        # First GCN layer
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        
        # Second GCN layer
        x = self.conv2(x, edge_index)
        
        return F.log_softmax(x, dim=1)

print("Simple GCN model created successfully!")

## 2. Node Classification

Let's implement node classification using the Cora dataset.

In [None]:
# Load Cora dataset
try:
    dataset = Planetoid(root='/tmp/Cora', name='Cora')
    data = dataset[0]
    print(f"Dataset: {dataset}")
    print(f"Number of graphs: {len(dataset)}")
    print(f"Number of features: {dataset.num_features}")
    print(f"Number of classes: {dataset.num_classes}")
    print(f"Number of nodes: {data.num_nodes}")
    print(f"Number of edges: {data.num_edges}")
    print(f"Average node degree: {data.num_edges / data.num_nodes:.2f}")
    print(f"Number of training nodes: {data.train_mask.sum()}")
    print(f"Number of validation nodes: {data.val_mask.sum()}")
    print(f"Number of test nodes: {data.test_mask.sum()}")
except Exception as e:
    print(f"Error loading Cora dataset: {e}")
    print("Creating a simple synthetic dataset instead...")
    
    # Create synthetic data
    num_nodes = 100
    num_features = 16
    num_classes = 7
    
    # Random features
    x = torch.randn(num_nodes, num_features)
    
    # Random edges
    edge_index = torch.randint(0, num_nodes, (2, 200))
    
    # Random labels
    y = torch.randint(0, num_classes, (num_nodes,))
    
    # Create masks
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[:70] = True
    val_mask = torch.zeros(num_nodes, dtype=torch.bool)
    val_mask[70:85] = True
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask[85:] = True
    
    data = Data(x=x, edge_index=edge_index, y=y, 
               train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)
    
    print(f"Synthetic dataset created with {num_nodes} nodes and {num_classes} classes")

In [None]:
# Train the GCN model
def train_gcn(data, epochs=200):
    model = SimpleGCN(data.num_features, data.y.max().item() + 1)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    
    model.train()
    losses = []
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if epoch % 50 == 0:
            print(f'Epoch {epoch:03d}: Loss: {loss.item():.4f}')
    
    return model, losses

# Train the model
model, losses = train_gcn(data)
print("Training completed!")

In [None]:
# Evaluate the model
model.eval()
with torch.no_grad():
    out = model(data.x, data.edge_index)
    pred = out.argmax(dim=1)
    
    # Calculate accuracies
    train_acc = pred[data.train_mask].eq(data.y[data.train_mask]).sum().item() / data.train_mask.sum().item()
    val_acc = pred[data.val_mask].eq(data.y[data.val_mask]).sum().item() / data.val_mask.sum().item()
    test_acc = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
    
    print(f'Train Accuracy: {train_acc:.4f}')
    print(f'Validation Accuracy: {val_acc:.4f}')
    print(f'Test Accuracy: {test_acc:.4f}')

# Plot training loss
plt.figure(figsize=(10, 6))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.grid(True, alpha=0.3)
plt.show()

## 3. Graph Attention Networks (GAT)

Let's implement a Graph Attention Network.

In [None]:
# GAT implementation
class GAT(torch.nn.Module):
    def __init__(self, num_features, num_classes, num_heads=8):
        super(GAT, self).__init__()
        self.conv1 = GATConv(num_features, 8, heads=num_heads, dropout=0.6)
        self.conv2 = GATConv(8 * num_heads, num_classes, heads=1, concat=False, dropout=0.6)
    
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

print("GAT model created successfully!")

## 4. Summary and Key Insights

### Key Takeaways:

1. **Graph Neural Networks**: Extend neural networks to graph-structured data
2. **Message Passing**: Nodes aggregate information from their neighbors
3. **Attention Mechanisms**: GAT uses attention to weight neighbor importance
4. **Node Classification**: Predict node labels using graph structure and features

### Applications:
- **Social Networks**: User classification, recommendation systems
- **Biological Networks**: Protein function prediction, drug discovery
- **Knowledge Graphs**: Entity classification, relation extraction
- **Computer Vision**: Scene understanding, object detection