Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implicitly differentiate the KKT conditions #539

Open
amdee opened this issue Sep 22, 2023 · 5 comments
Open

Implicitly differentiate the KKT conditions #539

amdee opened this issue Sep 22, 2023 · 5 comments
Labels
question Further information is requested

Comments

@amdee
Copy link

amdee commented Sep 22, 2023

Hi,

I am currently learning how to use Jaxopt and am trying to adapt the code below to utilize its features. This code is originally from research paper and was ported from PyTorch to Jax. In the forward method, a root-finding problem is solved using the bracketing method. Meanwhile, the backward method relies on implicit differentiation through the KKT conditions.

Currently, I am only making use of Jaxopt's root-finding capabilities.
I have a specific question: How can I employ Jaxopt to perform differentiation through the KKT conditions?

Please find my current implementation below:"

import jax
import jax.numpy as jnp
import flax.linen as nn
from jaxopt import Bisection
import numpy as np

@jax.custom_vjp
def LML_jax(x, N, eps, n_iter, branch=None, verbose=0):
    y, res = lml_forward(x, N, eps, n_iter, branch, verbose)
    return y, res

def f(nu, x, N):
    return jnp.sum(jax.nn.sigmoid(x + nu)) - N

def lml_forward(x, N, eps, n_iter, branch, verbose):
    branch = branch if branch is not None else 10 if jax.devices()[0].platform == 'cpu' else 100
    nx = x.shape[0]
    if nx <= N:
        return jnp.ones(nx, dtype=x.dtype), None

    x_sorted = jnp.sort(x)[::-1]
    nu_lower = -x_sorted[N-1] - 7.
    nu_upper = -x_sorted[N] + 7.

    # Using Bisection from jaxopt
    bisection = Bisection(optimality_fun=f, lower=nu_lower, upper=nu_upper, tol=eps, check_bracket=False)
    sol = bisection.run(x=x, N=N)
    nu = sol.params

    y = jax.nn.sigmoid(x + nu)
    return y, (y, nu, x, N)

def lml_backward(res, grad_output):
    y, nu, x, N = res
    if y is None:
        return (jnp.zeros_like(x), None, None, None, None, None)

    Hinv = 1. / (1. / y + 1. / (1. - y))
    dnu = jnp.sum(Hinv * grad_output) / jnp.sum(Hinv)
    dx = -Hinv * (-grad_output + dnu)
    return (dx, None, None, None, None, None)

LML_jax.defvjp(lml_forward, lml_backward)

class LML(nn.Module):
    N: int = 1
    eps: float = 1e-4
    n_iter: int = 100
    branch: int = None
    verbose: int = 0

    @nn.compact
    def __call__(self, x):
        return LML_jax(x, N=self.N, eps=self.eps, n_iter=self.n_iter, branch=self.branch, verbose=self.verbose)

if __name__ == '__main__':
    m = 10
    n = 2
    np.random.seed(0)
    x = np.random.random(m)
    x_jax_unbatched = jnp.array(x)
    x_jax_batched = jnp.stack([x_jax_unbatched, x_jax_unbatched])
    x = jnp.stack([x, x])
    model = LML(N=n)
    key1, key2 = jax.random.split(jax.random.PRNGKey(1))
    dummy_input = jax.random.normal(key1, (n, m))
    params = model.init(jax.random.PRNGKey(0), dummy_input)
    LML_state = model.bind(params)
    lml = lambda x_input: LML_state(x_input)[0]

    y_unbatched = lml(x_jax_unbatched)
    y_batched = jax.vmap(lml)(x_jax_batched)
    y_unbatched_check = np.array(y_unbatched, copy=False)
    y_batched_check = np.array(y_batched, copy=False)

    vyo_unbatched, dyo_unbatched = jax.value_and_grad(lml)(x_jax_unbatched)
    vyo_batched, dyo_batched = jax.vmap(jax.value_and_grad(lml))(x_jax_batched)
    print(f"value of y vyo_unbatched: {vyo_unbatched}\ngradient of y dy0 vyo_unbatched: {dyo_unbatched}")
    print(f"\nvalue of y vyo_batched: {vyo_batched}\ngradient of y dy0 vyo_batched: {dyo_batched}") 
@mblondel mblondel added the question Further information is requested label Oct 2, 2023
@zaccharieramzi
Copy link
Contributor

Hi @amdee ,

When you are using solvers directly from jaxopt, like what you are doing with Bisection you don't need to worry about how to differentiate through the KKT conditions: you can use the solver as is and think of it as a differentiable operation.
What you can play on is whether to use implicit differentiation or not (in your case you probably want to), and how to solve the inverse jacobian (but you can also keep the defaults for this).
In other words you shouldn't need to define a custom vjp. I think in your case the problem is that you are using non differentiable operations inside lml_forward like the sorting before even creating the optimizer.
You also cannot differentiate w.r.t. parameters of the optimizer as what you are doing here with lower and upper: basically you should be agnostic to optimizer's parameters in your differentiation.

Hope this helps,
Cheers

@amdee
Copy link
Author

amdee commented Nov 12, 2023

@zaccharieramzi and maybe @mblondel, thanks for answering this question. That's what I thought if I use a Jaxopt Solver I should not worry about manually implementing the differentiating through the KKT condition as Jaxopt Solvers are differentiable out of the box. The algorithm I am trying to implement in the above-mentioned paper is summarized in the image below.image
I removed all the flax dependence on the above code and left a bare-bone Jax/Jaxopt code. See the below code but I am running into the following issue

  1. CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.
  2. I am not sure what I am doing wrong. I have also read the following issues, Gradient through closure #285 and Problem differentiating through solver.run in OptaxSolver #31. Any help will be appreciated.
import jax
from jax.config import config
import jax.numpy as jnp
from jaxopt import Bisection
config.update("jax_enable_x64", True)

# Implement the bracketing method using jaxopt.Bisection
# jax.jit
def find_nu_star(x, k=2, saturation=7.0, eps=1e-4, num_iter=100):
    def g(nu, x, N):
        return jnp.sum(jax.nn.sigmoid(x + nu)) - N

    x_sorted = jnp.sort(x)[::-1]
    nu_lower = -x_sorted[k-1] - saturation
    nu_upper = -x_sorted[k] + saturation

    # Using Bisection from jaxopt
    init_params = jnp.zeros(x.shape)
    bisection = Bisection(optimality_fun=g, lower=nu_lower, upper=nu_upper, maxiter=num_iter, tol=eps, check_bracket=False)
    sol, _ = bisection.run(init_params, x, k)
    # nu = sol.params
    return sol

def calculate_y_star(x, K_value):
    nu_star = find_nu_star(x, K_value)
    y_star = jax.nn.sigmoid(x + nu_star)
    return y_star

data = jax.random.normal(jax.random.PRNGKey(0), shape=(3, ))
n = 2 # this is k in the paper
grad_result = jax.grad(calculate_y_star)(data, n)
print(f"grad_result: {grad_result}")

@zaccharieramzi
Copy link
Contributor

So this error is due to what I was explaining above, i.e. you are trying to differentiate through non differentiable parts of the code, i.e. the lower and upper parameters of Bisection. On top of that the sort function is also non differentiable.

You could use for example x_sorted = jax.lax.stop_gradient(jnp.sort(x)[::-1]).

With this I have the following error:

ValueError: Shape of cotangent input to vjp pullback function (3,) must be the same as the shape of corresponding primal input ().

which I don't understand. Ideally, you would write an issue with the bare bone minimum to reproduce this error (i.e. no need to mention which paper this is referring to) to allow for an easier processing.

@amdee
Copy link
Author

amdee commented Nov 16, 2023

@zaccharieramzi, @mblondel, I apologize for not being specific in describing my question about differentiating through the KKT condition. I am attempting to replicate an algorithm from the paper mentioned above.

  • In summary, the algorithm addresses a convex-constrained optimization problem by calculating both the forward and backward passes. The forward pass uses a root-finding method, while the backward pass involves differentiation through the KKT condition of the convex-constrained optimization problem below. This explains the use of the argsort function, as it doesn’t affect the backward pass, according to my understanding.

  • Returning to my question, how can I use Jaxopt to differentiate through the KKT condition of the optimization problem below without manually calculating the KKT condition?

$$\min_{0 &lt; y &lt; 1} -x^T y - H_b(y) \quad \text{subject to} \quad 1^T y = k$$
Where:

  • $y$ is the vector variable for which we are solving.
  • $x$ is a given vector.
  • $H_b(y)$ is the binary entropy function, defined as:

$$ H_b(y) = -\sum_{i}(y_i \log y_i + (1 - y_i) \log(1 - y_i)) $$

  • I am not sure how to do the constraints and correct me if I am wrong. If I solve the optimization problem I should get the gradients for free? if you can point me out on how to implement the constraints, I would appreciate it. Here is what I have so far
import jax.numpy as jnp
from jax import grad
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_simplex

def binary_entropy(y):
    """Calculate the binary entropy of a vector y."""
    return -jnp.sum(y * jnp.log(y) + (1 - y) * jnp.log(1 - y))

def objective(y, x):
    """Objective function to minimize."""
    return -jnp.dot(x, y) - binary_entropy(y)

def projection(y, k):
    """Project onto the set {y : 1^T y = k, 0 < y < 1}."""
    pass

# Initialize the Projected Gradient solver
solver = ProjectedGradient(fun=objective, projection=projection, maxiter=1000)

# Solve the problem
result = solver.run(y_init, hyperparams_proj=k, x=x).params

@zaccharieramzi
Copy link
Contributor

@amdee when you solve $y^\star = \argmin_y f(x, y)$, you can get $\frac{\partial y^\star}{\partial x}$.
What I meant earlier is that you cannot differentiate through hyperparameters of the optimization algorithm as you were trying to do in your first example (in addition to trying to differentiate through sorting).

Indeed in your example once you implement the projection you will be able to get the gradient of result w.r.t. x. The projection can be just a clipping between 0 and 1, followed by a normalization with the sum and multiplication by k, but I am saying this without thinking too much about it. I would like to point out that answering questions like how to implement the projection is out of the scope of this project imho.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants