In [1]:
import os, glob, urllib.request, gzip, shutil, copy
import networkx as nx, numpy as np
import random
import torch
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import from_networkx
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from scipy.stats import pearsonr
import matplotlib.pyplot as plt

In [2]:
def monte_carlo_influence(G, p=0.4, n_sims=1000):
    spreads = []
    if G.number_of_nodes() == 0:
        return 0.0
    for _ in range(n_sims):
        seeds = [np.random.choice(list(G.nodes()))]
        active = set(seeds)
        new_active = set(seeds)
        while new_active:
            next_active = set()
            for u in new_active:
                for v in G.neighbors(u):
                    if v not in active and np.random.rand() < p:
                        next_active.add(v)
            active |= next_active
            new_active = next_active
        spreads.append(len(active))
    return np.mean(spreads)


In [3]:
def sample_synthetic():
    choice = np.random.choice(['ER', 'BA', 'WS'])
    n = np.random.randint(120, 200)
    if choice == 'ER':
        p = np.random.uniform(0.05, 0.15)
        return nx.erdos_renyi_graph(n, p)
    elif choice == 'BA':
        m = np.random.randint(3, 7)
        return nx.barabasi_albert_graph(n, m)
    else: # WS
        k = np.random.randint(4, 10)
        beta = np.random.uniform(0.05, 0.25)
        return nx.watts_strogatz_graph(n, k, beta)

In [4]:
def node_features(G):
    if G.number_of_nodes() == 0:
        return torch.empty((0, 3), dtype=torch.float)
    deg = np.array([d for _, d in G.degree()])
    cc  = np.array(list(nx.clustering(G).values()))
    avg_nd = np.array(list(nx.average_neighbor_degree(G).values()))
    def z(x): return (x - x.mean()) / (x.std() + 1e-8)
    X = np.stack([z(deg), z(cc), z(avg_nd)], axis=1)
    return torch.tensor(X, dtype=torch.float)

In [5]:
class InfluenceGNN(torch.nn.Module):
    def __init__(self, in_feats=3, hidden_dim=64):
        super().__init__()
        self.conv1 = GCNConv(in_feats, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 1)
        )

    def forward(self, x, edge_index, batch, edge_weight=None):
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.shape[1], device=x.device)
        x = torch.relu(self.conv1(x, edge_index, edge_weight=edge_weight))
        x = torch.relu(self.conv2(x, edge_index, edge_weight=edge_weight))
        x = global_mean_pool(x, batch)
        return self.mlp(x).squeeze()


In [6]:
def get_trained_model(num_graphs=1000, epochs=501):
    print("--- Training Surrogate GNN (to be frozen) ---")
    dataset = []
    synthetic_graphs = [sample_synthetic() for _ in range(num_graphs)]
    for G in synthetic_graphs:
        try:
            y = monte_carlo_influence(G, p=0.4, n_sims=400) / G.number_of_nodes()
            data = from_networkx(G)
            data.x = node_features(G)
            data.y = torch.tensor([y], dtype=torch.float)
            dataset.append(data)
        except Exception as e:
            pass # Ignore disconnected graphs
    
    split = int(0.8 * len(dataset))
    train_dataset, test_dataset = dataset[:split], dataset[split:]
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    test_loader  = DataLoader(test_dataset, batch_size=16)
    
    model = InfluenceGNN()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index, batch.batch)
            loss = F.mse_loss(out, batch.y)
            loss.backward()
            optimizer.step()
        if epoch % 100 == 0:
            model.eval()
            preds, trues = [], []
            with torch.no_grad():
                for batch in test_loader:
                    out = model(batch.x, batch.edge_index, batch.batch)
                    preds += out.cpu().tolist()
                    trues += batch.y.cpu().tolist()
            corr, _ = pearsonr(preds, trues)
            print(f"Epoch {epoch} | Test Corr: {corr:.3f}")
    print("--- Surrogate Training Finished ---")
    model.eval() # Freeze model
    return model


In [7]:
model = get_trained_model()

--- Training Surrogate GNN (to be frozen) ---
Epoch 0 | Test Corr: -0.669
Epoch 100 | Test Corr: 0.959
Epoch 200 | Test Corr: 0.962
Epoch 300 | Test Corr: 0.969
Epoch 400 | Test Corr: 0.973
Epoch 500 | Test Corr: 0.975
--- Surrogate Training Finished ---


In [8]:
def prepare_edge_mappings(G, data):
    undirected_edges = sorted([tuple(sorted(e)) for e in G.edges()])
    undirected_edge_to_index = {e: i for i, e in enumerate(undirected_edges)}
    index_map = {}
    ei = data.edge_index.cpu().numpy()
    for col in range(ei.shape[1]):
        u, v = int(ei[0, col]), int(ei[1, col])
        index_map[(u, v)] = col
    return undirected_edge_to_index, index_map

In [9]:
def make_edge_weight(G, w, data, undirected_edge_to_index, index_map):
    E_dir = data.edge_index.size(1)
    ew = torch.zeros(E_dir, dtype=torch.float32)
    for (u, v), idx in index_map.items():
        key = tuple(sorted((u, v)))
        ew[idx] = w[undirected_edge_to_index[key]]
    return ew

In [10]:
def continuous_relaxation_optimize(G0, model, lam, steps=200, lr=0.03):
    model.eval()
    undirected_edges = sorted([tuple(sorted(e)) for e in G0.edges()])
    num_edges = len(undirected_edges)
    w = torch.full((num_edges,), 0.99, requires_grad=True)
    optimizer = torch.optim.Adam([w], lr=lr)
    
    data = from_networkx(G0)
    data.x = node_features(G0)
    batch = torch.zeros(data.num_nodes, dtype=torch.long)
    undirected_edge_to_index, index_map = prepare_edge_mappings(G0, data)

    for step in range(steps):
        optimizer.zero_grad()
        ew = make_edge_weight(G0, w, data, undirected_edge_to_index, index_map)
        pred = model(data.x, data.edge_index, batch, edge_weight=ew)
        loss = -pred + lam * torch.mean(torch.abs(1 - w))
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            w.clamp_(0.0, 1.0)
    return w.detach().clone()

In [11]:
def discrete_prune(G0, w_opt, tau=0.5):
    undirected_edges = sorted([tuple(sorted(e)) for e in G0.edges()])
    G_pruned = nx.Graph()
    G_pruned.add_nodes_from(G0.nodes())
    kept_edges = [e for e, val in zip(undirected_edges, w_opt) if val > tau]
    G_pruned.add_edges_from(kept_edges)
    return G_pruned

In [12]:
# STEP 1: DEFINE A CLEAN EVALUATION GRAPH SET
eval_graphs = [sample_synthetic() for _ in range(10)]
print(f"Generated {len(eval_graphs)} evaluation graphs.")

# STEP 2: DEFINE LAMBDA SWEEP
lambda_list = [0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.2, 0.5]

results = {lam: {'edges_kept': [], 'influence': []} for lam in lambda_list}
tau = 0.5 # Discretization threshold

# STEP 3: REPEAT RELAXATION FOR EACH LAMBDA


Generated 10 evaluation graphs.


In [13]:
for i, G0 in enumerate(eval_graphs):
    print(f"\n--- Processing Graph {i+1}/{len(eval_graphs)} ({G0.number_of_nodes()} nodes, {G0.number_of_edges()} edges) ---")
    if G0.number_of_edges() == 0: continue
    
    # Compute original influence
    norm_infl_orig = monte_carlo_influence(G0) / G0.number_of_nodes()
    
    for lam in lambda_list:
        # a. Run continuous relaxation
        w_opt = continuous_relaxation_optimize(G0, model, lam=lam)

        # b. Convert continuous weights to discrete edges
        G_pruned = discrete_prune(G0, w_opt, tau)

        # c. Compute TRUE influence
        norm_infl_pruned = monte_carlo_influence(G_pruned) / G0.number_of_nodes()

        # d. Record results
        edges_kept_frac = G_pruned.number_of_edges() / G0.number_of_edges()
        results[lam]['edges_kept'].append(edges_kept_frac)
        results[lam]['influence'].append(norm_infl_pruned)
        print(f"  lambda={lam:.4f} | Edges: {edges_kept_frac:.2f} | Influence: {norm_infl_pruned:.3f}")



--- Processing Graph 1/10 (187 nodes, 1262 edges) ---
  lambda=0.0005 | Edges: 0.30 | Influence: 0.414
  lambda=0.0010 | Edges: 0.30 | Influence: 0.440
  lambda=0.0050 | Edges: 0.30 | Influence: 0.415
  lambda=0.0100 | Edges: 0.31 | Influence: 0.461
  lambda=0.0500 | Edges: 0.33 | Influence: 0.524
  lambda=0.1000 | Edges: 0.34 | Influence: 0.560
  lambda=0.2000 | Edges: 0.44 | Influence: 0.793
  lambda=0.5000 | Edges: 0.77 | Influence: 0.974

--- Processing Graph 2/10 (194 nodes, 1721 edges) ---
  lambda=0.0005 | Edges: 0.27 | Influence: 0.578
  lambda=0.0010 | Edges: 0.26 | Influence: 0.563
  lambda=0.0050 | Edges: 0.28 | Influence: 0.602
  lambda=0.0100 | Edges: 0.28 | Influence: 0.590
  lambda=0.0500 | Edges: 0.32 | Influence: 0.670
  lambda=0.1000 | Edges: 0.35 | Influence: 0.779
  lambda=0.2000 | Edges: 0.46 | Influence: 0.907
  lambda=0.5000 | Edges: 0.76 | Influence: 0.986

--- Processing Graph 3/10 (198 nodes, 594 edges) ---
  lambda=0.0005 | Edges: 0.48 | Influence: 0.036
  l

In [14]:
mean_results = {'edges_kept': [], 'influence': []}
for lam in lambda_list:
    mean_results['edges_kept'].append(np.mean(results[lam]['edges_kept']))
    mean_results['influence'].append(np.mean(results[lam]['influence']))



In [15]:
def random_edge_removal_baseline(G0, num_edges_to_keep):
    if num_edges_to_keep > G0.number_of_edges():
        return G0.copy()
    
    edges_to_remove = random.sample(list(G0.edges()), G0.number_of_edges() - num_edges_to_keep)
    G_baseline = G0.copy()
    G_baseline.remove_edges_from(edges_to_remove)
    return G_baseline

baseline_results = {edges_kept: [] for edges_kept in mean_results['edges_kept']}



In [19]:
for edges_kept_frac in mean_results['edges_kept']:
     for G0 in eval_graphs:
         if G0.number_of_edges() > 0:
             num_edges_to_keep = int(edges_kept_frac * G0.number_of_edges())
             G_baseline = random_edge_removal_baseline(G0, num_edges_to_keep)
             norm_infl_baseline = monte_carlo_influence(G_baseline) / G0.number_of_nodes()
             baseline_results[edges_kept_frac].append(norm_infl_baseline)

In [21]:
mean_baseline_influence = [
    np.mean(baseline_results[ek]) if baseline_results[ek] else 0
    for ek in mean_results['edges_kept']
]

plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(10, 7))

# Plot Continuous Relaxation Curve
ax.plot(
    mean_results['edges_kept'],
    mean_results['influence'],
    marker='o',
    linestyle='-',
    label='Continuous Relaxation'
)

# Plot Baseline Curve
ax.plot(
    mean_results['edges_kept'],
    mean_baseline_influence,
    marker='x',
    linestyle='--',
    label='Random Edge Removal'
)

# Annotate lambda values
for i, lam in enumerate(lambda_list):
    ax.annotate(
        f"$\\lambda={lam}$",
        (mean_results['edges_kept'][i], mean_results['influence'][i]),
        textcoords="offset points",
        xytext=(0, 10),
        ha='center',
        fontsize=9,
        color='darkgreen'
    )

ax.set_title('Sparsity vs. Influence Tradeoff', fontsize=16)
ax.set_xlabel('Fraction of Edges Kept', fontsize=12)
ax.set_ylabel('Normalized Influence Spread', fontsize=12)
ax.legend(fontsize=12)
ax.set_xlim(left=0)
ax.set_ylim(bottom=0)

# Save figure
plt.tight_layout()
plt.savefig("sparsityVInfluence.png", dpi=300, bbox_inches="tight")
plt.close(fig)
