-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implicitly differentiate the KKT conditions #539
Comments
Hi @amdee , When you are using solvers directly from jaxopt, like what you are doing with Hope this helps, |
@zaccharieramzi and maybe @mblondel, thanks for answering this question. That's what I thought if I use a Jaxopt Solver I should not worry about manually implementing the differentiating through the KKT condition as Jaxopt Solvers are differentiable out of the box. The algorithm I am trying to implement in the above-mentioned paper is summarized in the image below.
import jax
from jax.config import config
import jax.numpy as jnp
from jaxopt import Bisection
config.update("jax_enable_x64", True)
# Implement the bracketing method using jaxopt.Bisection
# jax.jit
def find_nu_star(x, k=2, saturation=7.0, eps=1e-4, num_iter=100):
def g(nu, x, N):
return jnp.sum(jax.nn.sigmoid(x + nu)) - N
x_sorted = jnp.sort(x)[::-1]
nu_lower = -x_sorted[k-1] - saturation
nu_upper = -x_sorted[k] + saturation
# Using Bisection from jaxopt
init_params = jnp.zeros(x.shape)
bisection = Bisection(optimality_fun=g, lower=nu_lower, upper=nu_upper, maxiter=num_iter, tol=eps, check_bracket=False)
sol, _ = bisection.run(init_params, x, k)
# nu = sol.params
return sol
def calculate_y_star(x, K_value):
nu_star = find_nu_star(x, K_value)
y_star = jax.nn.sigmoid(x + nu_star)
return y_star
data = jax.random.normal(jax.random.PRNGKey(0), shape=(3, ))
n = 2 # this is k in the paper
grad_result = jax.grad(calculate_y_star)(data, n)
print(f"grad_result: {grad_result}") |
So this error is due to what I was explaining above, i.e. you are trying to differentiate through non differentiable parts of the code, i.e. the lower and upper parameters of Bisection. On top of that the sort function is also non differentiable. You could use for example With this I have the following error:
which I don't understand. Ideally, you would write an issue with the bare bone minimum to reproduce this error (i.e. no need to mention which paper this is referring to) to allow for an easier processing. |
@zaccharieramzi, @mblondel, I apologize for not being specific in describing my question about differentiating through the KKT condition. I am attempting to replicate an algorithm from the paper mentioned above.
import jax.numpy as jnp
from jax import grad
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_simplex
def binary_entropy(y):
"""Calculate the binary entropy of a vector y."""
return -jnp.sum(y * jnp.log(y) + (1 - y) * jnp.log(1 - y))
def objective(y, x):
"""Objective function to minimize."""
return -jnp.dot(x, y) - binary_entropy(y)
def projection(y, k):
"""Project onto the set {y : 1^T y = k, 0 < y < 1}."""
pass
# Initialize the Projected Gradient solver
solver = ProjectedGradient(fun=objective, projection=projection, maxiter=1000)
# Solve the problem
result = solver.run(y_init, hyperparams_proj=k, x=x).params
|
@amdee when you solve Indeed in your example once you implement the projection you will be able to get the gradient of |
Hi,
I am currently learning how to use Jaxopt and am trying to adapt the code below to utilize its features. This code is originally from research paper and was ported from PyTorch to Jax. In the forward method, a root-finding problem is solved using the bracketing method. Meanwhile, the backward method relies on implicit differentiation through the KKT conditions.
Currently, I am only making use of Jaxopt's root-finding capabilities.
I have a specific question: How can I employ Jaxopt to perform differentiation through the KKT conditions?
Please find my current implementation below:"
The text was updated successfully, but these errors were encountered: