In [1]:
import jax.numpy as jnp

def square(x):
    return x * x

x = jnp.array([1.0, 2.0, 3.0])
print(square(x))

[1. 4. 9.]


In [2]:
import jax.numpy as jnp

vector = jnp.array([1, 2, 3])
matrix = jnp.array([[1, 2], [3, 4]])
print("Vector:", vector)
print("Matrix:", matrix)

sum_result = vector + vector
product_result = matrix * matrix
print("Sum:", sum_result)
print("Product:", product_result)

Vector: [1 2 3]
Matrix: [[1 2]
 [3 4]]
Sum: [2 4 6]
Product: [[ 1  4]
 [ 9 16]]


In [3]:
from jax import grad

def quadratic(x):
    return 3 * x ** 2 + 2 * x + 1


gradient_fn = grad(quadratic)
print("Gradient at x=2:", gradient_fn(2.0))


Gradient at x=2: 14.0


In [4]:
import numpy as np
import time

large_matrix_jax = jnp.array(np.random.rand(10000, 10000))
large_matrix_np = np.array(np.random.rand(10000, 10000))

# Timing JAX operation
start_time = time.time()
result_jax = jnp.dot(large_matrix_jax, large_matrix_jax)
jax_time = time.time() - start_time

# Timing NumPy operation
start_time = time.time()
result_np = np.dot(large_matrix_np, large_matrix_np)
np_time = time.time() - start_time

print(f"JAX Time: {jax_time} seconds")
print(f"NumPy Time: {np_time} seconds")


JAX Time: 2.5511696338653564 seconds
NumPy Time: 11.365630149841309 seconds


In [6]:
from jax import jit

def complicated_function(x):
    return jnp.cos(x) * jnp.sin(x) + jnp.log(x + 1)

jit_function = jit(complicated_function)

# Timing the original function
start_time = time.time()
for _ in range(10000):
    complicated_function(1.0)
normal_time = time.time() - start_time

# Timing the JIT-compiled function
start_time = time.time()
for _ in range(10000):
    jit_function(1.0)
jit_time = time.time() - start_time

print(f"Normal Time: {normal_time} seconds")
print(f"JIT Time: {jit_time} seconds")


Normal Time: 0.2695322036743164 seconds
JIT Time: 0.06746506690979004 seconds
