In [1]:
import jax.numpy as jnp
from jax import jvp, grad, jacobian
from jax.scipy.sparse.linalg import cg as conjugate_gradient
import optax

## Implicit Differentiation Toy Example

$$ \mathcal{L}^{in}(w, \theta) = (w-\theta^2)^2$$
$$ \mathcal{L}^{out}(w^*, \theta) = 2 w^*$$
Want to find $\min_{\theta} \mathcal{L}^{out}(w^*, \theta)$, where $w^* \in \text{argmin}_{w} \mathcal{L}^{in}(w, \theta)$

In [2]:
def inner_objective(w, theta):
    return (w-theta**2)**2

def outer_objective(w_opt):
    return 2 * w_opt

grad_outer = grad(outer_objective)
grad_inner = grad(inner_objective)

## Solve the inner problem

In [3]:
def solve_inner(theta, w_init):
    optimizer = optax.adam(learning_rate=1e-2)
    opt_state = optimizer.init(w_init)
    w = w_init
    for i in range(50):
        grad_w = grad_inner(w, theta)
        updates, opt_state = optimizer.update(grad_w, opt_state, w)
        w = optax.apply_updates(w, updates)
    return w

## Implicit Differentiation

$f(\theta) = \theta^2 = w^*$ (recall $\mathcal{L}^{in}(w, \theta) = (w-\theta^2)^2$)

$Af'(\theta) = B$
- $A = -\left[\partial^{2}_{w}\mathcal{L}^{in}(w^*, \theta)\right]$

- $B = \partial_{\theta w} \mathcal{L}^{in}(w^*, \theta)$


Let $v = \partial_{w}\mathcal{L}^{out}(f(\theta), \theta)$, then $\nabla_{\theta} = v^T f'(\theta)$. We directly compute $v^T f'(\theta)$ as follows:
- solve the linear system $Au = v$ for $u$ using conjugate gradient
- compute $u^TB$ 
    - [$u^T B = u^T A f'(\theta) = u^T A^T f'(\theta) = v^T f'(\theta)$]

In [4]:
def solve_outer(theta_init, n_steps=500, step_size=1e-3, learning_rate=1e-2):
    theta = theta_init
    optimizer = optax.adam(learning_rate)
    opt_state = optimizer.init(theta)
    w_init = 1.0
    
    for i in range(n_steps):
        # solve inner problem
        w_star = solve_inner(theta, w_init)
        w_init = w_star # get initial parameters for next iteration
        
        # solve for hyperparameter gradient with conjugate gradient
        v = grad_outer(w_star)
        B = jacobian(grad_inner, argnums=1)(w_star, theta)
        
        def matvec_A(u):
            return -jvp(lambda w: grad_inner(w, theta), (w_star,), (u,))[1]
        
        u = conjugate_gradient(matvec_A, v.T)[0]
        grad_theta = jnp.dot(u.T, B)
  
        # gradient descent with ADAM
        print(f'Iteration {i} of {n_steps} theta: {theta}, grad theta: {grad_theta}')
        updates, opt_state = optimizer.update(grad_theta, opt_state)
        theta = optax.apply_updates(theta, updates)
    
    return theta

In [5]:
solve_outer(jnp.array(1.0), n_steps=100, learning_rate=8e-2)

Iteration 0 of 100 theta: 1.0, grad theta: 4.0
Iteration 1 of 100 theta: 0.9200005531311035, grad theta: 3.680002212524414
Iteration 2 of 100 theta: 0.8402443528175354, grad theta: 3.3609774112701416
Iteration 3 of 100 theta: 0.7609260082244873, grad theta: 3.043704032897949
Iteration 4 of 100 theta: 0.6822706460952759, grad theta: 2.7290825843811035
Iteration 5 of 100 theta: 0.604533314704895, grad theta: 2.41813325881958
Iteration 6 of 100 theta: 0.5280033349990845, grad theta: 2.112013339996338
Iteration 7 of 100 theta: 0.4530062675476074, grad theta: 1.8120250701904297
Iteration 8 of 100 theta: 0.3799028992652893, grad theta: 1.5196115970611572
Iteration 9 of 100 theta: 0.309088796377182, grad theta: 1.236355185508728
Iteration 10 of 100 theta: 0.24098969995975494, grad theta: 0.9639587998390198
Iteration 11 of 100 theta: 0.17605474591255188, grad theta: 0.7042189836502075
Iteration 12 of 100 theta: 0.11474530398845673, grad theta: 0.4589812159538269
Iteration 13 of 100 theta: 0.05

Array(-0.0042658, dtype=float32)