## JAX basics

In [4]:
import jax
import jax.numpy as jnp

In [5]:
# Create JAX arrays
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])

# Perform operations
c = a + b
d = jnp.dot(a, b)

### gradient

In [None]:
# Define a simple function
def f(x):
    return x**2 + 3 * x + 2

# Compute the gradient
grad_f = jax.grad(f)
x = 2.0
grad_value = grad_f(x)
print(grad_value)  # Output will be the derivative of f at x=2


7.0


### JIT: just in time

In [6]:
from jax import jit

# Define a function
def compute_sum(x):
    return jnp.sum(x ** 2)

# JIT compile the function
jit_compute_sum = jit(compute_sum)

# Use the JIT compiled function
x = jnp.arange(10.0)
result = jit_compute_sum(x)
print(result)  # Output will be the sum of squares of elements in x

285.0


### Vectorization

In [7]:
from jax import vmap

# Define a function
def square(x):
    return x ** 2

# Vectorize the function
vectorized_square = vmap(square)

# Use the vectorized function
x = jnp.array([1.0, 2.0, 3.0])
result = vectorized_square(x)
print(result)  # Output will be [1.0, 4.0, 9.0]


[1. 4. 9.]


### random number generation

In [8]:
import jax.random as random

# Create a random key
key = random.PRNGKey(0)

# Generate random numbers
rand_numbers = random.normal(key, (3,))
print(rand_numbers)


[ 1.8160863  -0.48262316  0.33988908]


In [10]:
pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

zsh:1: no matches found: jax[cuda]
Note: you may need to restart the kernel to use updated packages.


  pid, fd = os.forkpty()


In [18]:
y0 = jnp.array([2.0, 0.0])  # Initial state (y, v)
y = jnp.zeros((len(t), len(y0)))

In [23]:
print(y.at[0])

_IndexUpdateRef(Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.],
       [0., 0.]

In [24]:
t_bounds = (0, 1)

collocation_points = jnp.linspace(t_bounds[0], t_bounds[1], 2)
collocation_points

Array([0., 1.], dtype=float32)