Skip to content

crowsonkb/mdmm-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mdmm-jax

Gradient-based constrained optimization for JAX (implementation by Katherine Crowson).

The Modified Differential Multiplier Method was proposed in Platt and Barr (1988), "Constrained Differential Optimization".

MDMM minimizes a main objective f(x) subject to equality (g(x) = 0) and inequality (h(x) ≥ 0) constraints, where the constraints can be arbitrary differentiable functions of your parameters and data.

Quick usage

Creating an equality constraint and its trainable parameters:

import mdmm_jax

constraint = mdmm_jax.eq(my_function)
# Internally calls my_function(main_params, x)
mdmm_params = constraint.init(main_params, x)

Constructing the loss function for the augmented Lagrangian system incorporating the constraint loss (the loss will become far less interpretable so you should return the original loss as part of an auxiliary return value):

def system(params, x):
    main_params, mdmm_params = params
    loss = loss_fn(main_params, x)
    mdmm_loss, inf = constraint.loss(mdmm_params, main_params, x)
    return loss + mdmm_loss, (loss, inf)

Turning an Optax base optimizer into an MDMM constrained optimizer:

optimizer = optax.chain(
    optax.sgd(1e-3),
    mdmm_jax.optax_prepare_update(),
)
params = main_params, mdmm_params
opt_state = optimizer.init(params)

About

Gradient-based constrained optimization for JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages