# Installing and using JAX

JAX is an auto-differentiation library for native Python and Numpy code which does gradient-based optimization. Auto-differentiation forms the backbone of deep learning libraries like PyTorch.

Activate your standard environment from Assignment-3. Then   
```
pip install --upgrade pip     
pip install --upgrade jax jaxlib 
```

[See this](https://github.com/google/jax#installation) for more information. (CPU version should be enough for this project.)

In [4]:
import jax.numpy as jnp
"""
Use "jnp" instead of using "np", our favourite numpy library. 
All functions work as it is (at least that are required for this project).
Be careful though:

JAX works on python functions that are "functionally pure": 
For the sake of our project, that just means using array datatype everywhere 
(or 'jnp.array()' in particular) instead of using other datatype, say lists for
storing arrays or matrices. Whenever you face some datatype issue with jax, 
first try to convert it to jax numpy array using `jnp.array()`.

Tip: jnp's errors don't seem very readable as compared to np.
So use "np" first for most of the code and the moment the necessity for "jnp" starts, 
replace all np's with jnp's. Directly replacing should work fine. This is only a tip for easier 
debugging.
"""
from jax import jacfwd

In [6]:
# Define some simple function.
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Note that here, I want a derivative of a "vector" output function (inputs*a + b is a vector) wrt a input 
# "vector" a at a0: Derivative of vector wrt another vector is a matrix: The Jacobian
def simpleJ(a, b, inputs): #inputs is a matrix, a & b are vectors
    return sigmoid(jnp.dot(inputs, a) + b)

inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])

b = jnp.array([0.2, 0.1, 0.3, 0.2])
a0 = jnp.array([0.1,0.7,0.7])

# Isolate the function: variables to be differentiated from the constant parameters
f = lambda a: simpleJ(a, b, inputs) # Now f is just a function of variable to be differentiated

J = jacfwd(f)
# Till now I have only calculated the derivative, it still needs to be evaluated at a0.
J(a0)

DeviceArray([[ 0.07388726,  0.1591418 ,  0.10940997],
             [ 0.20861849, -0.2560318 ,  0.03555997],
             [ 0.12171669,  0.01404423, -0.30429173],
             [ 0.17407255, -0.58573055,  0.3269741 ]], dtype=float32)

In [2]:
!pip install --upgrade pip

Collecting pip
  Downloading pip-21.3.1-py3-none-any.whl (1.7 MB)
[K     |████████████████████████████████| 1.7 MB 1.1 MB/s eta 0:00:01
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.2.4
    Uninstalling pip-21.2.4:
      Successfully uninstalled pip-21.2.4
Successfully installed pip-21.3.1


In [3]:
!pip install --upgrade jax jaxlib

Collecting jax
  Downloading jax-0.2.24.tar.gz (786 kB)
     |████████████████████████████████| 786 kB 1.2 MB/s            
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hCollecting jaxlib
  Downloading jaxlib-0.1.73-cp38-none-manylinux2010_x86_64.whl (50.0 MB)
     |████████████████████████████████| 50.0 MB 12 kB/s              
[?25hCollecting absl-py
  Downloading absl_py-0.15.0-py3-none-any.whl (132 kB)
     |████████████████████████████████| 132 kB 34.8 MB/s            
Collecting opt_einsum
  Downloading opt_einsum-3.3.0-py3-none-any.whl (65 kB)
     |████████████████████████████████| 65 kB 573 kB/s             
Collecting typing_extensions
  Downloading typing_extensions-3.10.0.2-py3-none-any.whl (26 kB)
Collecting flatbuffers<3.0,>=1.12
  Downloading flatbuffers-2.0-py2.py3-none-any.whl (26 kB)
Building wheels for collected packages: jax
  Building wheel for jax (setup.py) ... [?25ldone
[?25h  Created wheel for jax: filename=jax-0.2.24-py3-none-any.whl size=90311