# Overlapping Community Detection Methods

This notebook explores methods for detecting overlapping communities, where nodes can belong to multiple communities simultaneously. We'll cover:

1. Generating synthetic graphs with overlapping communities
2. Implementing and applying overlapping community detection algorithms
3. Evaluating detection results with metrics for overlapping communities
4. Visualizing overlapping community structures
5. Comparing different detection methods

In [ ]:
import sys
import os
import numpy as np
import torch
import polars as pl
import rustworkx as rx
import networkx as nx  # Still needed for some algorithms and visualization
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import time
import warnings
warnings.filterwarnings('ignore')

# Import visualization functions
from community_detection.visualization import visualize_overlapping_communities

# Import from evaluation module
from community_detection.evaluation import save_result

# Import overlapping community detection functions
from community_detection.overlapping_community_detection import (
    generate_synthetic_overlapping_graph, plot_overlapping_communities,
    run_bigclam, run_demon, run_slpa, GNN_Overlapping, rwx_to_pyg_overlapping,
    train_gnn_overlapping, predict_gnn_overlapping, run_gnn_overlapping,
    evaluate_overlapping_communities
)

## 1. Check Required Libraries

Let's check if the required libraries for overlapping community detection are available.

In [None]:
# Check if cdlib is available for traditional overlapping methods
try:
    from cdlib import algorithms, evaluation as cdlib_eval
    from cdlib.classes import NodeClustering
    CDLIB_AVAILABLE = True
    print("cdlib is available.")
except ImportError:
    CDLIB_AVAILABLE = False
    print("cdlib is not available. Some methods will be skipped.")
    print("You can install it with: pip install cdlib")

# Check if PyTorch Geometric is available for GNN-based methods
try:
    import torch_geometric
    from torch_geometric.data import Data
    from torch_geometric.nn import GCNConv
    TORCH_GEOMETRIC_AVAILABLE = True
    print("PyTorch Geometric is available.")
except ImportError:
    TORCH_GEOMETRIC_AVAILABLE = False
    print("PyTorch Geometric is not available. GNN-based methods will be skipped.")
    print("You can install it with: pip install torch-geometric")

## 2. Define Functions for Overlapping Community Detection

Let's define the necessary functions for generating and analyzing overlapping communities.

In [None]:
def generate_synthetic_overlapping_graph(n_nodes: int = 100, 
                                       n_communities: int = 4,
                                       overlap_size: int = 20,
                                       p_in: float = 0.3,
                                       p_out: float = 0.05,
                                       seed: int = 42) -> tuple:
    """
    Generate a synthetic graph with overlapping communities
    
    Parameters:
    -----------
    n_nodes: int
        Number of nodes
    n_communities: int
        Number of communities
    overlap_size: int
        Number of nodes that will belong to multiple communities
    p_in: float
        Probability of edge within a community
    p_out: float
        Probability of edge between communities
    seed: int
        Random seed
        
    Returns:
    --------
    G: rustworkx.PyGraph
        The generated graph
    ground_truth: list
        List of lists containing node indices for each community
    """
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # Calculate nodes per community
    base_size = (n_nodes - overlap_size) // n_communities
    remainder = (n_nodes - overlap_size) % n_communities
    
    # Assign nodes to communities
    community_sizes = [base_size + (1 if i < remainder else 0) for i in range(n_communities)]
    
    # Assign non-overlapping nodes to communities
    communities = [[] for _ in range(n_communities)]
    node_id = 0
    
    for i, size in enumerate(community_sizes):
        for _ in range(size):
            communities[i].append(node_id)
            node_id += 1
    
    # Distribute overlapping nodes
    overlap_assignments = []
    
    # Each overlapping node belongs to 2 communities
    for i in range(overlap_size):
        # Randomly select 2 distinct communities
        comm_indices = np.random.choice(n_communities, 2, replace=False)
        overlap_assignments.append(list(comm_indices))
        
        # Add the overlapping node to both communities
        for comm_idx in comm_indices:
            communities[comm_idx].append(node_id)
        
        node_id += 1
    
    # Create the graph using rustworkx
    G = rx.PyGraph()
    
    # Add all nodes
    for i in range(n_nodes):
        G.add_node(None)
    
    # Add edges based on communities
    # Nodes in the same community have higher probability (p_in) of being connected
    # Nodes in different communities have lower probability (p_out) of being connected
    
    # Create a matrix tracking which communities each node belongs to
    node_memberships = [[] for _ in range(n_nodes)]
    for comm_idx, comm_nodes in enumerate(communities):
        for node in comm_nodes:
            node_memberships[node].append(comm_idx)
    
    # Add edges based on community membership
    for i in range(n_nodes):
        for j in range(i+1, n_nodes):
            # Check if nodes share a community
            shared_community = any(comm in node_memberships[j] for comm in node_memberships[i])
            
            # Set probability based on community membership
            p = p_in if shared_community else p_out
            
            # Add edge with probability p
            if np.random.random() < p:
                G.add_edge(i, j, None)
    
    return G, communities

def _rwx_to_nx(G: rx.PyGraph) -> nx.Graph:
    """
    Convert a RustworkX graph to a NetworkX graph
    
    Parameters:
    -----------
    G: rx.PyGraph
        RustworkX graph
        
    Returns:
    --------
    G_nx: nx.Graph
        NetworkX graph
    """
    G_nx = nx.Graph()
    
    # Add nodes
    for i in range(len(G)):
        G_nx.add_node(i, **({} if G.get_node_data(i) is None else G.get_node_data(i)))
    
    # Add edges
    for edge in G.edge_list():
        source, target = edge[0], edge[1]
        edge_data = G.get_edge_data(source, target)
        G_nx.add_edge(source, target, **({} if edge_data is None else edge_data))
    
    return G_nx

def run_bigclam(G: rx.PyGraph, k: int = 5, iterations: int = 50) -> tuple:
    """
    Run BigCLAM algorithm for overlapping community detection
    
    Parameters:
    -----------
    G: rx.PyGraph
        Graph to analyze
    k: int
        Number of communities to detect
    iterations: int
        Number of iterations for the algorithm
        
    Returns:
    --------
    communities: list
        List of lists containing node indices for each community
    execution_time: float
        Execution time in seconds
    """
    if not CDLIB_AVAILABLE:
        raise ImportError("cdlib is required for BigCLAM algorithm")
    
    start_time = time.time()
    
    # Convert to NetworkX for cdlib
    G_nx = _rwx_to_nx(G)
    
    # Run BigCLAM using cdlib
    result = algorithms.bigclam(G_nx, k=k, iterations=iterations)
    
    # Extract communities
    communities = result.communities
    
    execution_time = time.time() - start_time
    return communities, execution_time

def run_demon(G: rx.PyGraph, epsilon: float = 0.25, min_com_size: int = 3) -> tuple:
    """
    Run DEMON algorithm for overlapping community detection
    
    Parameters:
    -----------
    G: rx.PyGraph
        Graph to analyze
    epsilon: float
        Merging threshold
    min_com_size: int
        Minimum community size
        
    Returns:
    --------
    communities: list
        List of lists containing node indices for each community
    execution_time: float
        Execution time in seconds
    """
    if not CDLIB_AVAILABLE:
        raise ImportError("cdlib is required for DEMON algorithm")
    
    start_time = time.time()
    
    # Convert to NetworkX for cdlib
    G_nx = _rwx_to_nx(G)
    
    # Run DEMON using cdlib
    result = algorithms.demon(G_nx, epsilon=epsilon, min_comm_size=min_com_size)
    
    # Extract communities
    communities = result.communities
    
    execution_time = time.time() - start_time
    return communities, execution_time

def run_slpa(G: rx.PyGraph, t: int = 21, r: float = 0.1) -> tuple:
    """
    Run SLPA algorithm for overlapping community detection
    
    Parameters:
    -----------
    G: rx.PyGraph
        Graph to analyze
    t: int
        Number of iterations
    r: float
        Threshold for community inclusion
        
    Returns:
    --------
    communities: list
        List of lists containing node indices for each community
    execution_time: float
        Execution time in seconds
    """
    if not CDLIB_AVAILABLE:
        raise ImportError("cdlib is required for SLPA algorithm")
    
    start_time = time.time()
    
    # Convert to NetworkX for cdlib
    G_nx = _rwx_to_nx(G)
    
    # Run SLPA using cdlib
    result = algorithms.slpa(G_nx, t=t, r=r)
    
    # Extract communities
    communities = result.communities
    
    execution_time = time.time() - start_time
    return communities, execution_time

def rwx_to_pyg_overlapping(G: rx.PyGraph, communities: list = None) -> Data:
    """
    Convert a RustworkX graph to a PyTorch Geometric Data object
    with overlapping community information
    
    Parameters:
    -----------
    G: rx.PyGraph
        Graph to convert
    communities: list
        List of lists containing node indices for each community
        
    Returns:
    --------
    data: torch_geometric.data.Data
        PyTorch Geometric Data object
    """
    if not TORCH_GEOMETRIC_AVAILABLE:
        raise ImportError("PyTorch Geometric is required")
    
    # Get edges
    edge_list = []
    for edge in G.edge_list():
        edge_list.append((edge[0], edge[1]))
    
    # Convert to PyG format
    if edge_list:
        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
        # For undirected graph in PyG, we need both directions
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
    else:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
    
    # Create node features (use degree as feature if not provided)
    degrees = torch.tensor([G.degree(i) for i in range(len(G))])
    max_degree = degrees.max().item()
    x = torch.zeros((len(G), max_degree + 1), dtype=torch.float)
    for i, degree in enumerate(degrees):
        x[i, degree] = 1.0
    
    # Create community membership matrix if communities provided
    if communities is not None:
        y = torch.zeros((len(G), len(communities)), dtype=torch.float)
        for i, comm in enumerate(communities):
            for node in comm:
                y[node, i] = 1.0
    else:
        y = None
    
    # Create PyG data object
    data = Data(x=x, edge_index=edge_index, y=y)
    
    return data

class GNN_Overlapping(torch.nn.Module):
    """
    GNN model for overlapping community detection
    """
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN_Overlapping, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
    
    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))
        x = torch.relu(self.conv2(x, edge_index))
        x = torch.sigmoid(self.conv3(x, edge_index))  # Sigmoid for multi-label output
        return x

def train_gnn_overlapping(model, data, epochs=100, lr=0.01, weight_decay=5e-4):
    """
    Train a GNN model for overlapping community detection
    
    Parameters:
    -----------
    model: torch.nn.Module
        GNN model to train
    data: torch_geometric.data.Data
        PyG data object with overlapping community labels
    epochs: int
        Number of training epochs
    lr: float
        Learning rate
    weight_decay: float
        Weight decay factor
        
    Returns:
    --------
    model: torch.nn.Module
        Trained model
    losses: list
        List of training losses
    """
    if not TORCH_GEOMETRIC_AVAILABLE:
        raise ImportError("PyTorch Geometric is required")
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = data.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = torch.nn.BCELoss()  # Binary cross-entropy for multi-label classification
    
    # Training loop
    model.train()
    losses = []
    
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
        if (epoch + 1) % 20 == 0:
            print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}')
    
    return model, losses

def predict_gnn_overlapping(model, data, threshold=0.5):
    """
    Predict overlapping communities using a trained GNN model
    
    Parameters:
    -----------
    model: torch.nn.Module
        Trained GNN model
    data: torch_geometric.data.Data
        PyG data object
    threshold: float
        Threshold for community membership
        
    Returns:
    --------
    communities: list
        List of lists containing node indices for each community
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = data.to(device)
    
    model.eval()
    with torch.no_grad():
        out = model(data.x, data.edge_index).cpu()
    
    # Convert predictions to communities
    predictions = (out > threshold).float().numpy()
    n_communities = predictions.shape[1]
    
    # Extract communities
    communities = []
    for i in range(n_communities):
        community = np.where(predictions[:, i] > 0)[0].tolist()
        if community:  # Only add non-empty communities
            communities.append(community)
    
    return communities

def run_gnn_overlapping(G: rx.PyGraph, ground_truth: list, hidden_dim: int = 64,
                      epochs: int = 100, threshold: float = 0.5) -> tuple:
    """
    Run GNN-based overlapping community detection
    
    Parameters:
    -----------
    G: rx.PyGraph
        Graph to analyze
    ground_truth: list
        List of lists containing node indices for each community
    hidden_dim: int
        Hidden dimension of the GNN model
    epochs: int
        Number of training epochs
    threshold: float
        Threshold for community membership
        
    Returns:
    --------
    communities: list
        List of lists containing node indices for each community
    execution_time: float
        Execution time in seconds
    losses: list
        List of training losses
    """
    if not TORCH_GEOMETRIC_AVAILABLE:
        raise ImportError("PyTorch Geometric is required for GNN-based methods")
    
    start_time = time.time()
    
    # Convert to PyG data
    data = rwx_to_pyg_overlapping(G, ground_truth)
    
    # Create model
    model = GNN_Overlapping(
        input_dim=data.x.size(1),
        hidden_dim=hidden_dim,
        output_dim=data.y.size(1)
    )
    
    # Train model
    model, losses = train_gnn_overlapping(model, data, epochs=epochs)
    
    # Predict communities
    communities = predict_gnn_overlapping(model, data, threshold=threshold)
    
    execution_time = time.time() - start_time
    return communities, execution_time, losses

def evaluate_overlapping_communities(detected: list, ground_truth: list) -> dict:
    """
    Evaluate overlapping community detection results
    
    Parameters:
    -----------
    detected: list
        List of lists containing node indices for each detected community
    ground_truth: list
        List of lists containing node indices for each ground truth community
        
    Returns:
    --------
    metrics: dict
        Dictionary containing evaluation metrics
    """
    if not CDLIB_AVAILABLE:
        raise ImportError("cdlib is required for evaluation")
    
    # Create a dummy graph for cdlib evaluation
    # Get all unique nodes
    all_nodes = set()
    for comm in ground_truth + detected:
        all_nodes.update(comm)
    
    G_dummy = nx.Graph()
    G_dummy.add_nodes_from(all_nodes)
    
    # Convert to cdlib format
    ground_truth_clustering = NodeClustering(ground_truth, G_dummy)
    detected_clustering = NodeClustering(detected, G_dummy)
    
    # Calculate metrics
    metrics = {}
    
    # Normalized Mutual Information
    metrics['nmi'] = cdlib_eval.normalized_mutual_information(ground_truth_clustering, detected_clustering).score
    
    # Omega Index
    metrics['omega'] = cdlib_eval.omega(ground_truth_clustering, detected_clustering).score
    
    # F1 Score (based on overlap)
    # Compute precision and recall for each detected community against best matching ground truth community
    total_precision = 0
    total_recall = 0
    
    # For each detected community
    for det_comm in detected:
        det_comm_set = set(det_comm)
        best_f1 = 0
        best_precision = 0
        best_recall = 0
        
        # Find best matching ground truth community
        for gt_comm in ground_truth:
            gt_comm_set = set(gt_comm)
            
            # Calculate overlap
            intersection = len(det_comm_set & gt_comm_set)
            
            if intersection > 0:
                precision = intersection / len(det_comm_set)  # TP / (TP + FP)
                recall = intersection / len(gt_comm_set)      # TP / (TP + FN)
                f1 = 2 * precision * recall / (precision + recall)
                
                if f1 > best_f1:
                    best_f1 = f1
                    best_precision = precision
                    best_recall = recall
        
        total_precision += best_precision
        total_recall += best_recall
    
    # Calculate average precision and recall
    avg_precision = total_precision / len(detected) if detected else 0
    avg_recall = total_recall / len(detected) if detected else 0
    
    # Calculate F1 score
    if avg_precision + avg_recall > 0:
        metrics['f1'] = 2 * avg_precision * avg_recall / (avg_precision + avg_recall)
    else:
        metrics['f1'] = 0
    
    return metrics

## 3. Generate a Synthetic Graph with Overlapping Communities

Let's create a synthetic graph with overlapping communities for our experiments.

In [None]:
# Generate a synthetic graph with overlapping communities
print("Generating a synthetic graph with overlapping communities...")
n_communities = 4
n_nodes = 100
overlap_size = 20  # Number of nodes that will belong to multiple communities

G, ground_truth = generate_synthetic_overlapping_graph(
    n_nodes=n_nodes,
    n_communities=n_communities,
    overlap_size=overlap_size,
    p_in=0.3,  # probability of edge within community
    p_out=0.05  # probability of edge between communities
)

print(f"Generated graph with {len(G)} nodes and {G.num_edges()} edges.")
print(f"Created {n_communities} overlapping communities.")

# Visualize the ground truth overlapping communities
print("\nVisualizing ground truth overlapping communities...")
visualize_overlapping_communities(G, ground_truth, figsize=(12, 10), alpha=0.6)

## 4. Analyze Overlapping Community Structure

Let's analyze the structure of our overlapping communities to better understand them.

In [None]:
# Analyze ground truth overlapping communities
print("Analyzing ground truth overlapping communities:")

# Count community sizes
community_sizes = [len(comm) for comm in ground_truth]
print(f"Community sizes: {community_sizes}")

# Count number of communities per node
node_memberships = {}
for i, comm in enumerate(ground_truth):
    for node in comm:
        if node not in node_memberships:
            node_memberships[node] = []
        node_memberships[node].append(i)

membership_counts = {node: len(comms) for node, comms in node_memberships.items()}
membership_distribution = {}
for count in membership_counts.values():
    if count not in membership_distribution:
        membership_distribution[count] = 0
    membership_distribution[count] += 1

print("\nMembership distribution (number of nodes with N community memberships):")
for count, num_nodes in sorted(membership_distribution.items()):
    print(f"  {num_nodes} nodes belong to {count} communities")

# Plot the distribution of community memberships
plt.figure(figsize=(10, 6))
plt.bar(membership_distribution.keys(), membership_distribution.values())
plt.title('Distribution of Community Memberships')
plt.xlabel('Number of Communities a Node Belongs To')
plt.ylabel('Number of Nodes')
plt.xticks(range(1, max(membership_distribution.keys()) + 1))
plt.grid(alpha=0.3)
plt.show()

# Analyze overlap between communities
print("\nOverlap between communities:")
overlap_matrix = np.zeros((n_communities, n_communities))
for i in range(n_communities):
    for j in range(i+1, n_communities):
        overlap = len(set(ground_truth[i]) & set(ground_truth[j]))
        overlap_matrix[i, j] = overlap
        overlap_matrix[j, i] = overlap
        print(f"  Communities {i} and {j} share {overlap} nodes")

# Visualize the overlap matrix
plt.figure(figsize=(8, 6))
sns.heatmap(overlap_matrix, annot=True, fmt=".0f", cmap="YlGnBu")
plt.title('Overlap Between Communities')
plt.xlabel('Community ID')
plt.ylabel('Community ID')
plt.show()

## 5. BigCLAM Algorithm

BigCLAM (Cluster Affiliation Model for Big Networks) is designed to detect overlapping communities in large networks.

In [None]:
if CDLIB_AVAILABLE:
    # Run BigCLAM algorithm
    print("Running BigCLAM algorithm...")
    
    # Set the number of communities to detect
    k = len(ground_truth)
    
    # Run BigCLAM
    bigclam_communities, bigclam_time = run_bigclam(G, k=k)
    
    # Evaluate against ground truth
    bigclam_metrics = evaluate_overlapping_communities(bigclam_communities, ground_truth)
    print(f"Execution time: {bigclam_time:.4f} seconds")
    print(f"Number of communities detected: {len(bigclam_communities)}")
    print(f"NMI: {bigclam_metrics['nmi']:.4f}")
    print(f"Omega Index: {bigclam_metrics['omega']:.4f}")
    print(f"F1 Score: {bigclam_metrics['f1']:.4f}")
    
    # Save results
    bigclam_result = {
        'communities': bigclam_communities,
        'execution_time': bigclam_time,
        'metrics': bigclam_metrics,
        'num_communities': len(bigclam_communities)
    }
    os.makedirs('results', exist_ok=True)
    with open('results/bigclam_result.pkl', 'wb') as f:
        pickle.dump(bigclam_result, f)
    
    # Visualize detected communities
    print("\nVisualizing BigCLAM detected communities...")
    visualize_overlapping_communities(G, bigclam_communities, figsize=(12, 10), alpha=0.6)
else:
    print("BigCLAM requires cdlib. Skipping this method.")

## 6. DEMON Algorithm

DEMON (Democratic Estimate of the Modular Organization of a Network) leverages local community information to identify overlapping communities.

In [None]:
if CDLIB_AVAILABLE:
    # Run DEMON algorithm
    print("Running DEMON algorithm...")
    
    # Run DEMON
    demon_communities, demon_time = run_demon(G, epsilon=0.25)
    
    # Evaluate against ground truth
    demon_metrics = evaluate_overlapping_communities(demon_communities, ground_truth)
    print(f"Execution time: {demon_time:.4f} seconds")
    print(f"Number of communities detected: {len(demon_communities)}")
    print(f"NMI: {demon_metrics['nmi']:.4f}")
    print(f"Omega Index: {demon_metrics['omega']:.4f}")
    print(f"F1 Score: {demon_metrics['f1']:.4f}")
    
    # Save results
    demon_result = {
        'communities': demon_communities,
        'execution_time': demon_time,
        'metrics': demon_metrics,
        'num_communities': len(demon_communities)
    }
    with open('results/demon_result.pkl', 'wb') as f:
        pickle.dump(demon_result, f)
    
    # Visualize detected communities
    print("\nVisualizing DEMON detected communities...")
    visualize_overlapping_communities(G, demon_communities, figsize=(12, 10), alpha=0.6)
else:
    print("DEMON requires cdlib. Skipping this method.")

## 7. SLPA Algorithm

Speaker-Listener Label Propagation Algorithm (SLPA) is an extension of Label Propagation that can detect overlapping communities.

In [None]:
if CDLIB_AVAILABLE:
    # Run SLPA algorithm
    print("Running SLPA algorithm...")
    
    # Run SLPA
    slpa_communities, slpa_time = run_slpa(G, t=21, r=0.1)
    
    # Evaluate against ground truth
    slpa_metrics = evaluate_overlapping_communities(slpa_communities, ground_truth)
    print(f"Execution time: {slpa_time:.4f} seconds")
    print(f"Number of communities detected: {len(slpa_communities)}")
    print(f"NMI: {slpa_metrics['nmi']:.4f}")
    print(f"Omega Index: {slpa_metrics['omega']:.4f}")
    print(f"F1 Score: {slpa_metrics['f1']:.4f}")
    
    # Save results
    slpa_result = {
        'communities': slpa_communities,
        'execution_time': slpa_time,
        'metrics': slpa_metrics,
        'num_communities': len(slpa_communities)
    }
    with open('results/slpa_result.pkl', 'wb') as f:
        pickle.dump(slpa_result, f)
    
    # Visualize detected communities
    print("\nVisualizing SLPA detected communities...")
    visualize_overlapping_communities(G, slpa_communities, figsize=(12, 10), alpha=0.6)
else:
    print("SLPA requires cdlib. Skipping this method.")

## 8. GNN-Based Overlapping Community Detection

Let's implement a GNN-based approach for detecting overlapping communities.

In [None]:
if TORCH_GEOMETRIC_AVAILABLE:
    # Run GNN-based overlapping community detection
    print("Running GNN-based overlapping community detection...")
    
    # Run the GNN model
    gnn_communities, gnn_time, gnn_losses = run_gnn_overlapping(
        G, 
        ground_truth,  # In a real scenario, you would split the data and use cross-validation
        hidden_dim=64,
        epochs=100,
        threshold=0.5
    )
    
    # Evaluate against ground truth
    gnn_metrics = evaluate_overlapping_communities(gnn_communities, ground_truth)
    print(f"Execution time: {gnn_time:.4f} seconds")
    print(f"Number of communities detected: {len(gnn_communities)}")
    print(f"NMI: {gnn_metrics['nmi']:.4f}")
    print(f"Omega Index: {gnn_metrics['omega']:.4f}")
    print(f"F1 Score: {gnn_metrics['f1']:.4f}")
    
    # Save results
    gnn_result = {
        'communities': gnn_communities,
        'execution_time': gnn_time,
        'metrics': gnn_metrics,
        'num_communities': len(gnn_communities),
        'losses': gnn_losses
    }
    with open('results/gnn_overlapping_result.pkl', 'wb') as f:
        pickle.dump(gnn_result, f)
    
    # Visualize detected communities
    print("\nVisualizing GNN detected communities...")
    visualize_overlapping_communities(G, gnn_communities, figsize=(12, 10), alpha=0.6)
    
    # Plot the training loss curve
    plt.figure(figsize=(10, 6))
    plt.plot(gnn_losses)
    plt.title('GNN Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(alpha=0.3)
    plt.show()
else:
    print("GNN-based methods require PyTorch Geometric. Skipping this method.")

## 9. Comparing Overlapping Community Detection Methods

Now, let's run a comprehensive comparison of all the overlapping community detection methods.

In [None]:
# Check if we have at least one method available
if CDLIB_AVAILABLE or TORCH_GEOMETRIC_AVAILABLE:
    # Set up the results for comparison
    results = []
    
    if CDLIB_AVAILABLE:
        # Add BigCLAM results
        if 'bigclam_metrics' in locals():
            results.append({
                'Method': 'BigCLAM',
                'Num Communities': len(bigclam_communities),
                'Avg Community Size': np.mean([len(comm) for comm in bigclam_communities]),
                'NMI': bigclam_metrics['nmi'],
                'Omega': bigclam_metrics['omega'],
                'F1': bigclam_metrics['f1'],
                'Execution Time (s)': bigclam_time
            })
        
        # Add DEMON results
        if 'demon_metrics' in locals():
            results.append({
                'Method': 'DEMON',
                'Num Communities': len(demon_communities),
                'Avg Community Size': np.mean([len(comm) for comm in demon_communities]),
                'NMI': demon_metrics['nmi'],
                'Omega': demon_metrics['omega'],
                'F1': demon_metrics['f1'],
                'Execution Time (s)': demon_time
            })
        
        # Add SLPA results
        if 'slpa_metrics' in locals():
            results.append({
                'Method': 'SLPA',
                'Num Communities': len(slpa_communities),
                'Avg Community Size': np.mean([len(comm) for comm in slpa_communities]),
                'NMI': slpa_metrics['nmi'],
                'Omega': slpa_metrics['omega'],
                'F1': slpa_metrics['f1'],
                'Execution Time (s)': slpa_time
            })
    
    if TORCH_GEOMETRIC_AVAILABLE:
        # Add GNN results
        if 'gnn_metrics' in locals():
            results.append({
                'Method': 'GNN',
                'Num Communities': len(gnn_communities),
                'Avg Community Size': np.mean([len(comm) for comm in gnn_communities]),
                'NMI': gnn_metrics['nmi'],
                'Omega': gnn_metrics['omega'],
                'F1': gnn_metrics['f1'],
                'Execution Time (s)': gnn_time
            })
    
    # Create DataFrame with results
    if results:
        # Use polars for DataFrame
        comparison_df = pl.DataFrame(results)
        
        # Display results
        print("\nComparison Results:")
        print(comparison_df)
        
        # Save the comparison results
        comparison_df.write_parquet('results/overlapping_methods_comparison.parquet', compression="zstd")
        
        # Convert to pandas for visualization
        comparison_pd = comparison_df.to_pandas()
        
        # Create visualizations
        plt.figure(figsize=(14, 10))
        
        # Plot metrics
        plt.subplot(2, 2, 1)
        comparison_pd.plot(x='Method', y=['NMI', 'Omega', 'F1'], kind='bar', ax=plt.gca())
        plt.title('Quality Metrics by Method')
        plt.ylabel('Score')
        plt.ylim(0, 1)
        plt.grid(alpha=0.3)
        
        # Plot execution time
        plt.subplot(2, 2, 2)
        comparison_pd.plot(x='Method', y='Execution Time (s)', kind='bar', ax=plt.gca(), color='green')
        plt.title('Execution Time by Method')
        plt.ylabel('Time (seconds)')
        plt.grid(alpha=0.3)
        
        # Plot number of communities
        plt.subplot(2, 2, 3)
        comparison_pd.plot(x='Method', y='Num Communities', kind='bar', ax=plt.gca(), color='orange')
        plt.title('Number of Communities Detected')
        plt.ylabel('Count')
        plt.grid(alpha=0.3)
        
        # Plot average community size
        plt.subplot(2, 2, 4)
        comparison_pd.plot(x='Method', y='Avg Community Size', kind='bar', ax=plt.gca(), color='purple')
        plt.title('Average Community Size')
        plt.ylabel('Size')
        plt.grid(alpha=0.3)
        
        plt.tight_layout()
        plt.savefig('results/overlapping_methods_comparison.png')
        plt.show()
    else:
        print("No results available for comparison.")
else:
    print("No methods were run. Please install cdlib or PyTorch Geometric to run the methods.")

## 10. Summary and Conclusions

In this notebook, we have:

1. Generated a synthetic graph with overlapping communities
2. Applied various overlapping community detection methods
   - BigCLAM (if cdlib was available)
   - DEMON (if cdlib was available)
   - SLPA (if cdlib was available)
   - GNN-based approach (if PyTorch Geometric was available)
3. Evaluated the methods using metrics designed for overlapping communities
   - NMI (Normalized Mutual Information)
   - Omega Index
   - F1 Score
4. Visualized the detected overlapping communities
5. Compared the methods in terms of quality and performance

Overlapping community detection is important for many real-world networks where nodes can naturally belong to multiple communities simultaneously, such as in social networks where people can be part of multiple social circles. The methods we've explored provide different approaches to identifying these complex community structures.

In the next notebook, we'll perform a comprehensive evaluation of all the community detection methods we've explored in this series, comparing their performance across different types of graphs and community structures.