In [7]:

import jax
import jax.numpy as jnp
import jax.random as random
from jax import grad, jit, vmap, pmap, device_put, device_get, hessian, jacfwd, jacrev
import jax.tree_util as tree_util
import jax.lax as lax
import jax.nn as nn
#from jax.experimental import loops
import jax.profiler as profiler

# Set up JAX to use 64-bit precision
jax.config.update("jax_enable_x64", True)

# Random number generation
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print("Random numbers:", x)

# Basic operations with grad
def simple_func(x):
    return jnp.sin(x) * jnp.cos(x)

grad_simple_func = grad(simple_func)
print("Gradient of simple_func at x=1.0:", grad_simple_func(1.0))

# Just-in-time compilation with jit
@jit
def compute_square(x):
    return x ** 2

print("Square of 3:", compute_square(3))

# Vectorized mapping with vmap
@vmap
def vectorized_square(x):
    return x ** 2

print("Vectorized squares:", vectorized_square(jnp.arange(5)))

# # Parallelized mapping with pmap
# @pmap
# def parallelized_square(x):
#     return x ** 2

# print("Parallelized squares:", parallelized_square(jnp.arange(5)))

# Using lax for control flow
def while_loop_example(x):
    def cond_fun(val):
        return val < 10

    def body_fun(val):
        return val + 1

    return lax.while_loop(cond_fun, body_fun, x)

print("Result of while loop starting at 0:", while_loop_example(0))

# Device transfer
x_device = device_put(x)
print("Data on device:", x_device)
x_host = device_get(x_device)
print("Data back on host:", x_host)

# Tree operations
tree = {'a': x, 'b': (2, 3)}
flat_tree, tree_def = tree_util.tree_flatten(tree)
print("Flattened tree:", flat_tree)
restored_tree = tree_util.tree_unflatten(tree_def, flat_tree)
print("Restored tree:", restored_tree)

# Lax scan
def scan_func(carry, x):
    return carry + x, carry + x

carry, result = lax.scan(scan_func, 0, jnp.arange(5))
print("Result of lax.scan:", result)

# Neural network module
x = jnp.linspace(-5, 5, 100)
logits = nn.relu(x)
print("ReLU activation:", logits)

# Profiling
profiler.start_trace("profiler_output")
y = compute_square(x)
profiler.stop_trace()

# Custom gradient
@custom_grad
def custom_func(x):
    value = x ** 2
    def grad_func(grad):
        return grad * 2 * x
    return value, grad_func

print("Custom gradient at x=3:", grad(custom_func)(3.0))

# Hessian
def quadratic_func(x):
    return x ** 2

hess_func = hessian(quadratic_func)
print("Hessian of quadratic_func at x=1.0:", hess_func(1.0))

# Jacobian forward mode
def func(x):
    return jnp.array([x ** 2, x ** 3])

jacobian_fwd = jacfwd(func)
print("Jacobian forward mode at x=2.0:", jacobian_fwd(2.0))

# Jacobian reverse mode
jacobian_rev = jacrev(func)
print("Jacobian reverse mode at x=2.0:", jacobian_rev(2.0))


Random numbers: [ 1.05451609 -0.96928879 -0.59460177 -0.03188579  2.41093278 -1.87844856
 -0.78476944 -0.31370829  0.33370904  1.76770368]
Gradient of simple_func at x=1.0: -0.4161468365471423
Square of 3: 9
Vectorized squares: [ 0  1  4  9 16]
Result of while loop starting at 0: 10
Data on device: [ 1.05451609 -0.96928879 -0.59460177 -0.03188579  2.41093278 -1.87844856
 -0.78476944 -0.31370829  0.33370904  1.76770368]
Data back on host: [ 1.05451609 -0.96928879 -0.59460177 -0.03188579  2.41093278 -1.87844856
 -0.78476944 -0.31370829  0.33370904  1.76770368]
Flattened tree: [Array([ 1.05451609, -0.96928879, -0.59460177, -0.03188579,  2.41093278,
       -1.87844856, -0.78476944, -0.31370829,  0.33370904,  1.76770368],      dtype=float64), 2, 3]
Restored tree: {'a': Array([ 1.05451609, -0.96928879, -0.59460177, -0.03188579,  2.41093278,
       -1.87844856, -0.78476944, -0.31370829,  0.33370904,  1.76770368],      dtype=float64), 'b': (2, 3)}
Result of lax.scan: [ 0  1  3  6 10]
ReLU acti

NameError: name 'custom_grad' is not defined