In [3]:
import jax
import jax.numpy as jnp
import numpy as np
import time

# Sample code

In [7]:
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.arange(3.0, 6.0) # From 3.0 to up to (but not including) 6.0

In [8]:
print(a)
print(b)

[1. 2. 3.]
[3. 4. 5.]


In [11]:
# Element-wise addition
c = a + b
print("a + b:", c)

# Dot product
d = jnp.dot(a, b)
print("Dot product of a and b:", d)

# Reshaping
e = jnp.zeros((2, 3))
print("Zeros array, reshaped:", e)

# Transposing
f = e.T
print("Transposed array 'f':\n", f)

a + b: [4. 6. 8.]
Dot product of a and b: 26.0
Zeros array, reshaped: [[0. 0. 0.]
 [0. 0. 0.]]
Transposed array 'f':
 [[0. 0.]
 [0. 0.]
 [0. 0.]]


# JAX immutability

In [14]:
np_arr = np.array([1.0, 2.0, 3.0])
print("Original array 'np_arr':\n", np_arr)

np_arr[0] = 99 # Modifying an element in-place
print("Modified array 'np_arr':\n", np_arr)

Original array 'np_arr':
 [1. 2. 3.]
Modified array 'np_arr':
 [99.  2.  3.]


In [16]:
jax_arr = jnp.array([1.0, 2.0, 3.0])
print("JAX array 'jax_arr':\n", jax_arr)

try:
    jax_arr[0] = 99
except TypeError as e:
    print(f"Attempting in-place modification raised an error: {e}")

JAX array 'jax_arr':
 [1. 2. 3.]
Attempting in-place modification raised an error: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html


# jax.grad

In [21]:
def square(x):
    return x ** 2

# Get the gradient function for 'square'
gradient_square = jax.grad(square)

# Calculate the gradient at x = 3.0
# The derivative of x^2 is 2x. So at x = 3, it should be 6.
print(f"Original function 'square(3.0)': {square(3.0)}")
print(f"Gradient of square(x) at x=3.0: {gradient_square(3.0)}")

Original function 'square(3.0)': 9.0
Gradient of square(x) at x=3.0: 6.0


Slightly more complicated function

In [23]:
def sin_square(x):
    return jnp.sin(x ** 2)

# Get its gradient function
gradient_sin_square = jax.grad(sin_square)

# Calculate the gradient at x = 1.0
print(f"\nOriginal function 'sin_square(1.0)': {sin_square(1.0)}")
print(f"Gradient of sin_square(x) at x=1.0: {gradient_sin_square(1.0)}")
print(f"Expected: {2 * jnp.cos(1.0)}") # Compare with manual calculation


Original function 'sin_square(1.0)': 0.8414709568023682
Gradient of sin_square(x) at x=1.0: 1.0806045532226562
Expected: 1.0806045532226562


Function with multiple arguments

In [28]:
def sum_of_squares(x, y):
    return x**2 + y**2

# By default, jax.grad computes the gradient with respect to the first argument.
grad_x_sum_of_squares = jax.grad(sum_of_squares)
print(f"\nGradient of sum_of_squares w.r.t. x at (x=2, y=3): {grad_x_sum_of_squares(2.0, 3.0)}") # Should be 2*2 = 4

# To specify which arguments to differentiate with respect to, use 'argnums'.
grad_xy_sum_of_squares = jax.grad(sum_of_squares, argnums=(0, 1))
grad_x, grad_y = grad_xy_sum_of_squares(2.0, 3.0)
print(f"Gradient of sum_of_squares w.r.t. x and y at (x=2, y=3): x_grad={grad_x}, y_grad={grad_y}") # Should be 4, 6


Gradient of sum_of_squares w.r.t. x at (x=2, y=3): 4.0
Gradient of sum_of_squares w.r.t. x and y at (x=2, y=3): x_grad=4.0, y_grad=6.0


# jax.jit

In [31]:
# A computationally intensive function
def sum_large_array(x):
    return jnp.sum(x * jnp.sin(x) / jnp.cosh(x) + jnp.log(x + 1))

# Create a large JAX array
large_array = jnp.arange(1, 1_00_001, dtype=jnp.float32) # From 1 to 1 million

# Test without JIT
start_time = time.time()
result_non_jit = sum_large_array(large_array)
end_time = time.time()
print(f"Non-JIT execution time: {end_time - start_time:.4f} seconds")

# Test with JIT
# Apply jax.jit as a decorator or directly wrap the function
@jax.jit
def sum_large_array_jitted(x):
    return jnp.sum(x * jnp.sin(x) / jnp.cosh(x) + jnp.log(x + 1))

# First call (compilation happens here)
start_time = time.time()
result_jit_first_call = sum_large_array_jitted(large_array)
end_time = time.time()
print(f"JIT first call (with compilation) time: {end_time - start_time:.4f} seconds")

# Subsequent calls (using compiled code)
start_time = time.time()
result_jit_subsequent_call = sum_large_array_jitted(large_array)
end_time = time.time()
print(f"JIT subsequent call time: {end_time - start_time:.4f} seconds")

# Verify results are the same
print(f"Results match: {jnp.allclose(result_non_jit, result_jit_first_call)}")

Non-JIT execution time: 0.2244 seconds
JIT first call (with compilation) time: 0.0381 seconds
JIT subsequent call time: 0.0003 seconds
Results match: True


# Pure functions

Example of an impure function

In [36]:
global_counter = 0

def impure_add_and_increment(x):
    global global_counter
    global_counter += 1 # Side effect: modifies a global variable
    # This print statement shows the internal state *during* the function call
    print(f"  (Inside function) global_counter after increment: {global_counter}")
    return x + global_counter

print("--- Impure Function Example ---")
print(f"Initial global_counter: {global_counter}")

# First call to the impure function
result_1 = impure_add_and_increment(5)
print(f"First call impure_add_and_increment(5) returned: {result_1}")
print(f"global_counter after first call: {global_counter}") # Showing the side effect

# Second call to the impure function
result_2 = impure_add_and_increment(5)
print(f"Second call impure_add_and_increment(5) returned: {result_2}")
print(f"global_counter after second call: {global_counter}\n") # Showing the side effect again

print(f"Final value of global_counter: {global_counter}") # Final check of the modified global state


--- Impure Function Example ---
Initial global_counter: 0
  (Inside function) global_counter after increment: 1
First call impure_add_and_increment(5) returned: 6
global_counter after first call: 1
  (Inside function) global_counter after increment: 2
Second call impure_add_and_increment(5) returned: 7
global_counter after second call: 2

Final value of global_counter: 2


Example of a pure function

In [39]:
def pure_add(x, y):
    return x + y # Deterministic, no side effects

def pure_transform_list(my_list):
    # This creates a NEW list, it doesn't modify my_list in-place
    return [item * 2 for item in my_list]

print("\n--- Pure Function Example ---")
print(f"pure_add(5, 3) returned: {pure_add(5, 3)}") # Always 8
print(f"pure_add(5, 3) returned: {pure_add(5, 3)}") # Still 8
print("Notice how pure_add always gives the same result for the same inputs and causes no external changes.")

original_list = [1, 2, 3]
new_list = pure_transform_list(original_list) # Assuming pure_transform_list is defined from previous snippet
print("\nOriginal list (unchanged after pure transformation):", original_list)
print("New list (transformed by pure function):", new_list)


--- Pure Function Example ---
pure_add(5, 3) returned: 8
pure_add(5, 3) returned: 8
Notice how pure_add always gives the same result for the same inputs and causes no external changes.

Original list (unchanged after pure transformation): [1, 2, 3]
New list (transformed by pure function): [2, 4, 6]


# jax.pmap

In [4]:
# A function that operates on a single vector
def elementwise_multiply_add(x, y):
    # Imagine this is a complex operation on single data points
    return x * 2 + y / 3

# Define individual inputs
x_single = jnp.array([1.0, 2.0, 3.0])
y_single = jnp.array([4.0, 5.0, 6.0])

print(f"Result for single inputs: {elementwise_multiply_add(x_single, y_single)}")

Result for single inputs: [3.3333335 5.6666665 8.       ]


In [5]:
# x_batch: batch of 2 vectors, each of size 3
x_batch = jnp.array([[1.0, 2.0, 3.0],
                     [4.0, 5.0, 6.0]])

# y_batch: batch of 2 vectors, each of size 3
y_batch = jnp.array([[10.0, 11.0, 12.0],
                     [13.0, 14.0, 15.0]])

In [7]:
results = []
for i in range(x_batch.shape[0]):
    results.append(elementwise_multiply_add(x_batch[i], y_batch[i]))
manual_batch_result = jnp.array(results)
print(f"Manual batch result: {manual_batch_result}")

Manual batch result: [[ 5.333333  7.666667 10.      ]
 [12.333334 14.666666 17.      ]]


In [8]:
batched_elementwise_multiply_add = jax.vmap(elementwise_multiply_add)
vmap_result = batched_elementwise_multiply_add(x_batch, y_batch)
print(f"vmap-batched result: {vmap_result}")

vmap-batched result: [[ 5.333333  7.666667 10.      ]
 [12.333334 14.666666 17.      ]]


In [11]:
@jax.jit
@jax.vmap
def batched_loss_gradient(predictions, targets):
    # This function operates on a single prediction-target pair
    loss_fn = lambda p, t: jnp.mean((p - t)**2) # MSE Loss
    # Gradient of loss w.r.t predictions
    return jax.grad(loss_fn)(predictions, targets)

In [12]:
# Batch data
preds_batch = jnp.array([[1.0, 2.0], [3.0, 4.0]])
targets_batch = jnp.array([[1.1, 2.1], [3.3, 4.5]])

batched_grads = batched_loss_gradient(preds_batch, targets_batch)
print(f"\nBatched gradients using vmap and grad: \n{batched_grads}")


Batched gradients using vmap and grad: 
[[-0.10000002 -0.0999999 ]
 [-0.29999995 -0.5       ]]


# vmap

In [16]:
# Check available devices
print(f"\nAvailable devices: {jax.devices()}")
num_devices = len(jax.devices())

if num_devices < 2:
    print("\nSkipping pmap example: Requires at least 2 devices (e.g., multiple GPUs or CPU devices for simulation).")
else:
    # A simple function that computes mean
    def device_mean(x):
        return jnp.mean(x)

    # Let's create an array that we want to parallelize
    data = jnp.arange(16.0).reshape(num_devices, -1)
    print(f"\nData for pmap (sharded across devices):\n{data}")

    # Use pmap to run device_mean on each slice of data on each device
    pmapped_mean = jax.pmap(device_mean, axis_name='devices')

    # Each device gets a slice of 'data' (e.g., data[0] on device 0, data[1] on device 1)
    results_per_device = pmapped_mean(data)
    print(f"Mean calculated on each device:\n{results_per_device}")

    # Often, you'll want to combine results from all devices.
    def sum_across_devices(x):
        # x here is the local part on each device
        return jax.lax.psum(x, axis_name='devices')

    # pmap this new function, using the same axis_name
    pmapped_sum_across = jax.pmap(sum_across_devices, axis_name='devices')

    # Each device computes its local sum, then sums these across devices
    total_sum = pmapped_sum_across(data)

    print(f"Total sum across all devices (collective operation):\n{total_sum[0]}")
    print(f"Verified total sum: {jnp.sum(data)}")


Available devices: [CpuDevice(id=0)]

Skipping pmap example: Requires at least 2 devices (e.g., multiple GPUs or CPU devices for simulation).
