In [1]:
import dynamiqs as dq
import jax.numpy as jnp
from matplotlib import pyplot as plt
from globalConstants import *
# Parameters
na = 20           # Hilbert space dimension for mode a
nb = 5            # Hilbert space dimension for mode b
T = 4             # Total simulation time
omega = jnp.pi / T  # Oscillator frequency
eps = -4          # Coupling strength (example value)
Kb = 10           # Decay rate (kappa in the paper)
numTimes = 1000
t_save = jnp.linspace(0, T, numTimes)  # Time array for saving states

# Get annihilation operators for a two-mode system.
# dq.destroy returns a tuple of QArrays if multiple dimensions are provided.
a_op, b_op = dq.destroy(na, nb)
# Note: The dimensions of the composite operator are (na*nb, na*nb)

# For the identity on mode b, use dq.eye with dimension nb.
Ib = dq.eye(nb)

# Create the initial vacuum state in two ways:

# 1. Using dq.fock with a tuple: first tuple gives the Hilbert space dimensions,
#    second tuple gives the Fock indices for each mode.
psi0_fock = dq.fock((na, nb), (0, 0))  # Vacuum state |0>_a ⊗ |0>_b

# 2. Alternatively, create vacuum states as coherent states with amplitude 0
psi0a = dq.coherent(na, 0)  # Vacuum state for mode a
psi0b = dq.coherent(nb, 0)  # Vacuum state for mode b
psi0_tensor = dq.tensor(psi0a, psi0b)  # Tensor product to get the composite state

# Obtain the creation (dagger) operators using dq.dag
adag_op = dq.dag(a_op)
bdag_op = dq.dag(b_op)

# (Optional) Print shapes to confirm correct dimensions
print("Shape of psi0_fock:", psi0_fock.shape)
print("Shape of psi0_tensor:", psi0_tensor.shape)
print("Shape of a_op:", a_op.shape)
print("Shape of b_op:", b_op.shape)

# Now you can use psi0_fock (or psi0_tensor) as the initial state for your simulation.


Shape of psi0_fock: (100, 1)
Shape of psi0_tensor: (100, 1)
Shape of a_op: (100, 100)
Shape of b_op: (100, 100)


In [10]:
alpha 

2.0

In [11]:
import dynamiqs as dq
import jax
import jax.numpy as jnp
from matplotlib import pyplot as plt
import optax

# Set device to GPU if available (uncomment if GPU is present)
# dq.set_device("gpu")

# Simulation parameters (use lower resolution for debugging)
N_cutoff = 20  # Hilbert space dimension for the mode
T = 3.0 # Total evolution time
numTimes = 200 # Reduced number of time steps
t_save = jnp.linspace(0, T, numTimes)

# Create the bosonic annihilation operator for a single mode.
a_op = dq.destroy(N_cutoff)
adag_op = dq.dag(a_op)

# Create the vacuum state as the 0-th Fock state (converted via dq.to_jax)
psi0 = dq.fock(N_cutoff, 0)

# Define a piecewise constant drive function.
def eps_d(t, eps_params):
    t = jnp.asarray(t)
    N_bins = eps_params.shape[0]
    bin_index = jnp.minimum((t / T * N_bins).astype(jnp.int32), N_bins - 1)
    return eps_params[bin_index]

def drive_Hamiltonian(t, eps_params):
    drive = eps_d(t, eps_params)
    return jnp.conjugate(drive) * a_op + drive * adag_op

def H_2ph():
    return 0 * a_op

def H_total(t, eps_params):
    return H_2ph() + drive_Hamiltonian(t, eps_params)

# Define the target cat state with alpha = 2.
alpha = 2.0
psi_alpha = dq.coherent(N_cutoff, alpha)
psi_minus = dq.coherent(N_cutoff, -alpha)
norm = jnp.sqrt(2 * (1 + jnp.exp(-2 * alpha**2)))
cat_target = (psi_alpha + psi_minus) / norm

# Use a smaller number of drive bins for faster optimization.
N_bins = 50
eps_params_init = jnp.ones(N_bins, dtype=jnp.complex64) * 0.1

def make_H_time(eps_params):
    return dq.timecallable(lambda t: H_total(t, eps_params))
jumpOp = [(a_op@a_op)]

# jumpOp = [dq.zeros(a_op.shape[0])]
def loss(eps_params):
    H_time_current = make_H_time(eps_params)
    result = dq.mesolve(H_time_current,jumpOp, psi0, t_save)
    # result = dq.sesolve(H_time_current, psi0, t_save)

    final_state = result.states[-1]
    # fid = dq.fidelity(final_state, cat_target)
    a = dq.destroy(psi0.shape[0])
    adag = dq.dag(a)
    N = adag@a
    expecVal = dq.expect(N, final_state)
    return (jnp.real((expecVal-4)/(alpha**2))**2)

grad_loss = jax.grad(loss)
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(eps_params_init)
eps_params = eps_params_init

# For debugging, use fewer iterations.
n_iterations = 50
for i in range(n_iterations):
    grads = grad_loss(eps_params)
    updates, opt_state = optimizer.update(grads, opt_state)
    eps_params = optax.apply_updates(eps_params, updates)
    print(f"Iteration {i}: Loss = {loss(eps_params):.6f}")

times = jnp.linspace(0, T, N_bins)
plt.plot(times, jnp.real(eps_params), label="Real part")
plt.plot(times, jnp.imag(eps_params), label="Imaginary part")
plt.xlabel("Time")
plt.ylabel("Drive amplitude")
plt.legend()
plt.show()


|██████████| 100.0% ◆ elapsed 20.65ms ◆ remaining 0.00ms
|██████████| 100.0% ◆ elapsed 8.31s ◆ remaining 0.00ms   
ERROR:2025-02-02 05:34:25,249:jax._src.callback:97: jax.pure_callback failed
Traceback (most recent call last):
  File "/Users/amer_/Documents/Obsidian Vault/Personal/Project Notes/Hackathon/iQuHack/lib/python3.12/site-packages/jax/_src/callback.py", line 95, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
                                          ^^^^^^^^^^^^^^^
  File "/Users/amer_/Documents/Obsidian Vault/Personal/Project Notes/Hackathon/iQuHack/lib/python3.12/site-packages/jax/_src/callback.py", line 72, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/amer_/Documents/Obsidian Vault/Personal/Project Notes/Hackathon/iQuHack/lib/python3.12/site-packages/equinox/_errors.py", line 89, in raises
    raise _EquinoxRuntimeError(

XlaRuntimeError: FAILED_PRECONDITION: Buffer Definition Event: Error dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sError dispatching computation: %sCpuCallback error: Traceback (most recent call last):
  File "/Users/amer_/Documents/Obsidian Vault/Personal/Project Notes/Hackathon/iQuHack/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2960, in _wrapped_callback
  File "/Users/amer_/Documents/Obsidian Vault/Personal/Project Notes/Hackathon/iQuHack/lib/python3.12/site-packages/jax/_src/callback.py", line 306, in _callback
  File "/Users/amer_/Documents/Obsidian Vault/Personal/Project Notes/Hackathon/iQuHack/lib/python3.12/site-packages/jax/_src/callback.py", line 98, in pure_callback_impl
  File "/Users/amer_/Documents/Obsidian Vault/Personal/Project Notes/Hackathon/iQuHack/lib/python3.12/site-packages/jax/_src/callback.py", line 72, in __call__
  File "/Users/amer_/Documents/Obsidian Vault/Personal/Project Notes/Hackathon/iQuHack/lib/python3.12/site-packages/equinox/_errors.py", line 89, in raises
_EquinoxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.


--------------------
An error occurred during the runtime of your JAX program! Unfortunately you do not appear to be using `equinox.filter_jit` (perhaps you are using `jax.jit` instead?) and so further information about the error cannot be displayed. (Probably you are seeing a very large but uninformative error message right now.) Please wrap your program with `equinox.filter_jit`.
--------------------
