# Freezing Parameters Example with Optax (Manual Masking)

This notebook demonstrates how to selectively freeze parameters in an Optax optimizer using a masking function.

We show how to:
- Create a boolean mask matching the parameter tree
- Freeze one layer while training another
- Verify that frozen parameters remain unchanged after optimization


In [12]:
import jax
import jax.numpy as jnp
import optax


In [13]:
#1. Define dummy model parameters
params = {
    'layer1': {'w': jnp.ones((2, 2)), 'b': jnp.zeros((2,))},
    'layer2': {'w': jnp.ones((2, 2)), 'b': jnp.zeros((2,))}
}

In [14]:
#2. Create boolean mask matching the parameter tree
# False => frozen (no updates), True => trainable (updates applied)
mask = {
    'layer1': {'w': False, 'b': False},
    'layer2': {'w': True,  'b': True}
}

In [15]:
#3. Define a simple loss function
def loss_fn(params, inputs, targets):
    """
    Computes mean squared error between model predictions and targets.
    """
    w1, b1 = params['layer1']['w'], params['layer1']['b']
    w2, b2 = params['layer2']['w'], params['layer2']['b']

    hidden = jnp.dot(inputs, w1) + b1
    hidden = jax.nn.relu(hidden)

    output = jnp.dot(hidden, w2) + b2
    return jnp.mean((output - targets) ** 2)


In [16]:
#4. Create optimizer
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(params)

In [17]:
#5. Dummy data for demonstration
key = jax.random.PRNGKey(0)
inputs = jax.random.normal(key, (5, 2))
targets = jax.random.normal(key, (5, 2))


In [18]:
#6. Utility to apply mask to updates
def apply_mask(updates, mask):
    """
    Zeroes out updates for parameters where mask is False (frozen).
    """
    return jax.tree_util.tree_map(lambda g, m: g if m else jnp.zeros_like(g), updates, mask)

In [19]:
#7. Perform one optimization step
def update(params, opt_state, inputs, targets):
    """
    Computes gradients, applies masked updates, and returns new parameters.
    """
    loss, grads = jax.value_and_grad(loss_fn)(params, inputs, targets)
    updates, opt_state = optimizer.update(grads, opt_state)
    masked_updates = apply_mask(updates, mask)
    new_params = optax.apply_updates(params, masked_updates)
    return new_params, opt_state, loss

In [20]:
#8. Run the update step
new_params, opt_state, loss = update(params, opt_state, inputs, targets)

In [21]:
#9. Output results
print("\nlayer1 (FROZEN)")
print("Initial weights:\n", params['layer1']['w'])
print("Updated weights (should be unchanged):\n", new_params['layer1']['w'])

print("\nlayer2 (TRAINED)")
print("Initial weights:\n", params['layer2']['w'])
print("Updated weights (should be changed):\n", new_params['layer2']['w'])

print("\nFinal Loss after one step:", loss)


layer1 (FROZEN)
Initial weights:
 [[1. 1.]
 [1. 1.]]
Updated weights (should be unchanged):
 [[1. 1.]
 [1. 1.]]

layer2 (TRAINED)
Initial weights:
 [[1. 1.]
 [1. 1.]]
Updated weights (should be changed):
 [[0.99000007 0.99000007]
 [0.99000007 0.99000007]]

Final Loss after one step: 6.2967615


In [22]:
# === 10. Assertions for verification ===
assert jnp.allclose(new_params['layer1']['w'], params['layer1']['w']), "layer1 weights should remain unchanged"
print("\nAssertion passed: 'layer1' weights are unchanged as expected.")



Assertion passed: 'layer1' weights are unchanged as expected.


#Conclusion

We have demonstrated:
- How to selectively freeze parameters in Optax
- How to verify that frozen layers remain unchanged after optimization

Masking is a powerful tool for controlling which parameters should be updated during training.

