# Pruning Connections

In this example, we will train a DenseMLP, prune (remove) all connections whose weights are below some threshold in absolute value, and continue training with the pruned network.

In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import optax

import connex as cnx


# Initialize the model
network = cnx.nn.DenseMLP(input_size=1, output_size=2, width=128, depth=4)

# Initialize the optimizer
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(network, eqx.is_array))


# Define the loss function
@eqx.filter_value_and_grad
def loss_fn(model, x, y, keys):
    preds = jax.vmap(model)(x, keys)
    return jnp.mean((preds - y) ** 2)


# Define a single training step
@eqx.filter_jit
def step(model, opt_state, x, y, keys):
    loss, grads = loss_fn(model, x, y, keys)
    updates, opt_state = optim.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss


# Toy data
x = jnp.expand_dims(jnp.linspace(0, 2 * jnp.pi, 250), 1)
y = jnp.hstack((jnp.cos(x), jnp.sin(x)))

# Training loop
n_epochs = 500
key = jr.PRNGKey(0)
for epoch in range(n_epochs):
    *keys, key = jr.split(key, x.shape[0] + 1)
    keys = jnp.array(keys)
    network, opt_state, loss = step(network, opt_state, x, y, keys)
    print(f"Epoch: {epoch + 1}   Loss: {loss}")

Epoch: 1   Loss: 0.7124180793762207
Epoch: 2   Loss: 0.471740186214447
Epoch: 3   Loss: 0.5636301636695862
Epoch: 4   Loss: 0.4581114947795868
Epoch: 5   Loss: 0.34007692337036133
Epoch: 6   Loss: 0.32127997279167175
Epoch: 7   Loss: 0.3536826968193054
Epoch: 8   Loss: 0.345887690782547
Epoch: 9   Loss: 0.29597505927085876
Epoch: 10   Loss: 0.2544102668762207
Epoch: 11   Loss: 0.24971188604831696
Epoch: 12   Loss: 0.26612362265586853
Epoch: 13   Loss: 0.2702641189098358
Epoch: 14   Loss: 0.24971722066402435
Epoch: 15   Loss: 0.22087214887142181
Epoch: 16   Loss: 0.20580922067165375
Epoch: 17   Loss: 0.20780447125434875
Epoch: 18   Loss: 0.21007828414440155
Epoch: 19   Loss: 0.1984233856201172
Epoch: 20   Loss: 0.17706890404224396
Epoch: 21   Loss: 0.16016307473182678
Epoch: 22   Loss: 0.15511168539524078
Epoch: 23   Loss: 0.155961811542511
Epoch: 24   Loss: 0.1516830176115036
Epoch: 25   Loss: 0.13962800800800323
Epoch: 26   Loss: 0.12779438495635986
Epoch: 27   Loss: 0.123714648187160

Epoch: 211   Loss: 0.0009491420933045447
Epoch: 212   Loss: 0.0009288907749578357
Epoch: 213   Loss: 0.0009103944757953286
Epoch: 214   Loss: 0.0008943004067987204
Epoch: 215   Loss: 0.0008825139957480133
Epoch: 216   Loss: 0.0008769791456870735
Epoch: 217   Loss: 0.0008840184891596437
Epoch: 218   Loss: 0.0009137787274084985
Epoch: 219   Loss: 0.0009908685460686684
Epoch: 220   Loss: 0.0011627969797700644
Epoch: 221   Loss: 0.001527357380837202
Epoch: 222   Loss: 0.002259867498651147
Epoch: 223   Loss: 0.0035799986217170954
Epoch: 224   Loss: 0.00548159796744585
Epoch: 225   Loss: 0.006942305713891983
Epoch: 226   Loss: 0.00605309521779418
Epoch: 227   Loss: 0.0027579243760555983
Epoch: 228   Loss: 0.0006995780277065933
Epoch: 229   Loss: 0.00195307657122612
Epoch: 230   Loss: 0.0037953979335725307
Epoch: 231   Loss: 0.003040145616978407
Epoch: 232   Loss: 0.0009854678064584732
Epoch: 233   Loss: 0.0009928109357133508
Epoch: 234   Loss: 0.002404278377071023
Epoch: 235   Loss: 0.002251

Epoch: 420   Loss: 0.0001888366969069466
Epoch: 421   Loss: 0.0001890507701318711
Epoch: 422   Loss: 0.00020762355416081846
Epoch: 423   Loss: 0.00019945208623539656
Epoch: 424   Loss: 0.00018090875528287143
Epoch: 425   Loss: 0.0001857195602497086
Epoch: 426   Loss: 0.00019606210116762668
Epoch: 427   Loss: 0.00018706954142544419
Epoch: 428   Loss: 0.00017589307390153408
Epoch: 429   Loss: 0.00018070625083055347
Epoch: 430   Loss: 0.00018610002007335424
Epoch: 431   Loss: 0.00017869115981739014
Epoch: 432   Loss: 0.00017168254998978227
Epoch: 433   Loss: 0.00017510591715108603
Epoch: 434   Loss: 0.00017795473104342818
Epoch: 435   Loss: 0.00017249849042855203
Epoch: 436   Loss: 0.00016764766769483685
Epoch: 437   Loss: 0.00016957851767074317
Epoch: 438   Loss: 0.00017121086420957
Epoch: 439   Loss: 0.00016743128071539104
Epoch: 440   Loss: 0.00016370552475564182
Epoch: 441   Loss: 0.00016439516912214458
Epoch: 442   Loss: 0.0001654062361922115
Epoch: 443   Loss: 0.0001629268517717719


Next, we will prune the network of all connections whose weight is less than 0.1 in absolute value.

In [2]:
# Export the network to a NetworkX weighted DiGraph
graph = network.to_networkx_weighted_digraph()

# Set threshold and get all edges whose weight is below the threshold in absolute value
threshold = 0.1
edges = graph.edges(data=True)
edges_below_threshold = [
    (u, v) for u, v, data in edges if abs(data["weight"]) < threshold
]

# Remove those connections from the network
pruned_network = cnx.remove_connections(network, edges_below_threshold)

Let's see how many connections are in the pruned network compared to the original.

In [3]:
print(f"Number of edges in original network: {network._graph.number_of_edges()}")
print(f"Number of edges in pruned network: {pruned_network._graph.number_of_edges()}")

Number of edges in original network: 99842
Number of edges in pruned network: 32691


Now, let's continue training with the pruned network.

In [4]:
# Re-initialize the optimizer, since the architecture has changed
optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(pruned_network, eqx.is_array))

# Training loop
n_epochs = 500
key = jr.PRNGKey(0)
for epoch in range(n_epochs):
    *keys, key = jr.split(key, x.shape[0] + 1)
    keys = jnp.array(keys)
    pruned_network, opt_state, loss = step(pruned_network, opt_state, x, y, keys)
    print(f"Epoch: {epoch + 1}   Loss: {loss}")

Epoch: 1   Loss: 0.44865846633911133
Epoch: 2   Loss: 0.297453373670578
Epoch: 3   Loss: 0.2075083702802658
Epoch: 4   Loss: 0.17005939781665802
Epoch: 5   Loss: 0.16803623735904694
Epoch: 6   Loss: 0.178756982088089
Epoch: 7   Loss: 0.18404240906238556
Epoch: 8   Loss: 0.17665764689445496
Epoch: 9   Loss: 0.1583200991153717
Epoch: 10   Loss: 0.13481727242469788
Epoch: 11   Loss: 0.11238483339548111
Epoch: 12   Loss: 0.09569007158279419
Epoch: 13   Loss: 0.08682015538215637
Epoch: 14   Loss: 0.08502709120512009
Epoch: 15   Loss: 0.08732927590608597
Epoch: 16   Loss: 0.08992750197649002
Epoch: 17   Loss: 0.08984372019767761
Epoch: 18   Loss: 0.08592817932367325
Epoch: 19   Loss: 0.07887118309736252
Epoch: 20   Loss: 0.07050615549087524
Epoch: 21   Loss: 0.06289120763540268
Epoch: 22   Loss: 0.057517021894454956
Epoch: 23   Loss: 0.05483124405145645
Epoch: 24   Loss: 0.05418158322572708
Epoch: 25   Loss: 0.05419432744383812
Epoch: 26   Loss: 0.05344879627227783
Epoch: 27   Loss: 0.051126

Epoch: 211   Loss: 0.0004889486590400338
Epoch: 212   Loss: 0.0004856523300986737
Epoch: 213   Loss: 0.00048154551768675447
Epoch: 214   Loss: 0.0004779237788170576
Epoch: 215   Loss: 0.0004745451151393354
Epoch: 216   Loss: 0.00047061595250852406
Epoch: 217   Loss: 0.00046736691729165614
Epoch: 218   Loss: 0.000463794480310753
Epoch: 219   Loss: 0.00046033455873839557
Epoch: 220   Loss: 0.00045719873742200434
Epoch: 221   Loss: 0.00045368174323812127
Epoch: 222   Loss: 0.0004505550896283239
Epoch: 223   Loss: 0.00044734077528119087
Epoch: 224   Loss: 0.0004441167984623462
Epoch: 225   Loss: 0.00044116273056715727
Epoch: 226   Loss: 0.00043799146078526974
Epoch: 227   Loss: 0.00043504146742634475
Epoch: 228   Loss: 0.00043209639261476696
Epoch: 229   Loss: 0.0004291234945412725
Epoch: 230   Loss: 0.00042633965495042503
Epoch: 231   Loss: 0.00042344536632299423
Epoch: 232   Loss: 0.000420688244048506
Epoch: 233   Loss: 0.0004179648240096867
Epoch: 234   Loss: 0.00041522574611008167
Epoc

Epoch: 411   Loss: 0.00020207458874210715
Epoch: 412   Loss: 0.0002014394267462194
Epoch: 413   Loss: 0.0002008067531278357
Epoch: 414   Loss: 0.00020017723727505654
Epoch: 415   Loss: 0.0001995499769691378
Epoch: 416   Loss: 0.00019892584532499313
Epoch: 417   Loss: 0.00019830396922770888
Epoch: 418   Loss: 0.00019768507627304643
Epoch: 419   Loss: 0.00019706928287632763
Epoch: 420   Loss: 0.0001964567054528743
Epoch: 421   Loss: 0.00019584513211157173
Epoch: 422   Loss: 0.00019523708033375442
Epoch: 423   Loss: 0.0001946312841027975
Epoch: 424   Loss: 0.00019402915495447814
Epoch: 425   Loss: 0.00019343008170835674
Epoch: 426   Loss: 0.00019283218716736883
Epoch: 427   Loss: 0.00019223646086174995
Epoch: 428   Loss: 0.00019164411060046405
Epoch: 429   Loss: 0.0001910543505800888
Epoch: 430   Loss: 0.00019046735542360693
Epoch: 431   Loss: 0.00018988156807608902
Epoch: 432   Loss: 0.00018930055375676602
Epoch: 433   Loss: 0.00018871987413149327
Epoch: 434   Loss: 0.0001881423086160794

After training for 500 more epochs, the pruned network achieves almost the same training loss as the original network, with less than a third of the weights.