# Pruning Connections

In this example, we will train a DenseMLP, and then prune (remove) all connections whose weights are below some threshold in absolute value.

In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax

import connex as cnx


# Initialize the model
network = cnx.nn.DenseMLP(input_size=1, output_size=2, width=128, depth=4)

# Initialize the optimizer
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(network, eqx.is_array))


# Define the loss function
@eqx.filter_value_and_grad
def loss_fn(model, x, y, keys):
    preds = jax.vmap(model)(x, keys)
    return jnp.mean((preds - y) ** 2)


# Define a single training step
@eqx.filter_jit
def step(model, opt_state, x, y, keys):
    loss, grads = loss_fn(model, x, y, keys)
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss


# Toy data
x = jnp.expand_dims(jnp.linspace(0, 2 * jnp.pi, 250), 1)
y = jnp.hstack((jnp.cos(x), jnp.sin(x)))

# Training loop
n_epochs = 500
key = jr.PRNGKey(0)
for epoch in range(n_epochs):
    *keys, key = jr.split(key, x.shape[0] + 1)
    keys = jnp.array(keys)
    network, opt_state, loss = step(network, opt_state, x, y, keys)
    print(f"Epoch: {epoch + 1}   Loss: {loss}")

Epoch: 1   Loss: 0.7124180793762207
Epoch: 2   Loss: 0.471740186214447
Epoch: 3   Loss: 0.5636301636695862
Epoch: 4   Loss: 0.4581114947795868
Epoch: 5   Loss: 0.34007692337036133
Epoch: 6   Loss: 0.32127997279167175
Epoch: 7   Loss: 0.3536826968193054
Epoch: 8   Loss: 0.345887690782547
Epoch: 9   Loss: 0.29597505927085876
Epoch: 10   Loss: 0.2544102668762207
Epoch: 11   Loss: 0.24971188604831696
Epoch: 12   Loss: 0.26612362265586853
Epoch: 13   Loss: 0.2702641189098358
Epoch: 14   Loss: 0.24971722066402435
Epoch: 15   Loss: 0.22087214887142181
Epoch: 16   Loss: 0.20580922067165375
Epoch: 17   Loss: 0.20780447125434875
Epoch: 18   Loss: 0.21007828414440155
Epoch: 19   Loss: 0.1984233856201172
Epoch: 20   Loss: 0.17706890404224396
Epoch: 21   Loss: 0.16016307473182678
Epoch: 22   Loss: 0.15511168539524078
Epoch: 23   Loss: 0.155961811542511
Epoch: 24   Loss: 0.1516830176115036
Epoch: 25   Loss: 0.13962800800800323
Epoch: 26   Loss: 0.12779438495635986
Epoch: 27   Loss: 0.123714648187160

Epoch: 220   Loss: 0.0011627969797700644
Epoch: 221   Loss: 0.001527357380837202
Epoch: 222   Loss: 0.002259867498651147
Epoch: 223   Loss: 0.0035799986217170954
Epoch: 224   Loss: 0.00548159796744585
Epoch: 225   Loss: 0.006942305713891983
Epoch: 226   Loss: 0.00605309521779418
Epoch: 227   Loss: 0.0027579243760555983
Epoch: 228   Loss: 0.0006995780277065933
Epoch: 229   Loss: 0.00195307657122612
Epoch: 230   Loss: 0.0037953979335725307
Epoch: 231   Loss: 0.003040145616978407
Epoch: 232   Loss: 0.0009854678064584732
Epoch: 233   Loss: 0.0009928109357133508
Epoch: 234   Loss: 0.002404278377071023
Epoch: 235   Loss: 0.0022513431031256914
Epoch: 236   Loss: 0.0008820308721624315
Epoch: 237   Loss: 0.0008390072616748512
Epoch: 238   Loss: 0.001778544276021421
Epoch: 239   Loss: 0.0015943855978548527
Epoch: 240   Loss: 0.0007007258245721459
Epoch: 241   Loss: 0.000819189939647913
Epoch: 242   Loss: 0.0014122298453003168
Epoch: 243   Loss: 0.0011303203646093607
Epoch: 244   Loss: 0.00059150

Epoch: 420   Loss: 0.0001888366969069466
Epoch: 421   Loss: 0.0001890507701318711
Epoch: 422   Loss: 0.00020762355416081846
Epoch: 423   Loss: 0.00019945208623539656
Epoch: 424   Loss: 0.00018090875528287143
Epoch: 425   Loss: 0.0001857195602497086
Epoch: 426   Loss: 0.00019606210116762668
Epoch: 427   Loss: 0.00018706954142544419
Epoch: 428   Loss: 0.00017589307390153408
Epoch: 429   Loss: 0.00018070625083055347
Epoch: 430   Loss: 0.00018610002007335424
Epoch: 431   Loss: 0.00017869115981739014
Epoch: 432   Loss: 0.00017168254998978227
Epoch: 433   Loss: 0.00017510591715108603
Epoch: 434   Loss: 0.00017795473104342818
Epoch: 435   Loss: 0.00017249849042855203
Epoch: 436   Loss: 0.00016764766769483685
Epoch: 437   Loss: 0.00016957851767074317
Epoch: 438   Loss: 0.00017121086420957
Epoch: 439   Loss: 0.00016743128071539104
Epoch: 440   Loss: 0.00016370552475564182
Epoch: 441   Loss: 0.00016439516912214458
Epoch: 442   Loss: 0.0001654062361922115
Epoch: 443   Loss: 0.0001629268517717719


Next, we will prune the network of all connections whose weight is less than 0.01 in absolute value.

In [2]:
# Export the network to a NetworkX weighted DiGraph
graph = network.to_networkx_weighted_digraph()

# Set threshold and get all edges whose weight is below the threshold in absolute value
threshold = 0.01
edges = graph.edges(data=True)
edges_below_threshold = [
    (u, v) for u, v, data in edges if abs(data["weight"]) < threshold
]

# Remove those connections from the network
pruned_network = cnx.remove_connections(network, edges_below_threshold)

Let's see how many connections are in the pruned network compared to the original.

In [3]:
print(f"Number of edges in original network: {network._graph.number_of_edges()}")
print(f"Number of edges in pruned network: {pruned_network._graph.number_of_edges()}")

Number of edges in original network: 99842
Number of edges in pruned network: 92059


Finally, let's see the loss for the pruned network.

In [4]:
keys = jnp.array(jr.split(key, x.shape[0]))
loss_fn(pruned_network, x, y, keys)[0]

Array(0.00117882, dtype=float32)