# BRIDGE: Block Rewiring from Inference-Derived Graph Ensembles

This notebook demonstrates the key features of the BRIDGE package, including:
1. Graph rewiring using Stochastic Block Models
2. GNN model training and evaluation
3. Sensitivity analysis and SNR calculations
4. Synthetic graph generation
5. Hyperparameter optimization

First, let's install the required dependencies:

In [1]:
!pip install dgl numpy optuna ortools pandas scipy scikit-learn torch tqdm pyyaml


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [None]:
import torch
import dgl
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from bridge.models import GCN, SelectiveGCN
from bridge.rewiring import run_bridge_pipeline
from bridge.utils import generate_all_symmetric_permutation_matrices
from bridge.sensitivity import (
    estimate_snr_theorem,
    estimate_sensitivity_autograd,
    run_sensitivity_experiment,
    plot_snr_vs_homophily
)
from bridge.optimization import objective_gcn, objective_rewiring, collect_float_metrics
import optuna

ImportError: cannot import name 'optimize_hyperparameters' from 'bridge.optimization' (/Users/jonathanrubin/Desktop/Work/GNN-Analysis/BRIDGE/BRIDGE/bridge/optimization/__init__.py)

## 1. Loading and Preparing Data

Let's start by loading the Cora dataset and preparing it for our experiments:

In [5]:
# Load Cora dataset
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]

# Print graph statistics
print(f"Number of nodes: {g.number_of_nodes()}")
print(f"Number of edges: {g.number_of_edges()}")
print(f"Number of node features: {g.ndata['feat'].shape[1]}")
print(f"Number of classes: {len(torch.unique(g.ndata['label']))}")

# Move graph to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = g.to(device)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Number of nodes: 2708
Number of edges: 10556
Number of node features: 1433
Number of classes: 7


## 2. Graph Rewiring with BRIDGE

Now let's demonstrate the BRIDGE rewiring technique:

In [6]:
# Generate permutation matrices for rewiring
k = len(torch.unique(g.ndata['label']))
all_matrices = generate_all_symmetric_permutation_matrices(k)
P_k = all_matrices[0]  # Choose the first permutation matrix

# Run the rewiring pipeline
results = run_bridge_pipeline(
    g=g,
    P_k=P_k,
    h_feats_gcn=64,
    n_layers_gcn=2,
    dropout_p_gcn=0.5,
    model_lr_gcn=1e-3,
    h_feats_selective=64,
    n_layers_selective=2,
    dropout_p_selective=0.5,
    model_lr_selective=1e-3,
    num_graphs=1,
    device=device
)

# Print results
print(f"Base GCN accuracy: {results['cold_start']['test_acc']:.4f}")
print(f"Selective GCN accuracy: {results['selective']['test_acc']:.4f}")

Base GCN accuracy: 0.7820
Selective GCN accuracy: 0.7710


## 3. Sensitivity Analysis

Let's analyze the Signal-to-Noise Ratio (SNR) and sensitivity of the graph neural network:

In [12]:
from bridge.sensitivity.models import LinearGCN
from bridge.sensitivity.feature_gen import create_feature_generator

# Create a feature generator with our parameters
feature_generator = create_feature_generator(
    sigma_intra=0.1*torch.eye(5),
    sigma_inter=-0.05*torch.eye(5),
    tau=1.0*torch.eye(5),
    eta=1.0*torch.eye(5)
)

# Create a simple linear GCN model
model = LinearGCN(
    in_feats=5,  # feature dimension
    hidden_feats=64,
    out_feats=len(torch.unique(g.ndata['label']))
).to(device)

# Run sensitivity experiment
results = run_sensitivity_experiment(
    model=model,
    graph=g,
    feature_generator=feature_generator,
    in_feats=5,
    sigma_intra=0.1*torch.eye(5),
    sigma_inter=-0.05*torch.eye(5),
    tau=1.0*torch.eye(5),
    eta=1.0*torch.eye(5),
    num_acc_repeats=10,
    num_monte_carlo_samples=100,
    num_epochs=200,
    device=device
)

# Visualize results
plot_snr_vs_homophily(results)

  mu_concatenated = np.random.multivariate_normal(
Training repetitions:  10%|█         | 1/10 [00:20<03:06, 20.75s/it]


KeyboardInterrupt: 

## 4. Synthetic Graph Generation

Let's create and analyze a synthetic graph with controlled properties:

In [15]:
from bridge.datasets import SyntheticGraphDataset

syn_dataset = SyntheticGraphDataset(
    n=1000,  # number of nodes
    k=4,     # number of classes
    h=0.7,   # homophily ratio
    d_mean=3,  # mean degree scaling factor
    sigma_intra_scalar=0.1,  # intra-class covariance
    sigma_inter_scalar=-0.05,  # inter-class covariance
    tau_scalar=1.0,  # global covariance
    eta_scalar=1.0,  # noise covariance
    in_feats=16,  # feature dimension
    sym=True  # undirected graph
)
syn_g = syn_dataset[0]

  mu_concatenated = np.random.multivariate_normal(


## 5. Hyperparameter Optimization

Let's optimize the hyperparameters of our GNN model using Optuna:


In [18]:
from bridge.models import GCN, SelectiveGCN
import optuna   
from bridge.optimization import objective_gcn
# Create Optuna study
study = optuna.create_study(direction='maximize')

# Define the optimization objective
def objective(trial):
    return objective_gcn(
        trial,
        g,
        n_epochs=200,
        device=device
    )

# Run optimization
study.optimize(objective, n_trials=50)

# Print best parameters
print("Best parameters:")
print(study.best_params)
print(f"Best accuracy: {study.best_value:.4f}")

# Visualize optimization results
optuna.visualization.plot_optimization_history(study)
optuna.visualization.plot_param_importances(study)

[I 2025-03-17 15:48:55,946] A new study created in memory with name: no-name-478605da-47b7-4348-bb59-dd581342d0c4
 20%|██        | 2/10 [00:05<00:20,  2.56s/it]
[W 2025-03-17 15:49:01,079] Trial 0 failed with parameters: {'h_feats': 128, 'n_layers': 3, 'dropout_p': 0.1169322582368547, 'model_lr': 0.001948916870958118, 'weight_decay': 0.00031204616012533857} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/Users/jonathanrubin/.pyenv/versions/3.10.14/lib/python3.10/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/var/folders/gz/smbwgd590tsg7lzrhft0dk8w0000gn/T/ipykernel_47887/254347896.py", line 9, in objective
    return objective_gcn(
  File "/Users/jonathanrubin/Desktop/Work/GNN-Analysis/BRIDGE/BRIDGE/bridge/optimization/optuna_objectives.py", line 199, in objective_gcn
    train_acc, val_acc, test_acc, train_acc_ci, val_acc_ci, test_acc_ci = train_and_evaluate_gcn(
  File "/Use

KeyboardInterrupt: 

## 6. Model Comparison

Let's compare different GNN architectures on our dataset:

In [None]:
def train_and_evaluate(model, g, epochs=200):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    
    best_val_acc = 0
    best_test_acc = 0
    
    for epoch in tqdm(range(epochs)):
        model.train()
        optimizer.zero_grad()
        out = model(g)
        loss = criterion(out[g.ndata['train_mask']], g.ndata['label'][g.ndata['train_mask']])
        loss.backward()
        optimizer.step()
        
        # Evaluate
        model.eval()
        with torch.no_grad():
            out = model(g)
            pred = out.argmax(dim=1)
            val_acc = (pred[g.ndata['val_mask']] == g.ndata['label'][g.ndata['val_mask']]).float().mean()
            test_acc = (pred[g.ndata['test_mask']] == g.ndata['label'][g.ndata['test_mask']]).float().mean()
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_test_acc = test_acc
    
    return best_test_acc

# Initialize models
models = {
    'GCN': GCN(g.ndata['feat'].shape[1], 64, len(torch.unique(g.ndata['label'])), 2),
    'SelectiveGCN': SelectiveGCN(g.ndata['feat'].shape[1], 64, len(torch.unique(g.ndata['label'])), 2)
}

# Train and evaluate each model
results = {}
for name, model in models.items():
    print(f"Training {name}...")
    model = model.to(device)
    acc = train_and_evaluate(model, g)
    results[name] = acc
    print(f"{name} accuracy: {acc:.4f}")

## 7. Visualization of Results

Let's create some visualizations to compare the results:

In [None]:
# Plot model comparison
plt.figure(figsize=(10, 6))
plt.bar(results.keys(), results.values())
plt.title('Model Comparison')
plt.ylabel('Test Accuracy')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# Plot SNR vs Homophily
plt.figure(figsize=(10, 6))
plot_snr_vs_homophily(results)
plt.title('SNR vs Homophily')
plt.xlabel('Homophily')
plt.ylabel('Signal-to-Noise Ratio')
plt.tight_layout()
plt.show()