# Dynamic Graph Neural Networks for Community Detection

This notebook explores detecting communities in dynamic (time-evolving) graphs using specialized GNN architectures. We'll cover:

1. Generating synthetic dynamic graphs with evolving community structure
2. Implementing dynamic GNN architectures (EvolveGCN and DySAT)
3. Training models to capture temporal evolution
4. Detecting communities at each time step
5. Visualizing community evolution over time
6. Evaluating and comparing dynamic GNN methods

In [ ]:
import sys
import os
import numpy as np
import torch
import polars as pl
import rustworkx as rx
import networkx as nx  # Still needed for some visualizations
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import time
import warnings
warnings.filterwarnings('ignore')

# Import our dynamic GNN utilities
from community_detection.dynamic_gnn import (
    generate_dynamic_graphs, visualize_dynamic_communities,
    EvolveGCN, DySAT, train_dynamic_gnn, extract_temporal_embeddings,
    detect_temporal_communities, evaluate_temporal_communities,
    visualize_community_evolution, run_dynamic_community_detection,
    compare_dynamic_gnn_models
)

# Import visualization utilities
from community_detection.visualization import (
    community_membership_heatmap, alluvial_diagram, vehlow_visualization
)

## 1. Check PyTorch Geometric Availability

Dynamic GNN methods require PyTorch and PyTorch Geometric. Let's check if they're available.

In [None]:
# Check if PyTorch Geometric is available
try:
    import torch_geometric
    from torch_geometric.data import Data
    TORCH_GEOMETRIC_AVAILABLE = True
    print("PyTorch Geometric is available.")
except ImportError:
    TORCH_GEOMETRIC_AVAILABLE = False
    print("PyTorch Geometric is not available. Please install it to run dynamic GNN methods.")
    print("You can install it with: pip install torch-geometric")

## 2. Generate a Sequence of Dynamic Graphs

Let's create a sequence of dynamic graphs with evolving community structure for our experiments.

In [None]:
# Check if PyTorch Geometric is available before proceeding
if not TORCH_GEOMETRIC_AVAILABLE:
    print("PyTorch Geometric is required for this notebook. Please install it and restart the notebook.")
else:
    # Generate a sequence of dynamic graphs
    print("Generating a sequence of dynamic graphs...")
    n_time_steps = 5
    n_communities = 3
    n_nodes = 100
    
    graphs = generate_dynamic_graphs(
        n_time_steps=n_time_steps,
        n_nodes=n_nodes,
        n_communities=n_communities,
        change_fraction=0.1  # 10% of nodes change communities between time steps
    )
    
    print(f"Generated {n_time_steps} graphs with {n_nodes} nodes and {n_communities} communities.")
    
    # Visualize the ground truth communities over time
    print("\nVisualizing ground truth communities over time...")
    visualize_dynamic_communities(
        graphs, 
        community_attr='community',
        title="Ground Truth Communities Over Time"
    )

## 3. Analyze Community Structure at Each Time Step

Let's analyze the community structure in more detail at each time step to see how it evolves.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Analyze community structure at each time step
    for t, G in enumerate(graphs):
        # Count nodes in each community
        community_counts = {}
        for i in range(len(G)):
            node_data = G.get_node_data(i)
            if node_data and 'community' in node_data:
                comm = node_data['community']
                if comm not in community_counts:
                    community_counts[comm] = 0
                community_counts[comm] += 1
        
        print(f"\nTime Step {t+1} Community Distribution:")
        for comm, count in sorted(community_counts.items()):
            print(f"  Community {comm}: {count} nodes")
        
        # Calculate edge statistics
        total_edges = G.num_edges()
        intra_community_edges = 0
        inter_community_edges = 0
        
        for edge in G.edge_list():
            source, target = edge[0], edge[1]
            source_data = G.get_node_data(source)
            target_data = G.get_node_data(target)
            if (source_data and target_data and 
                'community' in source_data and 'community' in target_data):
                if source_data['community'] == target_data['community']:
                    intra_community_edges += 1
                else:
                    inter_community_edges += 1
        
        print(f"  Total edges: {total_edges}")
        print(f"  Intra-community edges: {intra_community_edges} ({intra_community_edges/total_edges:.2f})")
        print(f"  Inter-community edges: {inter_community_edges} ({inter_community_edges/total_edges:.2f})")
    
    # Create list of community assignments
    communities_list = []
    for G in graphs:
        communities = {}
        for i in range(len(G)):
            node_data = G.get_node_data(i)
            if node_data and 'community' in node_data:
                communities[i] = node_data['community']
        communities_list.append(communities)
    
    # Visualize community membership over time
    community_membership_heatmap(communities_list, figsize=(12, 8))

## 4. EvolveGCN-Based Dynamic Community Detection

Let's apply the EvolveGCN model to detect communities in our dynamic graphs.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Run EvolveGCN-based dynamic community detection
    print("Running EvolveGCN-based dynamic community detection...")
    
    # Get number of communities from ground truth
    ground_truth_communities = set()
    for G in graphs:
        for i in range(len(G)):
            node_data = G.get_node_data(i)
            if node_data and 'community' in node_data:
                ground_truth_communities.add(node_data['community'])
    
    n_clusters = len(ground_truth_communities)
    
    # Run EvolveGCN
    evolvegcn_results = run_dynamic_community_detection(
        graphs, 
        model_type='evolvegcn',
        embedding_dim=16,
        n_clusters=n_clusters,
        epochs=50
    )
    
    # Save results
    os.makedirs('results', exist_ok=True)
    with open('results/evolvegcn_results.pkl', 'wb') as f:
        pickle.dump(evolvegcn_results, f)

## 5. DySAT-Based Dynamic Community Detection

Next, let's try the DySAT model, which uses self-attention mechanisms for temporal graph analysis.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Run DySAT-based dynamic community detection
    print("Running DySAT-based dynamic community detection...")
    
    # Run DySAT
    dysat_results = run_dynamic_community_detection(
        graphs, 
        model_type='dysat',
        embedding_dim=16,
        n_clusters=n_clusters,
        epochs=50
    )
    
    # Save results
    with open('results/dysat_results.pkl', 'wb') as f:
        pickle.dump(dysat_results, f)

## 6. Comparing Dynamic GNN Models

Now, let's run a comprehensive comparison of the dynamic GNN models.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Compare dynamic GNN models
    print("Comparing dynamic GNN models...")
    
    # Run comparison
    results_df = compare_dynamic_gnn_models(
        graphs,
        embedding_dim=16,
        n_clusters=n_clusters,
        epochs=50
    )
    
    # Save the comparison results
    results_df.write_parquet('results/dynamic_gnn_methods_comparison.parquet', compression="zstd")
    
    # Display results
    print("\nComparison Results:")
    print(results_df)

## 7. Visualizing Community Evolution

Let's create more advanced visualizations of community evolution over time.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Check if we have results from both models
    if 'communities_list' in evolvegcn_results and 'communities_list' in dysat_results:
        # Get communities detected by each model
        evolvegcn_communities = evolvegcn_results['communities_list']
        dysat_communities = dysat_results['communities_list']
        
        # Convert communities to dictionary format if needed
        evolvegcn_dict_list = []
        dysat_dict_list = []
        
        for t in range(len(graphs)):
            evolvegcn_dict = {i: int(evolvegcn_communities[t][i].item()) for i in range(len(graphs[t]))}
            dysat_dict = {i: int(dysat_communities[t][i].item()) for i in range(len(graphs[t]))}
            
            evolvegcn_dict_list.append(evolvegcn_dict)
            dysat_dict_list.append(dysat_dict)
        
        # Visualize community evolution
        print("\nVisualizing community evolution comparison...")
        
        # Function to create a figure with multiple rows for comparison
        def compare_community_evolution(graphs, community_lists, model_names, figsize=(18, 12)):
            n_models = len(community_lists)
            n_time_steps = len(graphs)
            
            fig, axes = plt.subplots(n_models, n_time_steps, figsize=figsize)
            
            # Convert graphs to NetworkX for visualization
            graphs_nx = []
            for G in graphs:
                G_nx = nx.Graph()
                
                # Add nodes
                for i in range(len(G)):
                    G_nx.add_node(i)
                
                # Add edges
                for edge in G.edge_list():
                    source, target = edge[0], edge[1]
                    G_nx.add_edge(source, target)
                    
                graphs_nx.append(G_nx)
            
            # Create a fixed layout based on the first graph
            pos = nx.spring_layout(graphs_nx[0], seed=42)
            
            for i, communities_list in enumerate(community_lists):
                for t, (G_nx, communities) in enumerate(zip(graphs_nx, communities_list)):
                    # Get axis
                    ax = axes[i, t]
                    
                    # Get community assignments for each node
                    community_ids = [communities.get(n, -1) for n in G_nx.nodes()]
                    
                    # Draw the network
                    nx.draw_networkx(
                        G_nx, pos=pos, 
                        node_color=community_ids, 
                        cmap=plt.cm.rainbow,
                        node_size=80,
                        with_labels=False,
                        edge_color='gray',
                        alpha=0.7,
                        ax=ax
                    )
                    
                    # Add title
                    if i == 0:
                        ax.set_title(f'Time Step {t+1}')
                    
                    # Add model name to first column
                    if t == 0:
                        ax.set_ylabel(model_names[i])
                    
                    ax.axis('off')
            
            plt.tight_layout()
            plt.suptitle("Community Evolution Comparison", fontsize=16, y=1.02)
            plt.show()
        
        # Compare ground truth with EvolveGCN and DySAT
        model_names = ['Ground Truth', 'EvolveGCN', 'DySAT']
        community_lists = [communities_list, evolvegcn_dict_list, dysat_dict_list]
        
        compare_community_evolution(graphs, community_lists, model_names, figsize=(20, 12))
        
        # Try creating an alluvial diagram if Plotly is available
        try:
            import plotly
            
            print("\nCreating alluvial diagram for ground truth communities:")
            alluvial_diagram(communities_list)
            
            print("\nCreating alluvial diagram for EvolveGCN detected communities:")
            alluvial_diagram(evolvegcn_dict_list)
            
            print("\nCreating alluvial diagram for DySAT detected communities:")
            alluvial_diagram(dysat_dict_list)
            
        except ImportError:
            print("Plotly not available. Skipping alluvial diagrams.")
            
        # Visualize using Vehlow approach
        print("\nVisualizing ground truth communities using Vehlow approach:")
        vehlow_visualization(graphs, communities_list)
        
        print("\nVisualizing EvolveGCN detected communities using Vehlow approach:")
        vehlow_visualization(graphs, evolvegcn_dict_list)

## 8. Evaluating Community Detection Quality Over Time

Let's analyze how the quality of detected communities changes over time.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Check if we have evaluation metrics for both models
    if 'metrics_list' in evolvegcn_results and 'metrics_list' in dysat_results:
        # Get evaluation metrics
        evolvegcn_metrics = evolvegcn_results['metrics_list']
        dysat_metrics = dysat_results['metrics_list']
        
        # Plot metrics over time
        plt.figure(figsize=(14, 6))
        
        # Plot NMI
        plt.subplot(1, 2, 1)
        time_steps = range(1, len(evolvegcn_metrics) + 1)
        plt.plot(time_steps, [m['nmi'] for m in evolvegcn_metrics], 'o-', label='EvolveGCN')
        plt.plot(time_steps, [m['nmi'] for m in dysat_metrics], 's-', label='DySAT')
        plt.title('NMI Over Time')
        plt.xlabel('Time Step')
        plt.ylabel('NMI')
        plt.ylim(0, 1)
        plt.legend()
        plt.grid(alpha=0.3)
        
        # Plot ARI
        plt.subplot(1, 2, 2)
        plt.plot(time_steps, [m['ari'] for m in evolvegcn_metrics], 'o-', label='EvolveGCN')
        plt.plot(time_steps, [m['ari'] for m in dysat_metrics], 's-', label='DySAT')
        plt.title('ARI Over Time')
        plt.xlabel('Time Step')
        plt.ylabel('ARI')
        plt.ylim(0, 1)
        plt.legend()
        plt.grid(alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Calculate average metrics and standard deviations
        evolvegcn_avg_nmi = np.mean([m['nmi'] for m in evolvegcn_metrics])
        evolvegcn_std_nmi = np.std([m['nmi'] for m in evolvegcn_metrics])
        evolvegcn_avg_ari = np.mean([m['ari'] for m in evolvegcn_metrics])
        evolvegcn_std_ari = np.std([m['ari'] for m in evolvegcn_metrics])
        
        dysat_avg_nmi = np.mean([m['nmi'] for m in dysat_metrics])
        dysat_std_nmi = np.std([m['nmi'] for m in dysat_metrics])
        dysat_avg_ari = np.mean([m['ari'] for m in dysat_metrics])
        dysat_std_ari = np.std([m['ari'] for m in dysat_metrics])
        
        print("\nAverage Metrics by Model:")
        print(f"EvolveGCN: NMI = {evolvegcn_avg_nmi:.4f} ± {evolvegcn_std_nmi:.4f}, ARI = {evolvegcn_avg_ari:.4f} ± {evolvegcn_std_ari:.4f}")
        print(f"DySAT:     NMI = {dysat_avg_nmi:.4f} ± {dysat_std_nmi:.4f}, ARI = {dysat_avg_ari:.4f} ± {dysat_std_ari:.4f}")

## 9. Summary and Conclusions

In this notebook, we have:

1. Generated a sequence of dynamic graphs with evolving community structure
2. Applied dynamic GNN models to detect communities over time
   - EvolveGCN: Uses RNN to evolve GNN parameters over time
   - DySAT: Uses self-attention for temporal graph learning
3. Compared the detection accuracy of different models
4. Visualized community evolution over time using multiple visualization techniques
5. Analyzed how community detection quality changes across time steps

Dynamic community detection is important for understanding how communities form, evolve, merge, and dissolve over time. The specialized dynamic GNN architectures we've explored enable us to effectively track these changes and maintain consistent community identities across time steps.

In the next notebook, we'll explore methods for detecting overlapping communities, where nodes can belong to multiple communities simultaneously.