# Graph Neural Network-Based Community Detection

This notebook demonstrates how to use Graph Neural Networks (GNNs) for community detection. We'll explore:

1. Loading and preparing graphs for GNN processing
2. Implementing and training different GNN architectures
3. Extracting node embeddings and performing clustering
4. Evaluating detection results against ground truth
5. Comparing different GNN models

In [ ]:
import sys
import os
import numpy as np
import torch
import polars as pl
import rustworkx as rx
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import time
import warnings
warnings.filterwarnings('ignore')

# Import directly from the community_detection package
from community_detection.gnn_community_detection import (
    rwx_to_pyg, GCN, GraphSAGE, GAT, VGAE,
    train_gnn_embedding, extract_embeddings, detect_communities_from_embeddings,
    evaluate_communities, plot_embeddings, add_communities_to_graph,
    run_gnn_community_detection, compare_gnn_models
)

# Import functions from data_prep for generating graphs if needed
from community_detection.data_prep import generate_synthetic_graph

## 1. Check PyTorch Geometric Availability

GNN-based methods require PyTorch and PyTorch Geometric. Let's check if they're available.

In [None]:
# Check if PyTorch Geometric is available
try:
    import torch_geometric
    from torch_geometric.data import Data
    TORCH_GEOMETRIC_AVAILABLE = True
    print("PyTorch Geometric is available.")
except ImportError:
    TORCH_GEOMETRIC_AVAILABLE = False
    print("PyTorch Geometric is not available. Please install it to run GNN-based methods.")
    print("You can install it with: pip install torch-geometric")

## 2. Load or Generate Test Graph

Let's load the graph we've been using in previous notebooks.

In [None]:
# Check if PyTorch Geometric is available before proceeding
if not TORCH_GEOMETRIC_AVAILABLE:
    print("PyTorch Geometric is required for this notebook. Please install it and restart the notebook.")
else:
    # Check if previously saved graphs are available
    if os.path.exists('data/sbm_graph.gpickle'):
        # Load the SBM graph using pickle
        print("Loading SBM graph from file...")
        with open('data/sbm_graph.gpickle', 'rb') as f:
            G = pickle.load(f)
    else:
        # Generate a new SBM graph
        print("Generating a new SBM graph...")
        n_communities = 5
        G, _ = generate_synthetic_graph(
            'sbm', 
            n_nodes=100, 
            n_communities=n_communities,
            p_in=0.3, 
            p_out=0.05
        )
        
        # Save the graph in pickle format for future use
        os.makedirs('data', exist_ok=True)
        with open('data/sbm_graph.gpickle', 'wb') as f:
            pickle.dump(G, f)
        print("Generated and saved new SBM graph.")

    # Extract ground truth communities
    ground_truth = {}
    for i in range(len(G)):
        node_data = G.get_node_data(i)
        if node_data and 'community' in node_data:
            ground_truth[i] = node_data['community']

    # Visualize the graph with ground truth communities
    plot_embeddings(torch.randn(len(G), 2), torch.tensor([ground_truth.get(i, 0) for i in range(len(G))]), 
                  method='none', title="Ground Truth Communities")

## 3. GCN-Based Community Detection

Let's start with a Graph Convolutional Network (GCN) model.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Run GCN-based community detection
    print("Running GCN-based community detection...")
    
    # Get number of communities from ground truth
    n_clusters = len(set(ground_truth.values()))
    
    # Run GCN
    gcn_results = run_gnn_community_detection(
        G, 
        model_type='gcn',
        embedding_dim=16,
        n_clusters=n_clusters,
        epochs=100,
        ground_truth_attr='community'
    )
    
    # Save results
    os.makedirs('results', exist_ok=True)
    with open('results/gcn_results.pkl', 'wb') as f:
        pickle.dump(gcn_results, f)

## 4. GraphSAGE-Based Community Detection

Now let's try GraphSAGE, which is particularly good for inductive learning tasks.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Run GraphSAGE-based community detection
    print("Running GraphSAGE-based community detection...")
    
    # Run GraphSAGE
    graphsage_results = run_gnn_community_detection(
        G, 
        model_type='graphsage',
        embedding_dim=16,
        n_clusters=n_clusters,
        epochs=100,
        ground_truth_attr='community'
    )
    
    # Save results
    with open('results/graphsage_results.pkl', 'wb') as f:
        pickle.dump(graphsage_results, f)

## 5. GAT-Based Community Detection

Graph Attention Networks (GAT) use attention mechanisms to weight neighbor nodes' influence.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Run GAT-based community detection
    print("Running GAT-based community detection...")
    
    # Run GAT
    gat_results = run_gnn_community_detection(
        G, 
        model_type='gat',
        embedding_dim=16,
        n_clusters=n_clusters,
        epochs=100,
        ground_truth_attr='community'
    )
    
    # Save results
    with open('results/gat_results.pkl', 'wb') as f:
        pickle.dump(gat_results, f)

## 6. VGAE-Based Community Detection

Variational Graph Autoencoders (VGAE) learn latent representations in an unsupervised manner.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Run VGAE-based community detection
    print("Running VGAE-based community detection...")
    
    # Run VGAE
    vgae_results = run_gnn_community_detection(
        G, 
        model_type='vgae',
        embedding_dim=16,
        n_clusters=n_clusters,
        epochs=100,
        ground_truth_attr='community'
    )
    
    # Save results
    with open('results/vgae_results.pkl', 'wb') as f:
        pickle.dump(vgae_results, f)

## 7. Comparing GNN Models

Now, let's run a comprehensive comparison of all the GNN-based methods.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Compare different GNN models
    print("Comparing different GNN models...")
    
    # Run comparison
    results_df = compare_gnn_models(
        G,
        embedding_dim=16,
        n_clusters=n_clusters,
        epochs=100,
        ground_truth_attr='community'
    )
    
    # Save the comparison results
    results_df.write_parquet('results/gnn_methods_comparison.parquet', compression="zstd")
    
    # Display results
    print("\nComparison Results:")
    print(results_df)

## 8. Visualizing Node Embeddings

Let's take a closer look at the node embeddings generated by different GNN models.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Create a custom function to extract and visualize embeddings from each model
    def visualize_model_embeddings(G, model_type, embedding_dim=16, epochs=100):
        # Convert graph to PyG format
        data = rwx_to_pyg(G)
        
        # Initialize model
        if model_type == 'gcn':
            model = GCN(data.x.size(1), hidden_dim=32, output_dim=embedding_dim)
        elif model_type == 'graphsage':
            model = GraphSAGE(data.x.size(1), hidden_dim=32, output_dim=embedding_dim)
        elif model_type == 'gat':
            model = GAT(data.x.size(1), hidden_dim=32, output_dim=embedding_dim)
        elif model_type == 'vgae':
            model = VGAE(data.x.size(1), hidden_dim=32, latent_dim=embedding_dim)
        else:
            raise ValueError(f"Unsupported model type: {model_type}")
        
        # Train model
        model = train_gnn_embedding(model, data, epochs=epochs, verbose=True)
        
        # Extract embeddings
        embeddings = extract_embeddings(model, data)
        
        # Get ground truth communities
        ground_truth = []
        for i in range(len(G)):
            node_data = G.get_node_data(i)
            if node_data and 'community' in node_data:
                ground_truth.append(node_data['community'])
            else:
                ground_truth.append(0)
        
        ground_truth = torch.tensor(ground_truth)
        
        # Visualize embeddings colored by ground truth
        plot_embeddings(embeddings, ground_truth, method='tsne', 
                      title=f"{model_type.upper()} Node Embeddings (Ground Truth Coloring)")
        
        # Detect communities from embeddings
        n_clusters = len(set(ground_truth.tolist()))
        communities = detect_communities_from_embeddings(embeddings, n_clusters=n_clusters)
        
        # Visualize embeddings colored by detected communities
        plot_embeddings(embeddings, communities, method='tsne',
                      title=f"{model_type.upper()} Node Embeddings (Detected Communities)")
        
        return embeddings, communities
    
    # Visualize embeddings for each model type
    model_types = ['gcn', 'graphsage', 'gat', 'vgae']
    
    for model_type in model_types:
        print(f"\nVisualizing embeddings for {model_type.upper()}...")
        try:
            visualize_model_embeddings(G, model_type, embedding_dim=16, epochs=50)
        except Exception as e:
            print(f"Error visualizing embeddings for {model_type}: {e}")

## 9. Summary and Conclusions

In this notebook, we have:

1. Applied various GNN-based community detection methods to our test graph
   - Graph Convolutional Network (GCN)
   - GraphSAGE
   - Graph Attention Network (GAT)
   - Variational Graph Autoencoder (VGAE)
2. Trained the models and extracted node embeddings
3. Detected communities from the embeddings
4. Evaluated the results against ground truth
5. Compared different GNN architectures
6. Visualized the node embeddings

Graph Neural Networks provide a powerful approach to community detection by learning meaningful node representations that capture both structural and attribute information. These embeddings can then be used with standard clustering algorithms to detect communities.

In the next notebook, we'll explore methods for detecting communities in dynamic graphs that evolve over time.