Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Constraint violation causes L-BGFS-B to fail #590

Open
gulls-on-parade opened this issue Apr 17, 2024 · 1 comment
Open

Constraint violation causes L-BGFS-B to fail #590

gulls-on-parade opened this issue Apr 17, 2024 · 1 comment

Comments

@gulls-on-parade
Copy link

I believe the line search internally used by jaxopt.LBFGSB is not respecting the bounds that are passed here, causing the objective function to generate NaNs and the overall optimization problem to fail. I am unsure if this a bug, or if I am doing something wrong. Any guidance is much appreciated.

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count={}'.format(os.cpu_count())
import jax as jx
from jax import jit
import jax.numpy as jnp
import jaxopt

#%% Helper functions

@jit
def normal_vol(strike, atmf, t, alpha, beta, rho, nu):
    eps = 1e-07  # Numerical tolerance
    f_av = jnp.sqrt(atmf * strike)

    fmkr = jnp.select([(jnp.abs(atmf - strike) > eps) & (jnp.abs(1 - beta) > eps), 
                       (jnp.abs(atmf - strike) > eps) & (jnp.abs(1 - beta) <= eps), 
                       jnp.abs(atmf - strike) <= eps],
                      [(1 - beta) * (atmf - strike) / (atmf**(1 - beta) - strike**(1 - beta)),
                       (atmf - strike) / jnp.log(atmf / strike),
                       strike**beta],
                      jnp.nan)
    
    zeta = nu * (atmf - strike) / (alpha * f_av**beta)
    
    zxz = jnp.select([jnp.abs(zeta) > eps, 
                      jnp.abs(zeta) <= eps],
                     [zeta / jnp.log(jnp.abs(((1 - 2 * rho * zeta + zeta**2)**.5 + zeta - rho) / (1 - rho))),
                      1.],
                     jnp.nan)
    
    a = - beta * (2 - beta) * alpha**2 / (24 * f_av**(2 - 2 * beta))
    b = rho * alpha * nu * beta / (4 * f_av**(1 - beta))
    c = (2 - 3 * rho**2) * nu**2 / 24

    vol = alpha * fmkr * zxz * (1 + (a + b + c) * t)

    return vol


@jit
def _obj(params, args):
    """Objective function to minimize the squared error between implied and model vols."""
    expiry, tail, strikes, vols, atmf, beta = args
    alpha, rho, nu = params
    vol_fitted = jx.vmap(normal_vol, (0, None, None, None, None, None, None))(strikes, atmf, expiry, alpha, beta, rho, nu)
    error = (vol_fitted - vols) * 1e4
    return jnp.sum(error**2)

#%% Example problem

data = [(0.09041095890410959,
  0.2465753424657534,
  jnp.array([0.0824076, 0.0849076, 0.0874076, 0.0899076, 0.0924076, 0.0949076,
         0.0974076, 0.0999076, 0.1024076, 0.1049076, 0.1074076, 0.1099076,
         0.1124076, 0.1149076, 0.1174076, 0.1199076, 0.1224076, 0.1249076,
         0.1274076, 0.1299076, 0.1324076, 0.1349076, 0.1374076, 0.1399076,
         0.1424076]),
  jnp.array([0.02100495, 0.02000676, 0.01897691, 0.01791351, 0.016814,
         0.01567488, 0.01449142, 0.0132571 , 0.01196264, 0.0105943 ,
         0.00913049, 0.00753422, 0.00573621, 0.00368666, 0.00298916,
         0.00351651, 0.00417858, 0.00485768, 0.00553383, 0.00620241,
         0.00686251, 0.00751431, 0.00815832, 0.00879512, 0.00942527]),
  0.11240760359238675,
  0.25),
 (0.09041095890410959,
  1.0027397260273974,
  jnp.array([0.07611851, 0.07861851, 0.08111851, 0.08361851, 0.08611851,
         0.08861851, 0.09111851, 0.09361851, 0.09611851, 0.09861851,
         0.10111851, 0.10361851, 0.10611851, 0.10861851, 0.11111851,
         0.11361851, 0.11611851, 0.11861851, 0.12111851, 0.12361851,
         0.12611851, 0.12861851, 0.13111851, 0.13361851, 0.13611851]),
  jnp.array([0.02571163, 0.02466922, 0.02359377, 0.02248411, 0.02133859,
         0.02015503, 0.01893064, 0.01766194, 0.01634481, 0.01497479,
         0.01354828, 0.01206712, 0.01055505, 0.00911653, 0.00807032,
         0.00778549, 0.00810574, 0.00870589, 0.0094221 , 0.01018791,
         0.01097495, 0.01177004, 0.01256661, 0.01336128, 0.01415225]),
  0.10611850901102435,
  0.25),
 (0.09041095890410959,
  2.0027397260273974,
  jnp.array([0.06970405, 0.07220405, 0.07470405, 0.07720405, 0.07970405,
         0.08220405, 0.08470405, 0.08720405, 0.08970405, 0.09220405,
         0.09470405, 0.09720405, 0.09970405, 0.10220405, 0.10470405,
         0.10720405, 0.10970405, 0.11220405, 0.11470405, 0.11720405,
         0.11970405, 0.12220405, 0.12470405, 0.12720405, 0.12970405]),
  jnp.array([0.02641612, 0.02545857, 0.02447167, 0.02345581, 0.02241125,
         0.02133829, 0.02023758, 0.01911054, 0.01796036, 0.01679381,
         0.01562486, 0.01448212, 0.01342213, 0.01254578, 0.01198868,
         0.01184018, 0.01206377, 0.01254688, 0.01318733, 0.01391874,
         0.01470226, 0.01551549, 0.0163453 , 0.01718381, 0.01802615]),
  0.09970405414511939,
  0.25)]


x0 = jnp.array([0.01, 0.00, 0.10])
bounds = (jnp.array([0.0001, -0.9999, 0.0001]), jnp.array([999, 0.9999, 999]))
args = data[0]

# This fails, as the objective function is producing nans when the step size immediately violates bounds as part of the implicit differentiation
solver = jaxopt.LBFGSB(fun=_obj)
results = solver.run(x0, bounds=bounds, args=args)


# However the objective function evaluates properly at x0
_obj(x0, args)

@charles-zhng
Copy link

I might be having the same issue too, please let me know if you learn something new! For now I am clipping to the bounds myself in my loss function.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants