<a href="https://colab.research.google.com/github/bhadreshpsavani/LearningJax/blob/main/Notebooks/LearningJaxDoc.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Simple Numpy Operation:

In [5]:
import jax
import jax.numpy as jnp

In [2]:
long_array = jnp.arange(int(1e7))

In [3]:
%%timeit 
jnp.dot(long_array, long_array).block_until_ready()

The slowest run took 617.39 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 5: 421 µs per loop


## Jax first transformation with Grad

In [7]:
def sum_of_squares(x):
  """
  A simple function to make square of 
  given input jax numpy array
  """
  return jnp.sum(x**2)

sum_of_sqaures_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(sum_of_squares(x))

# we can directly call jax grad on python function
# and it will create transformation function
print(sum_of_sqaures_dx(x))

30.0
[2. 4. 6. 8.]


In [11]:
def sum_squared_error(x, y):
  return jnp.sum((x-y*2)**2)

sum_squared_error_dx = jax.grad(sum_squared_error)

y = jnp.asarray([1.1, 2.1, 3.1, 4.1])

print(sum_squared_error_dx(x, y))

[-2.4       -4.3999996 -6.3999996 -8.4      ]


In [12]:
jax.grad(sum_squared_error, argnums=(0, 1))(x, y) 
# to find gradient with respect to both x and y

(DeviceArray([-2.4      , -4.3999996, -6.3999996, -8.4      ], dtype=float32),
 DeviceArray([ 4.8     ,  8.799999, 12.799999, 16.8     ], dtype=float32))

In [13]:
jax.grad(sum_squared_error, argnums=(1))(x, y) # with respect to y

DeviceArray([ 4.8     ,  8.799999, 12.799999, 16.8     ], dtype=float32)

In [15]:
jax.grad(sum_squared_error)(x, y) 
# with respect to only first argument only

DeviceArray([-2.4      , -4.3999996, -6.3999996, -8.4      ], dtype=float32)

In [19]:
jax.value_and_grad(sum_squared_error, argnums=(1, 0))(x, y)

(DeviceArray(34.159996, dtype=float32),
 (DeviceArray([ 4.8     ,  8.799999, 12.799999, 16.8     ], dtype=float32),
  DeviceArray([-2.4      , -4.3999996, -6.3999996, -8.4      ], dtype=float32)))

## Auxillary Argument

In [20]:
def squared_error_with_aux(x, y):
  return sum_squared_error(x, y), x-y

In [21]:
jax.grad(squared_error_with_aux)(x, y)

TypeError: ignored

In [23]:
jax.grad(squared_error_with_aux, has_aux=True)(x, y)

(DeviceArray([-2.4      , -4.3999996, -6.3999996, -8.4      ], dtype=float32),
 DeviceArray([-0.10000002, -0.0999999 , -0.0999999 , -0.0999999 ], dtype=float32))

## Jax vs Numpy

In [24]:
import numpy as np
x = np.array([1, 2, 3])

def in_place_modify(x):
  x[0] = 123 
  return None

in_place_modify(x)
x

array([123,   2,   3])

In [25]:
in_place_modify(jnp.array(x))

TypeError: ignored

In [26]:
def jax_in_place_modify(x):
  return jax.ops.index_update(x, 0, 123)

y = jnp.array([1, 2, 3])
jax_in_place_modify(y)

DeviceArray([123,   2,   3], dtype=int32)

In [27]:
y

DeviceArray([1, 2, 3], dtype=int32)

## How jax tranformation works?

In [30]:
global_list = []

def log2(x):
  global_list.append(x)
  in_x = jnp.log(x)
  in_2 = jnp.log(2.0)
  return in_x / in_2

print(jax.make_jaxpr(log2)(3.0))

{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [32]:
def log2_with_print(x):
  print("print x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print(3.)))

print x: 3.0


TypeError: ignored

In [35]:
def log2_with_print(x):
  print("Printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.0))

Printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [37]:
def log2_if_rank_2(x):
  if x.ndim == 2:
    in_x = jnp.log(x)
    in_2 = jnp.log(2.0)
    return ln_x / ln_2
  else:
    return x

print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))

{ lambda ; a:i32[3]. let  in (a,) }
