In [25]:
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import scipy as sp

In [10]:
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4., 5., 6.])


In [14]:
jnp.sum(a)

Array(6., dtype=float32)

In [19]:
a = np.random.rand(3)
# a_jx = jax.device_put(a)
a_jx = jnp.array(a)
print(type(a_jx))

<class 'jaxlib.xla_extension.ArrayImpl'>


In [27]:
def f(x: jnp.ndarray) -> jnp.ndarray:
    x = x + 1
    
    return jnp.sum(x**3)

x = jnp.array([1.0, 2.0, 3.0])
print(f(x))

99.0


In [28]:
grad_f = jax.grad(f)
print(grad_f(x))

hess_f = jax.hessian(f)

[12. 27. 48.]


In [29]:
def f_vector_valued(x: jnp.ndarray) -> jnp.ndarray:
    return jnp.array([x[0]**2 + x[1]**2 + x[2],
                       x[0] + x[1] - x[2]**2])

In [31]:
jac_fwd_f = jax.jacfwd(f_vector_valued)

x = jnp.array([1.0, 2.0, 3.0])
print(jac_fwd_f(x))

jac_rev_f = jax.jacrev(f_vector_valued)
print(jac_rev_f(x))

[[ 2.  4.  1.]
 [ 1.  1. -6.]]
[[ 2.  4.  1.]
 [ 1.  1. -6.]]


In [32]:
jac_rev_f(x)

Array([[ 2.,  4.,  1.],
       [ 1.,  1., -6.]], dtype=float32)

In [37]:
def sum_two_arrays(a: jnp.ndarray, b: jnp.ndarray, c: jnp.ndarray) -> jnp.ndarray:
    return a + b**2 + c**(.5)

grad_wrt_first = jax.jacrev(sum_two_arrays, argnums=(0, 1))

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([4.0, 5.0, 6.0])
z = jnp.array([7.0, 8.0, 9.0])
print(grad_wrt_first(x, y, z))


(Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32), Array([[ 8.,  0.,  0.],
       [ 0., 10.,  0.],
       [ 0.,  0., 12.]], dtype=float32))


In [63]:
from jax import jit

def f(x):
    large_entries = x[(x > 1).astype('int')]
    x = jnp.ones(large_entries.shape)
    print(x)
    return x

f_jit =jit(f)

In [64]:
f_jit(np.array([-1, 0, 1, 2, 3]))

Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace>


  return lax_numpy.astype(self, dtype, copy=copy, device=device)


Array([1., 1., 1., 1., 1.], dtype=float32)

In [None]:
f_jit.

  return lax_numpy.astype(self, dtype, copy=copy, device=device)


<jax._src.stages.Traced at 0x71abc079c450>

In [67]:
from jax import vmap

def f(x, y):
    return x[0]**2 + x[1]**2 + y[0]**2

vmapped_f = vmap(f, in_axes=(0, 0))

x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([[5.0, 6.0], [7.0, 8.0]])
print(vmapped_f(x, y))

[30. 74.]
