In [2]:
import numpy as onp
import jax.numpy as np

import matplotlib.pyplot as plt
%matplotlib inline 
from scipy.optimize import minimize

In [10]:
def rosen(x):
    """"The Rosenbrock function"""
    return sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1-x[:-1])**2.0)

def rosen_der(x):
    xm = x[1:-1]
    xm_m1 = x[:-2]
    xm_p1 = x[2:]
    der = onp.zeros_like(x)
    der[1:-1] = 200*(xm-xm_m1**2) - 400*(xm_p1 - xm**2)*xm - 2*(1-xm)
    der[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
    der[-1] = 200*(x[-1]-x[-2]**2)
    return der

In [24]:
x0 = onp.array([1.3, 0.7, 0.8, 1.9, 1.2])

res = minimize(rosen, x0, method='BFGS', jac=rosen_der, options={'disp': True})

Optimization terminated successfully.
         Current function value: 0.000000
         Iterations: 25
         Function evaluations: 30
         Gradient evaluations: 30


In [21]:
def jax_rosen_der(x):
    xm = x[1:-1]
    xm_m1 = x[:-2]
    xm_p1 = x[2:]
    
    der = np.hstack((-400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0]), 200*(xm-xm_m1**2) - 400*(xm_p1 - xm**2)*xm - 2*(1-xm), 200*(x[-1]-x[-2]**2)))
    
#     der = np.zeros_like(x)
#     der[1:-1] = 200*(xm-xm_m1**2) - 400*(xm_p1 - xm**2)*xm - 2*(1-xm)
#     der[0] = -400*x[0]*(x[1]-x[0]**2) - 2*(1-x[0])
#     der[-1] = 200*(x[-1]-x[-2]**2)
    return der

In [18]:
jax_x0 = np.array([1.3, 0.7, 0.8, 1.9, 1.2])

res_j = minimize(rosen, jax_x0, method='BFGS', jac=jax_rosen_der, options={'disp': True})

Optimization terminated successfully.
         Current function value: 0.000000
         Iterations: 25
         Function evaluations: 30
         Gradient evaluations: 30


In [19]:
res_j

      fun: 4.0799799208644213e-13
 hess_inv: array([[0.00758745, 0.01243817, 0.0234388 , 0.04614586, 0.09221449],
       [0.01243817, 0.02481622, 0.04712762, 0.09298084, 0.18568164],
       [0.0234388 , 0.04712762, 0.09456061, 0.18673856, 0.37279784],
       [0.04614586, 0.09298084, 0.18673856, 0.37380675, 0.74615709],
       [0.09221449, 0.18568164, 0.37279784, 0.74615709, 1.49432005]])
      jac: DeviceArray([-5.7060729e-06, -2.7486874e-06, -2.5910108e-06,
             -7.7260538e-06,  5.7977527e-06], dtype=float32)
  message: 'Optimization terminated successfully.'
     nfev: 30
      nit: 25
     njev: 30
   status: 0
  success: True
        x: array([1.00000004, 1.0000001 , 1.00000022, 1.00000045, 1.00000092])

In [23]:
res

      fun: 4.0130879949972905e-13
 hess_inv: array([[0.00758796, 0.01243893, 0.02344025, 0.04614953, 0.09222281],
       [0.01243893, 0.02481725, 0.04712952, 0.09298607, 0.18569385],
       [0.02344025, 0.04712952, 0.09456412, 0.18674836, 0.37282072],
       [0.04614953, 0.09298607, 0.18674836, 0.37383212, 0.74621435],
       [0.09222281, 0.18569385, 0.37282072, 0.74621435, 1.49444705]])
      jac: array([-5.68982937e-06, -2.73296557e-06, -2.54520599e-06, -7.73460770e-06,
        5.78142698e-06])
  message: 'Optimization terminated successfully.'
     nfev: 30
      nit: 25
     njev: 30
   status: 0
  success: True
        x: array([1.00000004, 1.0000001 , 1.00000021, 1.00000044, 1.00000092])