In [14]:
import numpy as np
from globalConstants import *
import dynamiqs as dq
import jax
import jax.numpy as jnp # the JAX version of numpy
import matplotlib.pyplot as plt
import optax

In [15]:
# --- System parameters ---
n_a = 20  # Fock states in memory mode
n_b = 5   # Fock states in buffer mode
g2 = 1.0
kappa_b = 10.0
T = 3.0  # Total simulation time
ntpulse = 101  # Number of time bins for piecewise constant epsilon_d
learning_rate = 0.2  # Learning rate for optimization
nepochs = 300  # Number of optimization steps

# --- Initial state ---
psi0 = dq.tensor(dq.basis(n_a, 0), dq.basis(n_b, 0))  # Vacuum state in both modes

# --- Operators (correctly expanded for two-mode system) ---
a = dq.destroy(n_a)
b = dq.destroy(n_b)

# Expand operators to full Hilbert space
a_full = dq.tensor(a, b * 0)  # 'a' acts on memory mode (mode 0)
b_full = dq.tensor(a * 0, b)  # 'b' acts on buffer mode (mode 1)

# --- Target Cat State (Fixing Error 1) ---
alpha_target = 2.0  # Target |α|² = 4, so α = 2
coherent_plus = dq.coherent(n_a, alpha_target)  # |α⟩
coherent_minus = dq.coherent(n_a, -alpha_target)  # |-α⟩
cat_target = (coherent_plus + coherent_minus).unit()  # Normalize

# --- Time-discretized ε_d(t) ---
tpulse = jnp.linspace(0.0, T, ntpulse)


# --- Define Hamiltonian ---
def get_hamiltonian(epsilon_d):
    H_2ph = g2 * (a_full.dag() @ a_full.dag() @ b_full) + jnp.conj(g2) * (a_full @ a_full @ b_full.dag())
    H_d = dq.pwc(tpulse, epsilon_d.real, b_full.dag() + b_full)  # Fix: Ensure real values
    return H_2ph + H_d


# --- Lindblad Dissipator ---
jump_ops = [jnp.sqrt(kappa_b) * b_full]  # Losses only in buffer mode

# --- Loss Function (Fixing Error 2) ---
def compute_fidelity(epsilon_d):
    H = get_hamiltonian(epsilon_d)
    options = dq.Options(progress_meter=None)  # Disable progress meter
    result = dq.mesolve(H, jump_ops, psi0, tpulse, options=options)

    # Fix: Partial trace over buffer mode
    rho_memory = dq.partial_trace(result.states[-1], [n_a, n_b], [1])  # Trace out mode 1 (buffer)

    # Compute fidelity F = |⟨ψ_target|ρ|ψ_target⟩|
    fidelity = (cat_target.dag() @ rho_memory @ cat_target).tr().real
    return fidelity  # Fix: Ensure scalar real output

# Define loss function to *minimize* (Fixing Error 3)
@jax.jit
def compute_loss(epsilon_d):
    return 1.0 - compute_fidelity(epsilon_d)

# --- Optimization Setup ---
optimizer = optax.adam(learning_rate)
epsilon_d = -4.0 * jnp.ones(ntpulse)  # Initial guess for ε_d(t)
opt_state = optimizer.init(epsilon_d)
losses = []

# --- Optimization Loop ---
for _ in range(nepochs):
    loss, grads = jax.value_and_grad(compute_loss)(epsilon_d)  # Compute loss & gradient
    updates, opt_state = optimizer.update(grads, opt_state)
    epsilon_d = optax.apply_updates(epsilon_d, updates)  # Update parameters
    losses.append(loss)

# --- Plot Results ---
plt.figure(figsize=(8, 4))
plt.plot(tpulse, epsilon_d, label="Optimized $\epsilon_d(t)$")
plt.xlabel("Time")
plt.ylabel("$\epsilon_d$")
plt.legend()
plt.grid()
plt.show()


TypeError: Argument `values` must have shape `(..., len(times)-1)`, but has shape `(101,).