In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit
from functools import reduce

# Example: Define a few constant Hermitian matrices
A1 = jnp.array([[0.0, 1.0], [1.0, 0.0]])  # Pauli X
A2 = jnp.array([[0.0, -1j], [1j, 0.0]])   # Pauli Y
A3 = jnp.array([[1.0, 0.0], [0.0, -1.0]]) # Pauli Z
matrices = [A1, A2, A3]

# Initial vector
v0 = jnp.array([1.0, 0.0])

# Function to compute the exponential of each scaled matrix and apply to the vector
def mat_exp(theta, A):
        return jax.scipy.linalg.expm(theta * A)

def apply_exp_chain(thetas, matrices, v):
    #def mat_exp(theta, A):
    #    return jax.scipy.linalg.expm(theta * A)

    # Compute each exp(theta_i * A_i)
    exp_matrices = [mat_exp(theta, A) for theta, A in zip(thetas, matrices)]

    # Multiply all exponentials in sequence (right-to-left)
    total_op = reduce(lambda x, y: x @ y, reversed(exp_matrices))

    return total_op @ v

# Wrap in a function for autodiff
def final_vector(thetas):
    return apply_exp_chain(thetas, matrices, v0)

# Gradient of output vector (real part) with respect to parameters
grad_fn = jax.jacrev(final_vector,holomorphic=True)  # or use jax.grad if scalar output

# Example parameters
thetas = jnp.array([0.1, 0.2, 0.3],dtype=complex)

# Evaluate function and gradient
output_vector = final_vector(thetas)
jacobian = grad_fn(thetas)

print("Output vector:", output_vector)
print("Jacobian:\n", jacobian)


Output vector: [1.3838362 -0.02722283j 0.07569441+0.14989974j]
Jacobian:
 [[ 0.13792416-0.2731351j   0.2731351 -0.13792416j  1.3838362 -0.02722283j]
 [ 0.75946563+0.01494021j  0.01494021+0.7594656j  -0.07569441-0.14989975j]]


In [3]:
jax.jacrev?

[0;31mSignature:[0m
[0mjax[0m[0;34m.[0m[0mjacrev[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfun[0m[0;34m:[0m [0;34m'Callable'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0margnums[0m[0;34m:[0m [0;34m'int | Sequence[int]'[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhas_aux[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mholomorphic[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mallow_int[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m'Callable'[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.

Args:
  fun: Function whose Jacobian is to be computed.
  argnums: Optional, integer or sequence of integers. Specifies which
    positional argument(s) to differentiate with respect to (

In [12]:
final_vector(thetas)

Array([1.3838362 -0.02722283j, 0.07569441+0.14989974j], dtype=complex64)

In [17]:
#orde in which operations are applied.
jax.scipy.linalg.expm(thetas[2]* A3)@jax.scipy.linalg.expm(thetas[1]*A2)@jax.scipy.linalg.expm(thetas[0]*A1)@v0

Array([1.3838362 -0.02722283j, 0.07569441+0.14989974j], dtype=complex64)

In [20]:
A3@jax.scipy.linalg.expm(thetas[2]* A3)@jax.scipy.linalg.expm(thetas[1]*A2)@jax.scipy.linalg.expm(thetas[0]*A1)@v0

Array([ 1.3838362 -0.02722283j, -0.07569441-0.14989974j], dtype=complex64)

In [11]:
jax.jacrev?

[0;31mSignature:[0m
[0mjax[0m[0;34m.[0m[0mjacrev[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mfun[0m[0;34m:[0m [0;34m'Callable'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0margnums[0m[0;34m:[0m [0;34m'int | Sequence[int]'[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhas_aux[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mholomorphic[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mallow_int[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m'Callable'[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Jacobian of ``fun`` evaluated row-by-row using reverse-mode AD.

Args:
  fun: Function whose Jacobian is to be computed.
  argnums: Optional, integer or sequence of integers. Specifies which
    positional argument(s) to differentiate with respect to (

In [10]:
jnp.shape(jacobian)

(2, 3)