# This File is created to understand the JAX

In [1]:
!pip install jax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Simple Function using JAX

In [2]:
import jax
import jax.numpy as jnp

# Define a simple function
def my_function(x):
    return jnp.sin(x) + jnp.cos(x)

# Generate an input array
x = jnp.linspace(0.0, 2 * jnp.pi, 100)

# Compute the function using JAX
y = my_function(x)

# Print the result
print(y)




[ 1.          1.0614105   1.1185472   1.17118     1.2190967   1.2621045
  1.3000304   1.3327215   1.3600461   1.3818944   1.3981782   1.4088321
  1.4138131   1.4131011   1.4066992   1.3946328   1.3769509   1.3537245
  1.325047    1.2910341   1.2518226   1.2075704   1.158456    1.1046767
  1.0464492   0.9840081   0.91760474  0.8475066   0.77399576  0.69736826
  0.6179329   0.5360092   0.45192713  0.36602533  0.27864963  0.19015199
  0.10088861  0.01121938 -0.07849538 -0.16789412 -0.25661677 -0.3443061
 -0.43060905 -0.5151781  -0.5976724  -0.67776036 -0.7551192  -0.8294375
 -0.90041596 -0.96776867 -1.0312246  -1.090528   -1.1454405  -1.1957402
 -1.2412255  -1.2817128  -1.317039   -1.347062   -1.3716608  -1.3907366
 -1.404212   -1.4120334  -1.4141691  -1.4106103  -1.4013715  -1.3864899
 -1.3660253  -1.3400604  -1.3086994  -1.2720686  -1.2303158  -1.1836089
 -1.132136   -1.0761049  -1.0157402  -0.9512855  -0.8830004  -0.8111596
 -0.73605263 -0.6579818  -0.57726157 -0.49421686 -0.4091821  -

# Automatic Differentiation with JAX

In [3]:
import jax
import jax.numpy as jnp

# Define a function
def my_function(x):
    return jnp.sin(x) + jnp.cos(x)

# Compute the gradient of the function using JAX's automatic differentiation
grad_fn = jax.grad(my_function)
x = 1.0
gradient = grad_fn(x)

print(gradient)

-0.30116868


# JIT Compilation

In [4]:
import jax
import jax.numpy as jnp

# Define a function
def my_function(x):
    return jnp.sin(x) + jnp.cos(x)

# Compile the function using JIT
compiled_fn = jax.jit(my_function)

# Execute the compiled function
x = jnp.linspace(0.0, 2 * jnp.pi, 100)
result = compiled_fn(x)

print(result)


[ 1.          1.0614105   1.1185472   1.17118     1.2190967   1.2621045
  1.3000304   1.3327215   1.3600461   1.3818944   1.3981782   1.4088321
  1.4138131   1.4131011   1.4066992   1.3946328   1.3769509   1.3537245
  1.325047    1.2910341   1.2518226   1.2075704   1.158456    1.1046767
  1.0464492   0.9840081   0.91760474  0.8475066   0.77399576  0.69736826
  0.6179329   0.5360092   0.45192713  0.36602533  0.27864963  0.19015199
  0.10088861  0.01121938 -0.07849538 -0.16789412 -0.25661677 -0.3443061
 -0.43060905 -0.5151781  -0.5976724  -0.67776036 -0.7551192  -0.8294375
 -0.90041596 -0.96776867 -1.0312246  -1.090528   -1.1454405  -1.1957402
 -1.2412255  -1.2817128  -1.317039   -1.347062   -1.3716608  -1.3907366
 -1.404212   -1.4120334  -1.4141691  -1.4106103  -1.4013715  -1.3864899
 -1.3660253  -1.3400604  -1.3086994  -1.2720686  -1.2303158  -1.1836089
 -1.132136   -1.0761049  -1.0157402  -0.9512855  -0.8830004  -0.8111596
 -0.73605263 -0.6579818  -0.57726157 -0.49421686 -0.4091821  -

# GPU Acceleration

In [2]:
import jax
import jax.numpy as jnp

# Enable GPU acceleration
jax.config.update("jax_platform_name", "gpu")

# Define a function
def my_function(x):
    return jnp.sin(x) + jnp.cos(x)

# Move the function to the GPU
my_function_gpu = jax.jit(my_function)

# Generate input on the GPU
x_gpu = jnp.linspace(0.0, 2 * jnp.pi, 100)
x_gpu = jnp.array(x_gpu, dtype=jnp.float32)

# Execute the function on the GPU
result_gpu = my_function_gpu(x_gpu)

print(result_gpu)


[ 1.          1.0614105   1.1185472   1.17118     1.2190967   1.2621045
  1.3000304   1.3327215   1.3600461   1.3818944   1.3981782   1.4088321
  1.4138131   1.4131012   1.4066992   1.3946328   1.376951    1.3537245
  1.325047    1.2910341   1.2518226   1.2075704   1.158456    1.1046767
  1.0464492   0.9840081   0.91760474  0.8475066   0.77399576  0.69736826
  0.6179329   0.5360092   0.45192716  0.36602533  0.2786497   0.19015193
  0.10088861  0.01121938 -0.07849532 -0.16789412 -0.2566167  -0.34430605
 -0.43060905 -0.5151781  -0.5976724  -0.67776036 -0.7551192  -0.82943755
 -0.90041596 -0.96776867 -1.0312246  -1.090528   -1.1454405  -1.1957402
 -1.2412255  -1.2817128  -1.317039   -1.347062   -1.371661   -1.3907366
 -1.4042121  -1.4120336  -1.4141691  -1.4106103  -1.4013715  -1.3864899
 -1.3660253  -1.3400604  -1.3086994  -1.2720686  -1.2303158  -1.1836089
 -1.132136   -1.0761049  -1.0157402  -0.9512855  -0.8830003  -0.81115955
 -0.73605263 -0.6579819  -0.57726157 -0.4942169  -0.4091821

# Vectorization in JAX

In [3]:
import jax
import jax.numpy as jnp

# Define a function
def elementwise_multiply(x, y):
    return x * y

# Vectorize the function using `jax.vmap`
vectorized_fn = jax.vmap(elementwise_multiply)

# Generate input arrays
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])

# Apply the vectorized function to the inputs
result = vectorized_fn(x, y)

print(result)


[ 4 10 18]


# Parallelization in JAX

In [8]:
import jax
import jax.numpy as jnp

# Define a function
def elementwise_multiply(x, y):
    return x * y

# Vectorize the function using `jax.vmap`
vectorized_fn = jax.vmap(elementwise_multiply)

# Generate input arrays
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])

# Apply the vectorized function to the inputs
result = vectorized_fn(x, y)

print(result)


[ 4 10 18]


# Solving a Linear System

In [5]:
import jax
import jax.numpy as jnp
from jax.scipy.linalg import solve

# Define a coefficient matrix and a vector
A = jnp.array([[2, 1], [1, 3]])
b = jnp.array([1, 2])

# Solve the linear system using `jax.scipy.linalg.solve`
x = solve(A, b)

print(x)


[0.19999999 0.6       ]
