In [83]:
import jaxopt
import optimistix as optx
import jax.numpy as jnp
from jax import jvp, grad, jacobian 
import optax
import matplotlib.pyplot as plt
import time
import random
import numpy as np

In [98]:
def inner_objective(w, theta):
    return (w-theta**2)**2

def outer_objective(w_opt):
    return 2 * w_opt

grad_outer = grad(outer_objective)
grad_inner = grad(inner_objective)

@jaxopt.implicit_diff.custom_root(grad_inner)
def solve_inner1(w_init, theta):
    solver = jaxopt.BFGS(fun=inner_objective)
    w_opt, state = solver.run(w_init, theta=theta)
    return w_opt

def solve_inner2(w_init, theta):
    solver = jaxopt.BFGS(fun=inner_objective)
    w_opt, state = solver.run(w_init, theta=theta)
    return w_opt

@jaxopt.implicit_diff.custom_root(grad_inner)
def score1(w_init, theta):
    w_opt = solve_inner2(w_init, theta)
    #print(f"theta: {theta}, w_opt: {w_opt}, outer_objective: {outer_objective(w_opt)}")
    return outer_objective(w_opt)

def score2(theta, hi="hi"):
    w_init = 0.0
    w_opt = solve_inner1(w_init, theta)
    #print(f"theta: {theta}, w_opt: {w_opt}, outer_objective: {outer_objective(w_opt)}")
    return outer_objective(w_opt)

In [99]:
grad(score2)(1.0)

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.40000009536743164 Stepsize:1.0  Decrease Error:0.00039999998989515007  Curvature Error:0.40000009536743164 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 


Array(4., dtype=float32, weak_type=True)

In [81]:
times1 = []
for i in range(100):
    theta = random.random()
    w_init = random.random()
    t1 = time.time()
    g = grad(score1, argnums=1)(w_init, theta)
    t2 = time.time()
    times1.append(t2 - t1)

times2 = []
for i in range(100):
    theta = random.random()
    w_init = random.random()
    t1 = time.time()
    g = grad(score2, argnums=1)(w_init, theta)
    t2 = time.time()
    times2.append(t2-t1)

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.07799084484577179 Stepsize:1.0  Decrease Error:7.793102849973366e-05  Curvature Error:0.07799084484577179 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5000000596046448  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.01274079829454422 Stepsize:1.0  Decrease Error:1.2761708603648003e-05  Curvature Error:0.01274079829454422 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.4999999403953552  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.00595747958868742 Stepsize:1.0  Decrease Error:5.963049716228852e-06  Curvature Error:0.00595747958868742 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.07437632232904434 Stepsize:1.0  Decrease Error:7.438189641106874e-05  Curvature Error:0.07437632232904434 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.01661502756178379 Stepsize:1.0  Decrease Error:1.662038266658783e-05  Curvature Error:0.01661502756178379 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.1355343610048294 Stepsize:1.0  Decrease Error:0.00013552873861044645  Curvature Error:0.1355343610048294 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease E

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.03754311427474022 Stepsize:1.0  Decrease Error:3.752645352506079e-05  Curvature Error:0.03754311427474022 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.20416219532489777 Stepsize:1.0  Decrease Error:0.00020413454330991954  Curvature Error:0.20416219532489777 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0058980644680559635 Stepsize:1.0  Decrease Error:5.894345576962223e-06  Curvature Error:0.0058980644680559635 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5000000

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.05116962641477585 Stepsize:1.0  Decrease Error:5.119141496834345e-05  Curvature Error:0.05116962641477585 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.009876185096800327 Stepsize:1.0  Decrease Error:9.866439540928695e-06  Curvature Error:0.009876185096800327 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5000000596046448  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.008619850501418114 Stepsize:1.0  Decrease Error:8.620354492450133e-06  Curvature Error:0.008619850501418114 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Step

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.13481533527374268 Stepsize:1.0  Decrease Error:0.000134817193611525  Curvature Error:0.13481533527374268 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.1388055831193924 Stepsize:1.0  Decrease Error:0.0001387987722409889  Curvature Error:0.1388055831193924 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.0059468792751431465 Stepsize:1.0  Decrease Error:5.9506073739612475e-06  Curvature Error:0.0059468792751431465 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.4999999701

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.02988378144800663 Stepsize:1.0  Decrease Error:2.9883165552746505e-05  Curvature Error:0.02988378144800663 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.015863997861742973 Stepsize:1.0  Decrease Error:1.5856991012697108e-05  Curvature Error:0.015863997861742973 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.018184037879109383 Stepsize:1.0  Decrease Error:1.817659540392924e-05  Curvature Error:0.018184037879109383 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Dec

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.10399794578552246 Stepsize:1.0  Decrease Error:0.00010407037916593254  Curvature Error:0.10399794578552246 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.07720669358968735 Stepsize:1.0  Decrease Error:7.723034650553018e-05  Curvature Error:0.07720669358968735 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.4999999701976776  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 1.2907727978017647e-05 Stepsize:1.0  Decrease Error:1.3287090538938173e-08  Curvature Error:1.2907727978017647e-05 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.1371079981327057 Stepsize:1.0  Decrease Error:0.00013709787162952125  Curvature Error:0.1371079981327057 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.052090831100940704 Stepsize:1.0  Decrease Error:5.211851021158509e-05  Curvature Error:0.052090831100940704 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.4999999403953552  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.13514554500579834 Stepsize:1.0  Decrease Error:0.0001351439714198932  Curvature Error:0.13514554500579834 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsiz

INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.24577751755714417 Stepsize:1.0  Decrease Error:0.0002457625523675233  Curvature Error:0.24577751755714417 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.5  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.16683028638362885 Stepsize:1.0  Decrease Error:0.00016691081691533327  Curvature Error:0.16683028638362885 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsize:0.4999999701976776  Decrease Error:0.0  Curvature Error:0.0 
INFO: jaxopt.ZoomLineSearch: Iter: 1 Minimum Decrease & Curvature Errors (stop. crit.): 0.19885686039924622 Stepsize:1.0  Decrease Error:0.00019886475638486445  Curvature Error:0.19885686039924622 
INFO: jaxopt.ZoomLineSearch: Iter: 2 Minimum Decrease & Curvature Errors (stop. crit.): 0.0 Stepsi

In [86]:
times1 = np.array(times1)
times2 = np.array(times2)
print(f"grad(score1) compute time. Mean: {times1.mean()}, std: {times1.std()}")
print(f"grad(score2) compute time. Mean: {times2.mean()}, std: {times2.std()}")

grad(score1) compute time. Mean: 0.5935719418525696, std: 0.4197386625193027
grad(score2) compute time. Mean: 0.554382643699646, std: 0.3975916253284205
