# Graph Theory for Machine Learning
## Graph Neural Networks and Network Analysis

Welcome to the **mathematics of relationships and connections**! Graph theory provides the mathematical framework for understanding networks, relationships, and structured data that goes beyond traditional grid-based representations.

### What You'll Master
By the end of this notebook, you'll understand:
1. **Graph fundamentals** - Vertices, edges, and graph properties
2. **Graph matrices** - Adjacency, Laplacian, and incidence matrices
3. **Spectral graph theory** - Eigenvalues reveal graph structure
4. **Graph neural networks** - Learning on non-Euclidean data
5. **Network analysis** - Centrality, communities, and connectivity
6. **Random walks on graphs** - Diffusion and PageRank algorithms

### Why This is Revolutionary
- **Social networks** - Understanding influence and information spread
- **Molecular analysis** - Drug discovery through graph neural networks
- **Recommendation systems** - User-item interaction graphs
- **Knowledge graphs** - Reasoning over structured knowledge

### Real-World Applications
- **Social media**: Friend recommendation, community detection
- **Biology**: Protein folding, molecular property prediction
- **Transportation**: Route optimization, traffic flow analysis
- **Finance**: Fraud detection, risk assessment through transaction networks

Let's explore the beautiful mathematics of networks and connections! 🕸️

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from scipy import sparse, linalg
from scipy.sparse.linalg import eigsh
from sklearn.datasets import make_classification
from sklearn.manifold import SpectralEmbedding
from sklearn.cluster import SpectralClustering, KMeans
from sklearn.metrics import silhouette_score
import pandas as pd
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("Set1")
np.random.seed(42)

print("🕸️ Graph Theory toolkit loaded!")
print("Ready to analyze networks and relationships!")

## 1. Graph Fundamentals and Representations

### What is a Graph?
A **graph** G = (V, E) consists of:
- **V**: Set of vertices (nodes) representing entities
- **E**: Set of edges connecting vertices, representing relationships

### Types of Graphs
1. **Undirected**: Edges have no direction (friendship networks)
2. **Directed**: Edges have direction (web page links, Twitter follows)
3. **Weighted**: Edges have weights (distance, strength of connection)
4. **Unweighted**: All edges are equal

### Graph Matrices
**Adjacency Matrix A**:
```
A[i,j] = { w_ij  if edge (i,j) exists with weight w_ij
         { 0     otherwise
```

**Degree Matrix D**:
```
D[i,i] = degree of vertex i = Σⱼ A[i,j]
D[i,j] = 0 for i ≠ j
```

**Graph Laplacian L**:
```
L = D - A
```

**Normalized Laplacian L̃**:
```
L̃ = D^(-1/2) L D^(-1/2) = I - D^(-1/2) A D^(-1/2)
```

### Why Laplacians Matter
- **Spectrum reveals structure**: Eigenvalues encode connectivity
- **Number of connected components**: Number of zero eigenvalues
- **Graph cuts**: Second smallest eigenvalue (Fiedler value)
- **Diffusion processes**: Heat equation on graphs

In [None]:
def demonstrate_graph_fundamentals():
    """Explore fundamental concepts in graph theory"""
    
    print("📊 Graph Theory Fundamentals")
    print("=" * 30)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Basic graph types
    print("\n1. Graph Types and Representations")
    
    # Create different types of graphs
    graphs = {}
    
    # Undirected graph
    G_undirected = nx.Graph()
    G_undirected.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (1, 3), (0, 2)])
    graphs['Undirected'] = G_undirected
    
    # Directed graph
    G_directed = nx.DiGraph()
    G_directed.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 0), (1, 3)])
    graphs['Directed'] = G_directed
    
    # Weighted graph
    G_weighted = nx.Graph()
    G_weighted.add_weighted_edges_from([(0, 1, 0.5), (1, 2, 1.2), (2, 3, 0.8), 
                                       (3, 0, 1.5), (1, 3, 0.3)])
    graphs['Weighted'] = G_weighted
    
    # Plot different graph types
    positions = {0: (0, 1), 1: (1, 1), 2: (1, 0), 3: (0, 0)}
    
    for i, (graph_type, G) in enumerate(graphs.items()):
        if i < 3:
            ax = axes[0, i]
            
            if graph_type == 'Weighted':
                # Draw with edge weights
                nx.draw(G, positions, ax=ax, with_labels=True, node_color='lightblue',
                       node_size=500, font_size=16, font_weight='bold')
                edge_labels = nx.get_edge_attributes(G, 'weight')
                nx.draw_networkx_edge_labels(G, positions, edge_labels, ax=ax)
            else:
                nx.draw(G, positions, ax=ax, with_labels=True, node_color='lightblue',
                       node_size=500, font_size=16, font_weight='bold',
                       arrows=True if graph_type == 'Directed' else False)
            
            ax.set_title(f'{graph_type} Graph')
            ax.axis('off')
    
    print(f"   Undirected: {G_undirected.number_of_nodes()} nodes, {G_undirected.number_of_edges()} edges")
    print(f"   Directed: {G_directed.number_of_nodes()} nodes, {G_directed.number_of_edges()} edges")
    print(f"   Weighted: {G_weighted.number_of_nodes()} nodes, {G_weighted.number_of_edges()} edges")
    
    # 2. Graph matrices
    print("\n2. Graph Matrix Representations")
    
    # Use the undirected graph for matrix analysis
    G = G_undirected
    n = G.number_of_nodes()
    
    # Adjacency matrix
    A = nx.adjacency_matrix(G).toarray()
    
    # Degree matrix
    degrees = dict(G.degree())
    D = np.diag([degrees[i] for i in range(n)])
    
    # Laplacian matrix
    L = D - A
    
    # Normalized Laplacian
    D_inv_sqrt = np.diag([1/np.sqrt(degrees[i]) if degrees[i] > 0 else 0 for i in range(n)])
    L_norm = D_inv_sqrt @ L @ D_inv_sqrt
    
    # Visualize matrices
    matrices = [('Adjacency (A)', A), ('Degree (D)', D), ('Laplacian (L)', L)]
    
    for i, (name, matrix) in enumerate(matrices):
        if i < 3:
            ax = axes[1, i]
            im = ax.imshow(matrix, cmap='RdBu', aspect='equal')
            ax.set_title(name)
            ax.set_xticks(range(n))
            ax.set_yticks(range(n))
            
            # Add values to matrix
            for row in range(n):
                for col in range(n):
                    ax.text(col, row, f'{matrix[row, col]:.0f}', 
                           ha='center', va='center', 
                           color='white' if abs(matrix[row, col]) > matrix.max()/2 else 'black')
            
            plt.colorbar(im, ax=ax, shrink=0.8)
    
    print(f"   Matrix dimensions: {n}×{n}")
    print(f"   Adjacency matrix sum: {A.sum()} (twice the number of edges)")
    print(f"   Degree matrix trace: {np.trace(D)} (sum of all degrees)")
    print(f"   Laplacian properties: symmetric, positive semidefinite")
    
    plt.tight_layout()
    plt.show()
    
    return G, A, L, L_norm

G_demo, A_demo, L_demo, L_norm_demo = demonstrate_graph_fundamentals()

## 2. Spectral Graph Theory

### The Magic of Graph Spectra
The **eigenvalues and eigenvectors** of graph matrices reveal deep structural properties:

**Laplacian Spectrum Properties**:
1. **λ₁ = 0**: Always zero with eigenvector of all ones
2. **λ₂ (Fiedler value)**: Controls connectivity - larger means more connected
3. **Number of 0 eigenvalues**: Number of connected components
4. **Largest eigenvalue**: Related to maximum degree

**Graph Cuts and Clustering**:
The Fiedler vector (eigenvector of λ₂) provides optimal graph bisection:
```
Ratio Cut = Σᵢⱼ A[i,j] |fᵢ - fⱼ|² / (|S| × |S̄|)
```

**Cheeger's Inequality**:
Connects algebraic (eigenvalues) and combinatorial (cuts) properties:
```
λ₂/2 ≤ h(G) ≤ √(2λ₂)
```
where h(G) is the Cheeger constant (isoperimetric number).

### Applications in Machine Learning
- **Spectral clustering**: Use eigenvectors for dimensionality reduction
- **Graph signal processing**: Fourier transform on graphs
- **Manifold learning**: Laplacian eigenmaps
- **Graph neural networks**: Spectral convolutions

In [None]:
def demonstrate_spectral_graph_theory():
    """Explore spectral properties of graphs"""
    
    print("🌈 Spectral Graph Theory: Eigenvalues Reveal Structure")
    print("=" * 58)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Eigenvalues of different graph structures
    print("\n1. Graph Structure and Eigenvalue Patterns")
    
    graph_types = {
        'Path': nx.path_graph(10),
        'Cycle': nx.cycle_graph(10),
        'Complete': nx.complete_graph(10),
        'Star': nx.star_graph(9),  # 9 + 1 center = 10 nodes
        'Grid': nx.grid_2d_graph(3, 3),
        'Random': nx.erdos_renyi_graph(10, 0.3, seed=42)
    }
    
    eigenvalue_data = {}
    
    for name, G in graph_types.items():
        # Get Laplacian matrix
        L = nx.laplacian_matrix(G).toarray()
        
        # Compute eigenvalues
        eigenvals = np.linalg.eigvals(L)
        eigenvals = np.sort(eigenvals)
        eigenvalue_data[name] = eigenvals
        
        print(f"   {name} graph: λ₂ = {eigenvals[1]:.3f}, max λ = {eigenvals[-1]:.3f}")
    
    # Plot eigenvalue spectra
    for i, (name, eigenvals) in enumerate(list(eigenvalue_data.items())[:3]):
        ax = axes[0, i]
        ax.stem(range(len(eigenvals)), eigenvals, basefmt=' ')
        ax.set_xlabel('Eigenvalue Index')
        ax.set_ylabel('Eigenvalue')
        ax.set_title(f'{name} Graph Spectrum')
        ax.grid(True, alpha=0.3)
        
        # Highlight the Fiedler value
        ax.scatter([1], [eigenvals[1]], color='red', s=100, zorder=5, label='Fiedler value')
        ax.legend()
    
    # 2. Graph clustering using spectral methods
    print("\n2. Spectral Clustering")
    print("   Using Fiedler vector for graph partitioning")
    
    # Create a graph with clear community structure
    G_communities = nx.Graph()
    
    # Community 1: nodes 0-4
    for i in range(5):
        for j in range(i+1, 5):
            if np.random.random() > 0.3:  # Dense connections within community
                G_communities.add_edge(i, j)
    
    # Community 2: nodes 5-9
    for i in range(5, 10):
        for j in range(i+1, 10):
            if np.random.random() > 0.3:  # Dense connections within community
                G_communities.add_edge(i, j)
    
    # Few connections between communities
    G_communities.add_edge(2, 7)  # Bridge connection
    G_communities.add_edge(4, 5)  # Another bridge
    
    # Get Laplacian and compute Fiedler vector
    L_communities = nx.laplacian_matrix(G_communities).toarray()
    eigenvals, eigenvecs = np.linalg.eigh(L_communities)
    fiedler_vector = eigenvecs[:, 1]  # Second eigenvector
    
    # Partition based on sign of Fiedler vector
    partition = fiedler_vector > 0
    
    # Plot original graph with spectral clustering
    pos = nx.spring_layout(G_communities, seed=42)
    
    # Color nodes by partition
    node_colors = ['red' if partition[i] else 'blue' for i in range(len(partition))]
    
    nx.draw(G_communities, pos, ax=axes[1, 0], node_color=node_colors,
           with_labels=True, node_size=300, font_size=10)
    axes[1, 0].set_title('Spectral Clustering Result')
    
    print(f"   Fiedler value: {eigenvals[1]:.3f}")
    print(f"   Partition 1 (red): {np.where(partition)[0]}")
    print(f"   Partition 2 (blue): {np.where(~partition)[0]}")
    
    # 3. Graph signal processing
    print("\n3. Graph Signal Processing")
    print("   Fourier transform on graphs using Laplacian eigenvectors")
    
    # Create a path graph for clear demonstration
    G_path = nx.path_graph(20)
    L_path = nx.laplacian_matrix(G_path).toarray()
    
    # Compute full eigendecomposition
    eigenvals_path, eigenvecs_path = np.linalg.eigh(L_path)
    
    # Create a signal on the graph (smooth vs. rough)
    nodes = np.arange(20)
    
    # Smooth signal
    signal_smooth = np.sin(2 * np.pi * nodes / 20) + 0.5 * np.sin(4 * np.pi * nodes / 20)
    
    # Add some noise (rough signal)
    signal_noisy = signal_smooth + 0.3 * np.random.randn(20)
    
    # Graph Fourier transform
    signal_freq_smooth = eigenvecs_path.T @ signal_smooth
    signal_freq_noisy = eigenvecs_path.T @ signal_noisy
    
    # Plot signals
    axes[1, 1].plot(nodes, signal_smooth, 'b-o', linewidth=2, markersize=4, label='Smooth signal')
    axes[1, 1].plot(nodes, signal_noisy, 'r-s', linewidth=1, markersize=3, alpha=0.7, label='Noisy signal')
    axes[1, 1].set_xlabel('Node Index')
    axes[1, 1].set_ylabel('Signal Value')
    axes[1, 1].set_title('Signals on Path Graph')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # Plot frequency domain
    axes[1, 2].stem(eigenvals_path, np.abs(signal_freq_smooth), basefmt=' ', 
                   label='Smooth signal', linefmt='b-', markerfmt='bo')
    axes[1, 2].stem(eigenvals_path, np.abs(signal_freq_noisy), basefmt=' ',
                   label='Noisy signal', linefmt='r-', markerfmt='rs')
    axes[1, 2].set_xlabel('Eigenvalue (Frequency)')
    axes[1, 2].set_ylabel('|Fourier Coefficient|')
    axes[1, 2].set_title('Graph Fourier Transform')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    # Graph signal denoising
    # Keep only low-frequency components (small eigenvalues)
    cutoff = 2.0
    signal_freq_filtered = signal_freq_noisy.copy()
    signal_freq_filtered[eigenvals_path > cutoff] = 0
    
    # Inverse graph Fourier transform
    signal_denoised = eigenvecs_path @ signal_freq_filtered
    
    print(f"   Original signal energy: {np.linalg.norm(signal_smooth):.3f}")
    print(f"   Noisy signal energy: {np.linalg.norm(signal_noisy):.3f}")
    print(f"   Denoised signal energy: {np.linalg.norm(signal_denoised):.3f}")
    print(f"   Denoising error: {np.linalg.norm(signal_denoised - signal_smooth):.3f}")
    
    plt.tight_layout()
    plt.show()
    
    return G_communities, fiedler_vector, partition

G_comm, fiedler_vec, spectral_partition = demonstrate_spectral_graph_theory()

## 3. Graph Neural Networks (GNNs)

### The Revolution: Learning on Graphs
Traditional neural networks work on **Euclidean data** (images, sequences). Graph Neural Networks extend this to **non-Euclidean data** where relationships matter more than spatial structure.

### Core GNN Operations
**Message Passing Framework**:
1. **Message**: m^(l)_{ij} = Message(h^(l)_i, h^(l)_j, e_{ij})
2. **Aggregation**: m^(l)_i = Aggregate({m^(l)_{ji} : j ∈ N(i)})
3. **Update**: h^(l+1)_i = Update(h^(l)_i, m^(l)_i)

### Types of GNNs
**Graph Convolutional Networks (GCN)**:
```
H^(l+1) = σ(D̃^(-1/2) Ã D̃^(-1/2) H^(l) W^(l))
```
where Ã = A + I (adjacency + self-loops)

**GraphSAGE (Sample and Aggregate)**:
```
h^(l+1)_v = σ(W^(l) · CONCAT(h^(l)_v, AGG({h^(l)_u : u ∈ N(v)})))
```

**Graph Attention Networks (GAT)**:
```
α_{ij} = softmax(LeakyReLU(a^T [W h_i || W h_j]))
h'_i = σ(Σⱼ α_{ij} W h_j)
```

### Why GNNs Work
- **Permutation invariance**: Output doesn't depend on node ordering
- **Locality**: Information propagates through graph structure
- **Weight sharing**: Same function applied at each node
- **Inductive bias**: Graph structure guides learning

In [None]:
def demonstrate_graph_neural_networks():
    """Explore Graph Neural Network concepts and implementations"""
    
    print("🧠 Graph Neural Networks: Learning on Non-Euclidean Data")
    print("=" * 57)
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Simple GCN implementation
    print("\n1. Graph Convolutional Network (GCN) Implementation")
    
    class SimpleGCN:
        def __init__(self, input_dim, hidden_dim, output_dim):
            """Simple GCN implementation"""
            self.W1 = np.random.randn(input_dim, hidden_dim) * 0.1
            self.W2 = np.random.randn(hidden_dim, output_dim) * 0.1
            self.b1 = np.zeros(hidden_dim)
            self.b2 = np.zeros(output_dim)
        
        def relu(self, x):
            return np.maximum(0, x)
        
        def softmax(self, x):
            exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
            return exp_x / np.sum(exp_x, axis=1, keepdims=True)
        
        def forward(self, A_norm, X):
            """Forward pass: H^(l+1) = σ(A_norm H^(l) W^(l))"""
            # First GCN layer
            H1 = A_norm @ X @ self.W1 + self.b1
            H1 = self.relu(H1)
            
            # Second GCN layer
            H2 = A_norm @ H1 @ self.W2 + self.b2
            return self.softmax(H2)
    
    # Create a synthetic graph for node classification
    np.random.seed(42)
    n_nodes = 20
    n_features = 5
    n_classes = 3
    
    # Create graph with community structure
    G_gcn = nx.Graph()
    
    # Three communities
    communities = [list(range(0, 7)), list(range(7, 14)), list(range(14, 20))]
    
    # Dense connections within communities
    for community in communities:
        for i in community:
            for j in community:
                if i < j and np.random.random() > 0.4:
                    G_gcn.add_edge(i, j)
    
    # Sparse connections between communities
    G_gcn.add_edge(3, 10)  # Bridge 1
    G_gcn.add_edge(6, 15)  # Bridge 2
    G_gcn.add_edge(12, 17) # Bridge 3
    
    # Create adjacency matrix
    A = nx.adjacency_matrix(G_gcn).toarray()
    
    # Add self-loops and normalize
    A_tilde = A + np.eye(n_nodes)
    D_tilde = np.diag(np.sum(A_tilde, axis=1))
    D_inv_sqrt = np.diag(1.0 / np.sqrt(np.diag(D_tilde)))
    A_norm = D_inv_sqrt @ A_tilde @ D_inv_sqrt
    
    # Create node features (community-based)
    X = np.random.randn(n_nodes, n_features)
    for i, community in enumerate(communities):
        # Add community-specific bias to features
        community_bias = np.random.randn(n_features) * 2
        for node in community:
            X[node] += community_bias
    
    # True labels based on communities
    y_true = np.zeros(n_nodes, dtype=int)
    for i, community in enumerate(communities):
        for node in community:
            y_true[node] = i
    
    # Initialize and run GCN
    gcn = SimpleGCN(n_features, 8, n_classes)
    y_pred_probs = gcn.forward(A_norm, X)
    y_pred = np.argmax(y_pred_probs, axis=1)
    
    # Visualize graph with predictions
    pos = nx.spring_layout(G_gcn, seed=42)
    colors = ['red', 'blue', 'green']
    
    # True labels
    node_colors_true = [colors[y_true[i]] for i in range(n_nodes)]
    nx.draw(G_gcn, pos, ax=axes[0, 0], node_color=node_colors_true,
           with_labels=True, node_size=300, font_size=8)
    axes[0, 0].set_title('True Community Labels')
    
    # Predicted labels
    node_colors_pred = [colors[y_pred[i]] for i in range(n_nodes)]
    nx.draw(G_gcn, pos, ax=axes[0, 1], node_color=node_colors_pred,
           with_labels=True, node_size=300, font_size=8)
    axes[0, 1].set_title('GCN Predicted Labels')
    
    accuracy = np.mean(y_true == y_pred)
    print(f"   Nodes: {n_nodes}, Features: {n_features}, Classes: {n_classes}")
    print(f"   GCN accuracy: {accuracy:.3f}")
    print(f"   True communities: {communities}")
    
    # 2. Attention mechanism visualization
    print("\n2. Graph Attention Mechanism")
    print("   How nodes decide which neighbors to focus on")
    
    def compute_attention_weights(h_i, h_j, W, a):
        """Compute attention weights between nodes i and j"""
        # Transform features
        Wh_i = W @ h_i
        Wh_j = W @ h_j
        
        # Attention mechanism
        concat_features = np.concatenate([Wh_i, Wh_j])
        attention_score = a.T @ concat_features
        return attention_score
    
    # Simple attention example with 5 nodes
    n_att_nodes = 5
    feature_dim = 3
    hidden_dim = 4
    
    # Random node features
    H = np.random.randn(n_att_nodes, feature_dim)
    # Attention parameters
    W_att = np.random.randn(hidden_dim, feature_dim) * 0.1
    a_att = np.random.randn(2 * hidden_dim) * 0.1
    
    # Create a small complete graph for attention demo
    G_att = nx.complete_graph(n_att_nodes)
    A_att = nx.adjacency_matrix(G_att).toarray()
    
    # Compute attention weights
    attention_matrix = np.zeros((n_att_nodes, n_att_nodes))
    
    for i in range(n_att_nodes):
        attention_scores = []
        neighbors = [j for j in range(n_att_nodes) if A_att[i, j] == 1 or i == j]
        
        for j in neighbors:
            score = compute_attention_weights(H[i], H[j], W_att, a_att)
            attention_scores.append((j, score))
        
        # Softmax normalization
        scores = np.array([score for _, score in attention_scores])
        scores_normalized = np.exp(scores) / np.sum(np.exp(scores))
        
        for idx, (j, _) in enumerate(attention_scores):
            attention_matrix[i, j] = scores_normalized[idx]
    
    # Visualize attention matrix
    im = axes[0, 2].imshow(attention_matrix, cmap='Blues', aspect='auto')
    axes[0, 2].set_xlabel('Target Node')
    axes[0, 2].set_ylabel('Source Node')
    axes[0, 2].set_title('Attention Weight Matrix')
    axes[0, 2].set_xticks(range(n_att_nodes))
    axes[0, 2].set_yticks(range(n_att_nodes))
    
    # Add values to attention matrix
    for i in range(n_att_nodes):
        for j in range(n_att_nodes):
            if attention_matrix[i, j] > 0:
                axes[0, 2].text(j, i, f'{attention_matrix[i, j]:.2f}',
                               ha='center', va='center', fontsize=8,
                               color='white' if attention_matrix[i, j] > 0.5 else 'black')
    
    plt.colorbar(im, ax=axes[0, 2], shrink=0.8)
    
    print(f"   Attention nodes: {n_att_nodes}")
    print(f"   Each row sums to 1: {[f'{row.sum():.3f}' for row in attention_matrix]}")
    
    # 3. Message passing visualization
    print("\n3. Message Passing in GNNs")
    print("   How information flows through the graph")
    
    # Create a simple star graph for clear message passing demo
    G_star = nx.star_graph(4)  # Center node + 4 outer nodes
    pos_star = nx.spring_layout(G_star, seed=42)
    
    # Simulate message passing for 3 layers
    n_star_nodes = 5
    initial_features = np.random.randn(n_star_nodes, 1)
    A_star = nx.adjacency_matrix(G_star).toarray().astype(float)
    
    # Add self-loops and normalize
    A_star_norm = A_star + np.eye(n_star_nodes)
    D_star = np.diag(np.sum(A_star_norm, axis=1))
    A_star_norm = np.linalg.inv(D_star) @ A_star_norm
    
    # Track features through layers
    features_layers = [initial_features.copy()]
    current_features = initial_features.copy()
    
    for layer in range(3):
        # Simple message passing: aggregate neighbors
        current_features = A_star_norm @ current_features
        features_layers.append(current_features.copy())
    
    # Plot message passing evolution
    layer_indices = [0, 1, 2, 3]
    for layer_idx in layer_indices[:3]:
        ax = axes[1, layer_idx]
        
        # Node colors based on feature values
        feature_values = features_layers[layer_idx].flatten()
        node_colors = plt.cm.RdYlBu((feature_values - feature_values.min()) / 
                                   (feature_values.max() - feature_values.min()))
        
        nx.draw(G_star, pos_star, ax=ax, node_color=node_colors,
               with_labels=True, node_size=400, font_size=12)
        ax.set_title(f'Layer {layer_idx} Features')
        
        # Add feature values as text
        for node, (x, y) in pos_star.items():
            ax.text(x, y-0.15, f'{feature_values[node]:.2f}',
                   ha='center', va='top', fontsize=8, weight='bold')
    
    print(f"   Star graph: 1 center node, 4 outer nodes")
    print(f"   Initial center feature: {initial_features[0, 0]:.3f}")
    print(f"   Final center feature: {features_layers[-1][0, 0]:.3f}")
    print(f"   Information aggregated from all neighbors through layers")
    
    plt.tight_layout()
    plt.show()
    
    return G_gcn, gcn, attention_matrix

G_gnn, gnn_model, att_matrix = demonstrate_graph_neural_networks()