# Maximum Likelihood Estimation Using Jax

The following notebook is a brief introduction to [Jax](https://github.com/google/jax), which is a combination of [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla). 

In this example, we will be using automatic differentation to perform maximum likelihood estimation of the normal linear regression model.

The inspiriation for this notebook comes from a blog post by [Rob Hicks](https://rlhick.people.wm.edu/posts/mle-autograd.html). We encourage interested readers to read the blog post first for an excellent introduction to the methodology and details of automatic diferentation.



The first item is to import the required packages into Python:

In [39]:
import numpy as np
import jax.numpy as jnp
from scipy.optimize import minimize
from jax.scipy import optimize 
import jax

Next, we tke the data generation code from the blog post with these changes:

```python
K = 2
beta = np.array([2,2])
sigma = 0.5
```
The reason for this change is that we would like to know the true coefficients in the data generating process (DGP) so that we can verify that the optimization routine is correctly recovering the parameters of interest, i.e. we set each coefficient $\beta$ to "2", change the number of explanatory variables to $K = 2$, and the error variance ($\sigma = 0.5$).



In [40]:
# number of observations
N = 5000
# number of parameters
K = 2
# true parameter values
# beta = 2 * np.random.randn(K)
beta = np.array([2,2])
# true error std deviation
sigma =  0.5

def datagen(N, beta, sigma):
    """
    Generates data for OLS regression.
    Inputs:
    N: Number of observations
    beta: K x 1 true parameter values
    sigma: std dev of error
    """
    K = beta.shape[0]
    x_ = 10 + 2 * np.random.randn(N,K-1)
    # x is the N x K data matrix with column of ones
    #   in the first position for estimating a constant
    x = np.c_[np.ones(N),x_]
    # y is the N x 1 vector of dependent variables
    y = x.dot(beta) + sigma*np.random.randn(N)
    return y, x

y, x  = datagen(N, beta, sigma)


The following block of code defines the negative log-likelihood function. We have altered the return of the likelihood function witht the following code:

```python
    return  (-1 * ll)/N
```

The reason for this change is that scaling the likelihood function can help with convergence of the optimization routine. The following [StackOverflow](https://stackoverflow.com/questions/24767191/scipy-is-not-optimizing-and-returns-desired-error-not-necessarily-achieved-due) post explains how this can help. Please read the second answer for additional details.

Many numpy functions can be used in Jax by simply calling the folloiwng import:

```python
import jax.numpy as jnp
```

Most numpy functions can be called as usual with the addition of the alias "jnp." in front of the name of the numpy function.



In [41]:
def neg_loglike(theta):
    beta = theta[:-1]
    # transform theta[-1]
    # so that sigma > 0
    sigma = jnp.exp(theta[-1])
    mu = jnp.dot(x,beta)
    ll = jax.numpy.sum(jax.scipy.stats.norm.logpdf(y, loc=mu, scale=sigma))
    return  (-1 * ll)/N


The next two lines of code caculate the Jacobian (jax.jacfwd) and Hessian (jax.hessian) matrices using automatic differentation. The Jacobian is calculated using forward mode automatic differentation. Additional details can be found at the Jax website.

In [42]:
jacobian = jax.jacfwd(neg_loglike)
hessian = jax.hessian(neg_loglike)

The next block of code appends the $\beta$'s and $\sigma$ to a vector called theta. Then, the Jacobian and Hessian is evaluated at these values. Note that some optimization algorithms don't use either a Jacobian or Hessian, some just the Jacobian, and some both.

In [43]:
theta = jnp.append(beta,jnp.log(sigma))
print(f'Jacobian : {jacobian(theta)} \n')
print(f'Hessian: {hessian(theta)}')

Jacobian : [-0.01427949 -0.10929886 -0.01989127] 

Hessian: [[3.9999952e+00 4.0061764e+01 2.8558983e-02]
 [4.0061764e+01 4.1679004e+02 2.1859752e-01]
 [2.8558968e-02 2.1859777e-01 2.0397825e+00]]


The next block of code utilizes scipy's minimize function to minimize our negative log-likelihood function. The method we use is BFGS and we add a tolerance option (i.e. 'gtol': 1e-7*N) according to the advice in the Stack Overflow post referenced above.

In [44]:
theta_start = jax.numpy.append(jax.numpy.zeros(beta.shape[0]),0.0)
res1 = minimize(neg_loglike, theta_start, method = 'BFGS', 
	       options={'disp': True,'gtol': 1e-7*N}, jac = jacobian) # Tolerance added to aid in convergence
print("Convergence Achieved: ", res1.success)
print("Number of Function Evaluations: ", res1.nfev)


Optimization terminated successfully.
         Current function value: 0.735579
         Iterations: 21
         Function evaluations: 30
         Gradient evaluations: 30
Convergence Achieved:  True
Number of Function Evaluations:  30


After the minimization routine is finished, you should see output with various metrics. The most important for our purposes is that the minimization algoritm has converged. We can see in the above output that Convergence Achieved is True, so we are ready to print out the results in the next code block.

You'll notice a number of items printed out, with the last one being "x", which are the values of the coefficients (plus $\sigma$ at the end) returned by the minimization routine.

In [45]:
print(res1)

      fun: 0.7355785369873047
 hess_inv: array([[ 6.73242292e+00, -6.53264517e-01,  6.80819173e-02],
       [-6.53264517e-01,  6.56590570e-02, -4.33176909e-03],
       [ 6.80819173e-02, -4.33176909e-03,  4.80685128e-01]])
      jac: array([-4.8005677e-05, -4.3852540e-04,  2.4762345e-05], dtype=float32)
  message: 'Optimization terminated successfully.'
     nfev: 30
      nit: 21
     njev: 30
   status: 0
  success: True
        x: array([ 2.02524075,  1.997835  , -0.68334767])


The next code block prints out the coefficient estimates and exponentiates the $\sigma$ parameter to place it in its original scale.

In [46]:
print(f'Coefficient estimates: {res1.x[0]}, {res1.x[1]}')

new_sigma = np.exp(res1.x[-1])

print(f'Original sigma: {new_sigma}')

Coefficient estimates: 2.0252407490353663, 1.9978350016466289
Original sigma: 0.504923842712376


Jax also has a built in minimization routine, which can be called with the next block of code. Note that the method option is restricted to BFGS at the time of writing.

In [47]:
res2 = jax.scipy.optimize.minimize(neg_loglike, theta_start, tol=1e-7*N, method='BFGS')

We can print out the results of this second minimization routine in the next block of code.

In [48]:
print(res2)

OptimizeResults(x=DeviceArray([ 2.0252805,  1.9978323, -0.6833601], dtype=float32), success=DeviceArray(True, dtype=bool), status=DeviceArray(0, dtype=int32), fun=DeviceArray(0.7355785, dtype=float32), jac=DeviceArray([4.2006650e-07, 4.8113866e-06, 1.4516472e-07], dtype=float32), hess_inv=DeviceArray([[ 7.0872049e+00, -6.8336809e-01,  6.8797884e-03],
             [-6.8336803e-01,  6.8185955e-02, -1.9134890e-03],
             [ 6.8797367e-03, -1.9134427e-03,  5.0876194e-01]],            dtype=float32), nfev=DeviceArray(36, dtype=int32), njev=DeviceArray(36, dtype=int32), nit=DeviceArray(23, dtype=int32))


We see from the above output that the optimization routines has been successful (success=DeviceArray(True, dtype=bool)). We can print out the coefficients estimated and the value of the error variance i the next block.

In [49]:
print(f'Betas: {res2.x[0]}, {res2.x[1]}')
new_sigma = np.exp(res2.x[-1])
print(f'Original sigma: {new_sigma}')

Betas: 2.025280475616455, 1.9978322982788086
Original sigma: 0.504917562007904


It looks like either minimization routine can be used to maximize the likelihood function. We would recommend the scipy.optimize.minimize function presently until Jax incorporates additional minimization algorithms.

We hope that this short tutorial shows the power of Jax and it's ease of use.