In [1]:
import dynamiqs as dq
import jax
import jax.numpy as jnp

n = 10
b = dq.destroy(n)
H0 = 0.5 * b.dag() @ b.dag() @ b @ b + 0.3 * b.dag() @ b
jump_ops = [jnp.sqrt(1.0) * b]

solver = dq.SteadyStateGMRES(exact_dm=False)
def loss(epsilon):
    H = H0 + epsilon * (b.dag() + b)
    result = dq.steadystate(H, jump_ops)
    return jnp.real(jnp.trace(result.rho.to_jax()))

grad_fn = jax.grad(loss)
print(grad_fn(jnp.float64(0.8)))

  return asarray(x, dtype=self.dtype)


-2.6066838e-05


In [2]:
import jax
import jax.numpy as jnp
import dynamiqs as dq
from dynamiqs.steady_state import SteadyStateGMRES

jax.config.update('jax_enable_x64', True)

# === Simple single-mode test ===
n = 8
a = dq.destroy(n)
H0 = 0.5 * a.dag() @ a
jump_ops = [jnp.sqrt(1.0) * a]

solver = SteadyStateGMRES(tol=1e-8, max_iteration=100, krylov_size=32, exact_dm=False)

# 1. Does the forward pass work?
print("=== Forward pass ===")
H = H0 + 0.8 * (a.dag() + a)
result = dq.steadystate(H, jump_ops, solver=solver)
print(f"trace(rho) = {jnp.trace(result.rho.to_jax()):.6f}")
print(f"success = {result.infos.success}")

# 2. Does jax.jvp work?
print("\n=== JVP test ===")
def loss(epsilon):
    H = H0 + epsilon * (a.dag() + a)
    result = dq.steadystate(H, jump_ops, solver=solver)
    return jnp.real(jnp.trace(result.rho.to_jax()))

eps0 = jnp.float64(0.8)
try:
    val, grad_jvp = jax.jvp(loss, (eps0,), (jnp.float64(1.0),))
    print(f"loss = {val:.6f}, jvp grad = {grad_jvp:.6e}")
except Exception as e:
    print(f"JVP failed: {type(e).__name__}: {e}")

# 3. Does jax.grad work?
print("\n=== grad test ===")
try:
    grad_fn = jax.grad(loss)
    g = grad_fn(eps0)
    print(f"grad = {g:.6e}")
except Exception as e:
    print(f"grad failed: {type(e).__name__}: {e}")

# 4. Finite difference reference
print("\n=== Finite difference ===")
h = 1e-4
grad_fd = (loss(eps0 + h) - loss(eps0 - h)) / (2 * h)
print(f"fd grad = {grad_fd:.6e}")

# 5. Compare
print("\n=== Comparison ===")
try:
    _, grad_jvp = jax.jvp(loss, (eps0,), (jnp.float64(1.0),))
    print(f"jvp:  {grad_jvp:.6e}")
    print(f"fd:   {grad_fd:.6e}")
    print(f"match: {jnp.isclose(grad_jvp, grad_fd, rtol=5e-2)}")
except Exception as e:
    print(f"Comparison failed: {e}")

=== Forward pass ===
trace(rho) = 1.000000+0.000000j
success = True

=== JVP test ===
loss = 1.000000, jvp grad = -4.076600e-17

=== grad test ===
grad = 0.000000e+00

=== Finite difference ===
fd grad = -1.110223e-12

=== Comparison ===
jvp:  -4.076600e-17
fd:   -1.110223e-12
match: True


In [3]:
import jax
import jax.numpy as jnp
import dynamiqs as dq
from dynamiqs.steady_state import SteadyStateGMRES

jax.config.update('jax_enable_x64', True)

# === Simple single-mode test ===
n = 8
a = dq.destroy(n)
H0 = 0.5 * a.dag() @ a
jump_ops = [jnp.sqrt(1.0) * a]
n_op = a.dag() @ a  # number operator

solver = SteadyStateGMRES(tol=1e-8, max_iteration=100, krylov_size=32, exact_dm=True)

# Loss = <n> = tr(a†a ρ), which depends nontrivially on epsilon
def loss(epsilon):
    H = H0 + epsilon * (a.dag() + a)
    result = dq.steadystate(H, jump_ops, solver=solver)
    return jnp.real(dq.expect(n_op, result.rho))

eps0 = jnp.float64(0.8)

# 1. Forward pass
print("=== Forward pass ===")
val = loss(eps0)
print(f"<n> = {val:.6f}")

# 2. Finite difference reference
print("\n=== Finite difference ===")
h = 1e-4
grad_fd = (loss(eps0 + h) - loss(eps0 - h)) / (2 * h)
print(f"fd grad = {grad_fd:.6e}")

# 3. JVP
print("\n=== JVP test ===")
try:
    val, grad_jvp = jax.jvp(loss, (eps0,), (jnp.float64(1.0),))
    print(f"<n> = {val:.6f}, jvp grad = {grad_jvp:.6e}")
except Exception as e:
    print(f"JVP failed: {type(e).__name__}: {e}")
    grad_jvp = None

# 4. grad (reverse mode)
print("\n=== grad test ===")
try:
    grad_fn = jax.grad(loss)
    g = grad_fn(eps0)
    print(f"grad = {g:.6e}")
except Exception as e:
    print(f"grad failed: {type(e).__name__}: {e}")
    g = None

# 5. Compare
print("\n=== Comparison ===")
print(f"fd:   {grad_fd:.6e}")
if grad_jvp is not None:
    print(f"jvp:  {grad_jvp:.6e}")
    print(f"jvp match fd: {jnp.isclose(grad_jvp, grad_fd, rtol=5e-2)}")
if g is not None:
    print(f"grad: {g:.6e}")
    print(f"grad match fd: {jnp.isclose(g, grad_fd, rtol=5e-2)}")

=== Forward pass ===
<n> = 1.277428

=== Finite difference ===
fd grad = 3.158413e+00

=== JVP test ===
<n> = 1.277428, jvp grad = 3.158413e+00

=== grad test ===
grad = 3.158413e+00

=== Comparison ===
fd:   3.158413e+00
jvp:  3.158413e+00
jvp match fd: True
grad: 3.158413e+00
grad match fd: True


In [4]:
import jax
import jax.numpy as jnp
import dynamiqs as dq
from dynamiqs.steady_state import SteadyStateGMRES

jax.config.update('jax_enable_x64', True)

# === build_two_modes (copied from systems.py) ===
to_rad_MHz = 2 * jnp.pi * 1e-3

def eps_d_from_na(na, g2):
    eps_map = {12: 1.0, 24: 2.0, 32: 7.0, 46: 12.0}
    if na in eps_map:
        return eps_map[na]
    n_target = na - 1
    alpha = (-15.0 + jnp.sqrt(225.0 + 4.0 * n_target)) / 2.0
    return float(g2 * alpha**2)

def build_two_modes(n_a, n_b, g2=2, eps_d=None, kappa_b=8, kappa_a=1):
    if eps_d is None:
        eps_d = eps_d_from_na(n_a, g2)
    g2 = g2 * to_rad_MHz
    kappa_b = kappa_b * to_rad_MHz
    kappa_a = kappa_a * to_rad_MHz
    eps_d = eps_d * to_rad_MHz
    a, b = dq.destroy(n_a, n_b)
    H0 = g2 * (a @ a @ b.dag() + a.dag() @ a.dag() @ b) + eps_d * (b + b.dag())
    Ls = [dq.asqarray(jnp.sqrt(kappa_b) * b), dq.asqarray(jnp.sqrt(kappa_a) * a)]
    return H0, Ls

# === Test ===
na, nb = 12, 3
print(f"Hilbert space: {na}x{nb}={na*nb}, vec dim: {(na*nb)**2}")

solver = SteadyStateGMRES(
    tol=1e-7, max_iteration=200, krylov_size=64, exact_dm=False
)

_, b_op = dq.destroy(na, nb)
n_b = b_op.dag() @ b_op

# 1. Forward pass
print("\n=== Forward pass ===")
H, Ls = build_two_modes(na, nb, kappa_a=1)
result = dq.steadystate(H, Ls, solver=solver)
print(f"<n_b> = {jnp.real(dq.expect(n_b, result.rho)):.6f}")
print(f"success = {result.infos.success}")

# 2. Loss function (differentiable w.r.t. eps_d)
def loss(eps_d):
    H, Ls = build_two_modes(na, nb, eps_d=eps_d, kappa_a=1)
    result = dq.steadystate(H, Ls, solver=solver)
    return jnp.real(dq.expect(n_b, result.rho))

eps0 = jnp.float64(2.0)

# 3. Finite difference
print("\n=== Finite difference ===")
h = 5e-3
grad_fd = (loss(eps0 + h) - loss(eps0 - h)) / (2 * h)
print(f"fd grad = {grad_fd:.6e}")

# 4. JVP
print("\n=== JVP test ===")
try:
    val, grad_jvp = jax.jvp(loss, (eps0,), (jnp.float64(1.0),))
    print(f"<n_b> = {val:.6f}, jvp grad = {grad_jvp:.6e}")
except Exception as e:
    print(f"JVP failed: {type(e).__name__}: {e}")
    import traceback; traceback.print_exc()
    grad_jvp = None

# 5. grad
print("\n=== grad test ===")
try:
    g = jax.grad(loss)(eps0)
    print(f"grad = {g:.6e}")
except Exception as e:
    print(f"grad failed: {type(e).__name__}: {e}")
    import traceback; traceback.print_exc()
    g = None

# 6. Compare
print("\n=== Comparison ===")
print(f"fd:   {grad_fd:.6e}")
if grad_jvp is not None:
    print(f"jvp:  {grad_jvp:.6e}")
    rel = abs(grad_jvp - grad_fd) / (abs(grad_fd) + 1e-15)
    print(f"rel error: {rel:.2e}")
    print(f"match (rtol=5%): {jnp.isclose(grad_jvp, grad_fd, rtol=5e-2)}")
if g is not None:
    print(f"grad: {g:.6e}")
    rel = abs(g - grad_fd) / (abs(grad_fd) + 1e-15)
    print(f"rel error: {rel:.2e}")
    print(f"match (rtol=5%): {jnp.isclose(g, grad_fd, rtol=5e-2)}")

Hilbert space: 12x3=36, vec dim: 1296

=== Forward pass ===
<n_b> = 0.005958
success = True

=== Finite difference ===
fd grad = 9.630774e-03

=== JVP test ===
<n_b> = 0.016554, jvp grad = 9.630787e-03

=== grad test ===
grad = 9.630786e-03

=== Comparison ===
fd:   9.630774e-03
jvp:  9.630787e-03
rel error: 1.35e-06
match (rtol=5%): True
grad: 9.630786e-03
rel error: 1.29e-06
match (rtol=5%): True
