In [1]:
!pip install optax pandas



In [None]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.scipy.linalg as jla
import jax
import matplotlib.pyplot as plt
from jax import vmap
from jax import random
import numpy as np
import optax
import time
import pandas as pd
from jax import grad
import math
import logging
import sys
import os

# ========================================
# Logging Configuration
# ========================================
# ----- Logging Configuration (Print to Console) -----
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Suppress TensorFlow/XLA logs

# Create a logger and set level to INFO
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Remove any existing handlers
if logger.hasHandlers():
    logger.handlers.clear()

# Create a stream handler that outputs to sys.stdout
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO)

# Create formatter and add it to the stream handler
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

# ========================================
# Part 1: Control System Model (PIC + SLM)
# ========================================

# Constants and Parameters
def calculate_Omega_rabi_prefactor(I_mW_per_cm2, Detuning_MHz):
    """
    Calculate the effective Rabi frequency prefactor (MHz) based on laser intensity and detuning.

    Parameters:
    - I_mW_per_cm2: Laser intensity in mW/cm²
    - Detuning_MHz: Detuning in MHz

    Returns:
    - Omega_effRabi_prefactor_MHz: Effective Rabi frequency prefactor in MHz
    """
    hbar = 1.0545718e-34  # Reduced Planck constant (J·s)
    c = 3e8  # Speed of light in vacuum (m/s)
    epsilon_0 = 8.854187817e-12  # Permittivity of free space (F/m)
    mu0e = 3.336e-29  # |0> <-> |e> (C·m)
    mu1e = 3.400e-29  # |1> <-> |e> (C·m)

    I = I_mW_per_cm2 * 10  # Convert intensity from mW/cm² to W/m²
    E_0 = jnp.sqrt(2 * I / (c * epsilon_0))  # Electric field in V/m
    Omega_0e_prefactor = (mu0e * E_0) / hbar
    Omega_1e_prefactor = (mu1e * E_0) / hbar
    Omega_effRabi_prefactor = (Omega_0e_prefactor * Omega_1e_prefactor) / (2 * Detuning_MHz * 1e6)
    Omega_effRabi_prefactor_MHz = Omega_effRabi_prefactor / 1e6  # Convert to MHz
    return Omega_effRabi_prefactor_MHz


# Grid and atom positions
def generate_grid(grid_size=600, grid_range=(-6, 6)):
    """
    Generates a 2D grid using meshgrid for atom positioning.

    Args:
    grid_size (int): Number of points on the grid (default is 600).
    grid_range (tuple): Range of x and y coordinates (default is (-6, 6)).

    Returns:
    X, Y: Meshgrid for the x and y coordinates.
    """
    x = jnp.linspace(grid_range[0], grid_range[1], grid_size - 1)
    y = jnp.linspace(grid_range[0], grid_range[1], grid_size - 1)
    X, Y = jnp.meshgrid(x, y)
    return X, Y

# Function to generate atom positions
def generate_atom_positions_equilateral(N_a, side_length=3.0, center=(0, 0), triangle_spacing_factor=2.0):
    """
    Generate atom positions for any integer number of atoms (N_a) in equilateral triangular patterns.

    Parameters:
    - N_a (int): Total number of atoms.
    - side_length (float): Length of each side of the equilateral triangle.
    - center (tuple): Center of the first triangle (x_center, y_center).
    - triangle_spacing_factor (float): Distance between triangle centers, as a multiple of side_length.

    Returns:
    - positions (list of tuples): List of (x, y) positions for the atoms.
    """
    positions = []
    x_center, y_center = center  # Center of the first triangle
    height = (math.sqrt(3) / 2) * side_length  # Height of the equilateral triangle

    # Generate complete triangles
    num_full_triangles = N_a // 3
    for triangle_idx in range(num_full_triangles):
        # Calculate the center for the current triangle
        current_x_center = x_center + triangle_idx * triangle_spacing_factor * side_length
        current_y_center = y_center

        # Place atoms at the corners of the equilateral triangle
        positions.append((current_x_center - side_length / 2, current_y_center - height / 2))  # Bottom left
        positions.append((current_x_center + side_length / 2, current_y_center - height / 2))  # Bottom right
        positions.append((current_x_center, current_y_center + height / 2))  # Top vertex

    # Place remaining atoms (if any)
    remaining_atoms = N_a % 3
    if remaining_atoms > 0:
        # Get the last triangle's center
        if num_full_triangles > 0:
            last_x_center = x_center + (num_full_triangles - 1) * triangle_spacing_factor * side_length
            last_y_center = y_center
        else:
            last_x_center = x_center
            last_y_center = y_center

        # Add remaining atoms sequentially around the last triangle
        if remaining_atoms >= 1:
            positions.append((last_x_center - side_length / 2, last_y_center - height / 2))  # Bottom left
        if remaining_atoms == 2:
            positions.append((last_x_center + side_length / 2, last_y_center - height / 2))  # Bottom right

    return positions

# Dipoles for each atom
def generate_dipoles(N_a):
    dipoles = [jnp.array([1.0, 0]) for _ in range(N_a)]
    return dipoles

# Function to generate SLM modulations
def generate_slm_mod(N_slm):
    phase_mod = jnp.zeros(N_slm)
    amp_mod = jnp.ones(N_slm)
    return phase_mod, amp_mod

# Functions for crosstalk model setup (with fabrication variations considered)
def generate_distances(num_channels, pitch, random_variation=0.0, seed=None):
    """
    Generate pairwise distances for waveguides arranged in a linear layout with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        pitch (float): Spacing between adjacent waveguides (µm).
        random_variation (float): Maximum random variation (±) for distances.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        jax.numpy.array: Pairwise distances matrix (µm).
    """
    if seed is not None:
        np.random.seed(seed)

    distances = np.zeros((num_channels, num_channels), dtype=np.float32)
    for i in range(num_channels):
        for j in range(num_channels):
            if i != j:
                # Add randomness to the calculated distance
                base_distance = abs(i - j) * pitch
                variation = np.random.uniform(-random_variation, random_variation)
                distances[i, j] = base_distance + variation
    return jnp.array(distances)

def generate_coupling_lengths(num_channels, base_length, scaling_factor=1.0, random_variation=0.0, seed=None):
    """
    Generate coupling lengths with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        base_length (float): Base coupling length for adjacent waveguides (µm).
        scaling_factor (float): Multiplier to scale coupling length with increasing separation.
        random_variation (float): Maximum random variation (±) for coupling lengths.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        jax.numpy.array: Coupling lengths matrix (µm).
    """
    if seed is not None:
        np.random.seed(seed)

    coupling_lengths = np.zeros((num_channels, num_channels), dtype=np.float32)
    for i in range(num_channels):
        for j in range(num_channels):
            if i != j:
                base_length_ij = base_length * (scaling_factor**abs(i - j))
                variation = np.random.uniform(-random_variation, random_variation)
                coupling_lengths[i, j] = base_length_ij + variation
    return jnp.array(coupling_lengths)

def generate_n_eff_list(num_channels, base_n_eff, random_variation=0.0, seed=None):
    """
    Generate an array of effective refractive indices with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        base_n_eff (float): Base effective refractive index for all waveguides.
        random_variation (float): Maximum random variation (±) around the base value.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        numpy.ndarray: Array of effective refractive indices (n_eff_list).
    """
    if seed is not None:
        np.random.seed(seed)

    variations = np.random.uniform(-random_variation, random_variation, num_channels)
    n_eff_list = base_n_eff + variations
    return jnp.array(n_eff_list)


# 1.1 Control Signal Construction

def construct_V_smooth_with_carrier(tmin, tmax, t_steps, voltage_levels, omega_0, phi=0, max_step=30):
    """
    Constructs a time-dependent control signal with piecewise smooth voltage levels and a carrier frequency component.

    Parameters:
    - tmin (float): Start time.
    - tmax (float): End time.
    - t_steps (int): Number of time steps.
    - voltage_levels (array): Voltage levels for each piece.
    - omega_0 (float): Carrier angular frequency (rad/s).
    - phi (float, optional): Initial phase of the carrier (radians). Default is 0.
    - max_step (float, optional): Maximum allowed voltage change between consecutive pieces. Default is 30.

    Returns:
    - V_t (array): Time-dependent control voltage with carrier modulation.
    """
    time_points = jnp.linspace(tmin, tmax, t_steps)
    num_pieces = len(voltage_levels)
    piece_duration = t_steps // num_pieces

    # Ensure voltage_levels is a JAX array
    voltage_levels = jnp.array(voltage_levels, dtype=jnp.float32)

    V_piecewise = jnp.zeros_like(time_points)

    for i in range(num_pieces):
        if i > 0:
            delta_V = voltage_levels[i] - voltage_levels[i - 1]
            if jnp.abs(delta_V) > max_step:
                # Limit the voltage change to max_step
                voltage_levels = voltage_levels.at[i].set(
                    voltage_levels[i - 1] + jnp.sign(delta_V) * max_step
                )

        start_idx = i * piece_duration
        # Ensure the last piece covers any remaining time steps
        end_idx = (i + 1) * piece_duration if i < num_pieces - 1 else t_steps
        V_piecewise = V_piecewise.at[start_idx:end_idx].set(
            jnp.full(end_idx - start_idx, voltage_levels[i])
        )

    # Carrier modulation
    carrier = jnp.cos(omega_0 * time_points + phi)
    V_t = V_piecewise * carrier

    return V_t


# 1.2 Unitary Matrix Construction
def dn(V):
    return 4e-5 * V

def phi_func(L, n, dn_val, lambda_):
    return 2 * jnp.pi * L * (n + dn_val) / lambda_

def D_func(a, t_val, phi_val):
    return jnp.exp(1j * jnp.pi) * (a - t_val * jnp.exp(-1j * phi_val)) / (1 - t_val * a * jnp.exp(-1j * phi_val))

def U_drmzm_single_channel(V0, V1, L, n0, lambda_0, a0, t0, a1, t1, psi_0=0):
    dn_0 = dn(V0)
    dn_1 = dn(V1)
    phi0 = phi_func(L, n0, dn_0, lambda_0)
    phi1 = phi_func(L, n0, dn_1, lambda_0)

    D_00 = D_func(a0, t0, phi0) * jnp.exp(psi_0)
    D_11 = D_func(a1, t1, phi1)

    U_drmzm_matrix = jnp.array([[D_00, jnp.zeros_like(D_00)], [jnp.zeros_like(D_11), D_11]])
    U_bs = (1 / jnp.sqrt(2)) * jnp.array([[1, 1], [1, -1]])
    U_total = jnp.dot(U_bs, jnp.dot(U_drmzm_matrix, U_bs))

    return U_total

# Add Crosstalk
def U_drmzm_multi_channel(
    V0_t_list, V1_t_list, L, n0, lambda_0, a0, t0, a1, t1,
    N_ch, distances, coupling_lengths, n_eff_list, kappa0, alpha, enable_crosstalk=True
):
    """
    Constructs the multi-channel unitary matrix with optional inter-channel coupling (amplitude and phase crosstalk).

    Parameters:
    V0_t_list, V1_t_list: Control voltages for each channel.
    L, n0, lambda_0, a0, t0, a1, t1: Physical parameters of the system.
    N_ch (int): Number of waveguides (channels).
    distances (jax.numpy.array): Pairwise distances between channels (µm).
    coupling_lengths (jax.numpy.array): Pairwise coupling lengths (µm).
    n_eff_list (jax.numpy.array): Effective refractive indices for each waveguide.
    kappa0 (float): Maximum coupling coefficient.
    alpha (float): Decay rate of coupling with distance.
    enable_crosstalk (bool): If False, disables all crosstalk and returns a diagonal matrix.

    Returns:
    U_multi_channel_with_ct: Combined unitary matrix including (or excluding) crosstalk effects.
    """
    # Step 1: Construct U_multi_channel_no_ct (diagonal matrix without crosstalk)
    U_multi_channel_no_ct = jnp.zeros((N_ch, N_ch), dtype=jnp.complex64)
    for i in range(N_ch):
        V0_t = V0_t_list[i]
        V1_t = V1_t_list[i]
        U_single_channel = U_drmzm_single_channel(V0_t, V1_t, L, n0, lambda_0, a0, t0, a1, t1)
        U_multi_channel_no_ct = U_multi_channel_no_ct.at[i, i].set(U_single_channel[0, 0])

    # If crosstalk is disabled, return the diagonal matrix only
    if not enable_crosstalk:
        return U_multi_channel_no_ct

    # Step 2: Initialize crosstalk matrix
    U_wg_coupling_ct = jnp.eye(N_ch, dtype=jnp.complex64)  # Identity matrix

    # Wave vector (k = 2π / λ)
    k = 2 * jnp.pi / lambda_0

    def compute_transfer_matrix(L_val, beta_1, beta_2, kappa):
        """Compute transfer matrix for two coupled waveguides."""
        delta_beta = (beta_1 - beta_2) / 2.0
        kappa_eff = jnp.sqrt(kappa**2 + delta_beta**2)  # Effective coupling coefficient
        cos_term = jnp.cos(kappa_eff * L_val)
        sin_term = jnp.sin(kappa_eff * L_val)
        delta_term = delta_beta / kappa_eff if kappa_eff != 0 else 0

        return jnp.array([
            [cos_term - 1j * delta_term * sin_term, -1j * sin_term],
            [-1j * sin_term, cos_term + 1j * delta_term * sin_term]
        ])

    # Step 3: Compute direct coupling contributions
    for i in range(N_ch):
        beta_i = k * n_eff_list[i]
        for j in range(i + 1, N_ch):  # Upper triangle only
            beta_j = k * n_eff_list[j]

            if distances[i, j] > 0:
                # Compute coupling coefficient and propagation length
                kappa_ij = kappa0 * jnp.exp(-alpha * distances[i, j])
                L_ij = coupling_lengths[i, j]

                # Compute transfer matrix
                M = compute_transfer_matrix(L_ij, beta_i, beta_j, kappa_ij)

                # Update crosstalk matrix with amplitude and phase contributions
                U_wg_coupling_ct = U_wg_coupling_ct.at[i, j].set(M[0, 1])
                U_wg_coupling_ct = U_wg_coupling_ct.at[j, i].set(jnp.conj(M[0, 1]))

    # Step 4: Calculate combined matrix with crosstalk
    U_multi_channel_with_ct = jnp.matmul(U_multi_channel_no_ct, U_wg_coupling_ct)

    return U_multi_channel_with_ct

# 1.3 SLM and Scattering Matrices
def construct_U_multi_channel_slm(N_slm, phase_mod, amp_mod, t_steps):
    U_slm = jnp.zeros((t_steps, N_slm, N_slm), dtype=jnp.complex64)
    for t in range(t_steps):
        for i in range(N_slm):
            phase = jnp.exp(1j * phase_mod[i])
            amplitude = amp_mod[i]
            U_slm = U_slm.at[t, i, i].set(amplitude * phase)
    return U_slm

def construct_I_prime(N_scat, delta, t_steps):
    I_prime = jnp.zeros((t_steps, N_scat, N_scat), dtype=jnp.complex64)
    for t in range(t_steps):
        I_prime = I_prime.at[t].set(jnp.eye(N_scat) + jnp.diag(jnp.full(N_scat, delta)))
    return I_prime

# 1.4 E-field Calculations
def lg00_mode_profile(X, Y, beam_center, beam_waist):
    cx, cy = beam_center
    r_squared = (X - cx)**2 + (Y - cy)**2
    E_profile = jnp.exp(-r_squared / (2 * beam_waist**2))
    return E_profile

def compute_E_field_for_channel(X, Y, E_t, beam_center, beam_waist, t_idx):
    E_profile = lg00_mode_profile(X, Y, beam_center, beam_waist)
    E_field = E_profile * (jnp.real(E_t) + 1j * jnp.imag(E_t))  # Corrected indexing
    return E_field

def compute_total_E_field_profile(X, Y, b_slm_out, beam_centers, beam_waist):
    E_field_profiles = []
    for t_idx in range(b_slm_out.shape[0]):
        E_field_total = jnp.zeros_like(X, dtype=jnp.complex64)
        for atom_index, beam_center in enumerate(beam_centers):
            E_field_total += compute_E_field_for_channel(X, Y, b_slm_out[t_idx, atom_index], beam_center, beam_waist, t_idx)
        E_field_profiles.append(E_field_total)
    return jnp.array(E_field_profiles)

def extract_E_field_at_atoms(E_field_profiles, atom_positions, X, Y):
    E_field_at_atoms = []

    for t_idx in range(E_field_profiles.shape[0]):
        E_field_at_timestep = []
        for x0, y0 in atom_positions:
            x_idx = jnp.argmin(jnp.abs(X[0, :] - x0))
            y_idx = jnp.argmin(jnp.abs(Y[:, 0] - y0))
            E_field_at_timestep.append(E_field_profiles[t_idx][y_idx, x_idx])
        E_field_at_atoms.append(jnp.array(E_field_at_timestep))
    return jnp.array(E_field_at_atoms)

def compute_alpha_t(E_fields_at_atoms, dipoles, Omega_prefactor_MHz):
    alpha_t = []
    for t_idx in range(E_fields_at_atoms.shape[0]):
        alpha_t_timestep = []
        for atom_idx in range(len(dipoles)):
            E_field_atom = E_fields_at_atoms[t_idx, atom_idx]
            dipole = dipoles[atom_idx]
            alpha_t_timestep.append(Omega_prefactor_MHz * dipole[0] * E_field_atom)
        alpha_t.append(jnp.array(alpha_t_timestep))
    return jnp.array(alpha_t)


# ========================================
# Part 2: Atomic Model
# ========================================

# Two levels per qubit
N_qubit_level = 2

# Define multi-qubit operators and Hamiltonians (as described previously)
s_plus = jnp.array([[0, 1], [0, 0]], dtype=jnp.complex64)
s_minus = jnp.array([[0, 0], [1, 0]], dtype=jnp.complex64)
s_z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64)

def construct_multi_qubit_operator(single_qubit_op, N_a, N_qubit_level, qubit_idx):
    I = jnp.eye(N_qubit_level, dtype=jnp.complex64)
    op_list = [I] * N_a
    op_list[qubit_idx] = single_qubit_op

    multi_qubit_op = op_list[0]
    for op in op_list[1:]:
        multi_qubit_op = jnp.kron(multi_qubit_op, op)

    return multi_qubit_op

# Creation and annihilation operators for the quantized field (Fock mode)
def construct_annihilation_operator(fock_dim):
    a = jnp.zeros((fock_dim, fock_dim), dtype=jnp.complex64)
    for n in range(1, fock_dim):
        a = a.at[n - 1, n].set(jnp.sqrt(n))
    a_dag = a.T.conj()  # Creation operator is the Hermitian conjugate of annihilation operator
    return a, a_dag

# Fock space dimensions
fock_dim = 5
a, a_dag = construct_annihilation_operator(fock_dim)
I_fock = jnp.eye(fock_dim, dtype=jnp.complex64)

# Drift Hamiltonian (H_0)
def construct_H_0(N_a, omega_0, omega_r):  # omega_0/r in MHz
    H_0_qubits = sum(0.5 * omega_0 * construct_multi_qubit_operator(s_z, N_a, 2, i) for i in range(N_a))
    H_0_field = omega_r * jnp.kron(jnp.eye(2 ** N_a), a_dag @ a)
    return jnp.kron(H_0_qubits, I_fock) + H_0_field

# Control Hamiltonian
def construct_H_control(N_a, N_qubit_level, g_real_t, g_imag_t):
    H_control = jnp.zeros((N_qubit_level ** N_a * fock_dim, N_qubit_level ** N_a * fock_dim), dtype=jnp.complex64)
    for i in range(N_a):
        H_control += g_real_t[i] * (jnp.kron(construct_multi_qubit_operator(s_plus, N_a, N_qubit_level, i), a) +
                                    jnp.kron(construct_multi_qubit_operator(s_minus, N_a, N_qubit_level, i), a_dag))
        H_control += g_imag_t[i] * (1j * jnp.kron(construct_multi_qubit_operator(s_plus, N_a, N_qubit_level, i), a) -
                                    1j * jnp.kron(construct_multi_qubit_operator(s_minus, N_a, N_qubit_level, i), a_dag))
    return H_control

# Time Evolution Hamiltonian (H(t))
def construct_H_time(N_a, N_qubit_level, omega_0, omega_r, g_real_t, g_imag_t,
                   atom_positions, gate_type='single'):
    """Construct the time-dependent Hamiltonian H(t) as a JAX array."""
    H_t_list = []  # Store Hamiltonians for each time step

    # Construct base Hamiltonians
    H_0 = construct_H_0(N_a, omega_0, omega_r)

    # Construct H(t) for each time step
    for t in range(len(g_real_t)):
        H_control = construct_H_control(N_a, N_qubit_level, g_real_t[t], g_imag_t[t])
        H_t = H_0 + H_control
        H_t_list.append(H_t)

    # Convert list of Hamiltonians to a JAX array
    return jnp.array(H_t_list)

# Compute accumulated propagator (U(t)) over time
def compute_accumulated_propagator(H_t, dt, N_a, N_qubit_level):
    dim = N_qubit_level ** N_a * fock_dim
    U_accumulated = jnp.eye(dim, dtype=jnp.complex64)
    U_t_all = []
    for H in H_t:
        U_t = jla.expm(-1j * H * dt)
        U_accumulated = U_t @ U_accumulated
        U_t_all.append(U_accumulated)
    return jnp.array(U_t_all)

# Trace out field from unitary matrix
def trace_out_field_from_unitary(U_t, N_a, N_qubit_level, field_dim=5):
    dim_qubits = N_qubit_level ** N_a
    U_t_reshaped = U_t.reshape(dim_qubits, field_dim, dim_qubits, field_dim)
    U_t_traced = jnp.sum(U_t_reshaped, axis=(1, 3))
    return U_t_traced

# Fidelity Calculation
def compute_fidelity_unitary(U_t_traced, U_target):
    """
    Computes the fidelity of the unitary evolution compared to the target unitary
    using the standard multi-qubit gate fidelity formula.

    Parameters:
    - U_t_traced: Reduced unitary matrix for the qubit system after tracing out the field.
    - U_target: Target unitary matrix.

    Returns:
    - fidelity: The fidelity between the evolved and target unitary matrices.
    """
    # Dimension of the Hilbert space
    d = U_target.shape[0]  # For N qubits, d = 2^N

    # Compute fidelity using the standard formula
    trace_overlap = jnp.abs(jnp.trace(U_target.conj().T @ U_t_traced))**2
    fidelity = trace_overlap / (d**2)

    return jnp.real(fidelity)

# record F(t)
def compute_fidelity_unitary_t_all(U_t_traced_all, U_target):
    """
    Computes the fidelity of the unitary evolution for all time steps
    compared to the target unitary using the standard multi-qubit gate fidelity formula.

    Parameters:
    - U_t_traced_all: A list or array of reduced unitary matrices for the qubit system at all time steps.
    - U_target: Target unitary matrix.

    Returns:
    - fidelity_all: Array of fidelities at each time step.
    """
    # Dimension of the Hilbert space
    d = U_target.shape[0]  # For N qubits, d = 2^N

    def fidelity_at_timestep(U_t_traced):
        trace_overlap = jnp.abs(jnp.trace(U_target.conj().T @ U_t_traced))**2
        return trace_overlap / (d**2)

    # Vectorized computation of fidelity for all time steps
    fidelity_all = jax.vmap(fidelity_at_timestep)(U_t_traced_all)
    return fidelity_all


# Clifford gates
def clifford_group_and_t_gate():
    """
    Constructs the single-qubit Clifford group and includes the T gate.

    Returns:
    - clifford_group: List of tuples containing the gate name and matrix.
    """
    # Define single-qubit gates
    I = jnp.array([[1, 0], [0, 1]], dtype=jnp.complex64)  # Identity
    H = (1 / jnp.sqrt(2)) * jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64)  # Hadamard
    S = jnp.array([[1, 0], [0, 1j]], dtype=jnp.complex64)  # Phase gate
    X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex64)  # Pauli-X
    Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex64)  # Pauli-Y
    Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64)  # Pauli-Z
    T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex64)  # T gate (π/8 rotation)

    # Clifford group: combinations of I, H, S, X, Y, Z
    # The single-qubit Clifford group consists of 24 elements.
    clifford_group = [
        ("X", X),
        ("I", I),
        ("H", H),
        ("S", S),
        ("Y", Y),
        ("Z", Z),
        ("HS", H @ S),
        ("SH", S @ H),
        ("HX", H @ X),
        ("SX", S @ X),
        ("SY", S @ Y),
        ("SZ", S @ Z),
        ("XH", X @ H),
        ("YH", Y @ H),
        ("ZH", Z @ H),
        ("XS", X @ S),
        ("YS", Y @ S),
        ("ZS", Z @ S),
        ("HSX", H @ S @ X),
        ("HSY", H @ S @ Y),
        ("HSZ", H @ S @ Z),
        ("XHS", X @ H @ S),
        ("YHS", Y @ H @ S),
        ("ZHS", Z @ H @ S),
    ]

    # Include T gate for extended functionality
    clifford_group_with_t = clifford_group + [("T", T)]

    return clifford_group_with_t


# ========================================
# Part 3: Function to Compute Gate Fidelity
# ========================================

def compute_multi_qubit_fidelity_closed_system(
    V0_t_list, V1_t_list, L, n0, lambda_0, a0, t0, a1, t1, phase_mod, amp_mod, delta,
    atom_positions, dipoles, beam_centers, beam_waist, X, Y, Omega_prefactor_MHz,
    t_steps, dt,
    N_ch, distances, coupling_lengths, n_eff_list, kappa0, alpha, enable_crosstalk,
    N_slm, N_ch_slm_in, N_scat_1, N_scat_2, N_a, N_qubit_level, omega_0, omega_r, a_pic, a_scat_1,
    U_target, gate_type='single'
):
    """
    Computes the multi-qubit fidelity for a closed system with a fixed tmax.

    Parameters:
        V0_t_list, V1_t_list: Control voltage time-series for each channel.
        L, n0, lambda_0, a0, t0, a1, t1: Physical parameters of the MZM.
        phase_mod, amp_mod: Phase and amplitude modulation parameters.
        delta: Detuning.
        atom_positions, dipoles: Atom positions and dipole moments.
        beam_centers, beam_waist: Parameters of the spatial light modulator (SLM).
        X, Y: 2D grid for the beam field.
        Omega_prefactor_MHz: Prefactor for Rabi frequency in MHz.
        t_steps, dt: Time parameters.
        N_ch: Number of photonic channels.
        distances: Pairwise distances between channels (µm).
        coupling_lengths: Pairwise coupling lengths (µm).
        n_eff_list: Effective refractive indices for the waveguides.
        kappa0, alpha: Crosstalk parameters.
        enable_crosstalk (bool): If False, crosstalk is disabled.
        N_slm, N_ch_slm_in, N_scat_1, N_scat_2, N_a, N_qubit_level: System configuration.
        omega_0, omega_r: Frequencies in MHz.
        a_pic, a_scat_1: Initial fields.
        U_target: Target unitary matrix.
        gate_type: Type of quantum gate ('single', 'multi', etc.).

    Returns:
        Fidelity of the computed unitary matrix with respect to the target matrix.
    """
    # Compute Unitary Matrix for Multi-Channel System
    U_system_multi_channel = jnp.array([
        U_drmzm_multi_channel(
            [V0_t[i] for V0_t in V0_t_list],
            [V1_t[i] for V1_t in V1_t_list],
            L, n0, lambda_0, a0, t0, a1, t1,
            N_ch, distances, coupling_lengths, n_eff_list,
            kappa0, alpha, enable_crosstalk
        )
        for i in range(t_steps)
    ])

    # Compute scattering and SLM matrices
    I1_prime = construct_I_prime(N_scat_1, delta, t_steps)
    U_tensor_product = jnp.array([jnp.kron(U_system_multi_channel[i], I1_prime[i]) for i in range(t_steps)])

    # Output modes calculations
    a_total = jnp.kron(a_pic, a_scat_1)
    output_modes = jnp.array([jnp.dot(U_tensor_product[i], a_total) for i in range(t_steps)])
    output_modes_reshaped = output_modes.reshape(t_steps, N_ch, N_scat_1)
    a_pic_out = jnp.sum(output_modes_reshaped, axis=-1)
    a_scat_1_out = jnp.sum(output_modes_reshaped, axis=-2)

    # Compute final output modes after SLM and second scattering
    b_slm_in = jnp.zeros((t_steps, N_slm), dtype=jnp.complex64)
    b_scat2_in = jnp.zeros((t_steps, N_scat_2), dtype=jnp.complex64)

    for t in range(t_steps):
        b_slm_in = b_slm_in.at[t, :N_ch_slm_in].set(a_pic_out[t, :N_ch_slm_in])
        b_slm_in = b_slm_in.at[t, N_ch_slm_in:].set(a_scat_1_out[t, :(N_slm - N_ch_slm_in)])

        b_scat2_in = b_scat2_in.at[t, :(N_ch - N_ch_slm_in)].set(a_pic_out[t, N_ch_slm_in:])
        b_scat2_in = b_scat2_in.at[t, (N_ch - N_ch_slm_in):].set(a_scat_1_out[t, (N_scat_2 - (N_ch - N_ch_slm_in)):])

    b_total_in = jnp.array([jnp.kron(b_slm_in[t], b_scat2_in[t]) for t in range(t_steps)])
    U_multi_channel_slm = construct_U_multi_channel_slm(N_slm, phase_mod, amp_mod, t_steps)
    I2_prime = construct_I_prime(N_scat_2, delta, t_steps)
    U_tensor_product_stage_2 = jnp.array([jnp.kron(U_multi_channel_slm[t], I2_prime[t]) for t in range(t_steps)])
    b_total_out = jnp.array([jnp.dot(U_tensor_product_stage_2[t], b_total_in[t]) for t in range(t_steps)])

    b_total_out_reshaped = b_total_out.reshape(t_steps, N_slm, N_scat_2)
    b_slm_out = jnp.sum(b_total_out_reshaped, axis=-1)

    # Compute the total E-field on the atom plane
    E_field_profiles = compute_total_E_field_profile(X, Y, b_slm_out, beam_centers, beam_waist)

    # Extract E-fields at the atom positions
    E_fields_at_atoms = extract_E_field_at_atoms(E_field_profiles, atom_positions, X, Y)

    # Compute interaction strength α(t) for each atom
    alpha_t = compute_alpha_t(E_fields_at_atoms, dipoles, Omega_prefactor_MHz)

    # g_real_t and g_imag_t represent g(t) in the Jaynes-Cummings model, g(t) = eff_Rabi(t)/2
    g_real_t = alpha_t.real / 2
    g_imag_t = alpha_t.imag / 2

    # Construct the time-dependent Hamiltonian
    H_t = construct_H_time(N_a, N_qubit_level, omega_0, omega_r, g_real_t, g_imag_t, atom_positions, gate_type)

    # Evolve the unitary matrix (propagator)
    U_t_all = compute_accumulated_propagator(H_t, dt, N_a, N_qubit_level)

    # Trace out the field from the final unitary matrix
    U_t_traced = trace_out_field_from_unitary(U_t_all[-1], N_a, N_qubit_level)

    # Compute fidelity with the target gate
    fidelity = compute_fidelity_unitary(U_t_traced, U_target)

    return fidelity

# ========================================
# Part 4: Program Instruction function
# ========================================
def program_instruction(N_a, key_number, gate_type='single', selected_atoms=None, control_atom=None, target_atom=None):
    """
    Constructs the target gates for selected atoms. If `gate_type` is 'single', it applies single-qubit gates.
    Logs the gate applied to each atom.
    """
    if gate_type == 'single':
        # Define the Clifford gates for single-qubit gates
        clifford_t_gates = clifford_group_and_t_gate()
        U_target_multi_qubit = []
        key = jrandom.PRNGKey(key_number)  # Initialize the random key for gate selection

        # Apply random single-qubit gates to selected atoms
        for atom_idx in range(N_a):
            if selected_atoms is None or atom_idx in selected_atoms:
                key, subkey = jrandom.split(key)  # Split key for each atom to ensure different gates
                gate_idx = jrandom.randint(subkey, (), minval=0, maxval=len(clifford_t_gates))
                gate_name, gate = clifford_t_gates[gate_idx]
                logger.info(f"Applying {gate_name} gate to Atom {atom_idx}")
                U_target_multi_qubit.append(gate)
            else:
                logger.info(f"Applying Identity gate to Atom {atom_idx}")
                U_target_multi_qubit.append(jnp.eye(2))  # Identity gate for unselected atoms

        # Combine the gates using the Kronecker product
        U_target = U_target_multi_qubit[0]
        for gate in U_target_multi_qubit[1:]:
            U_target = jnp.kron(U_target, gate)

        return U_target

# ========================================
# Part 5: Optimization : SADE-ADAM
# ========================================

# Augmented Lagrangian-based multi-qubit objective function
def multi_qubit_objective(
    V_combined, APIC_params, atom_beam_params, control_Vt_params, system_params  # Removed penalty_weights
):
    # V_combined now only contains control signals, not tmax
    # Reshape control signals
    half_len = V_combined.shape[0] // 2
    V0_t_list = V_combined[:half_len].reshape(system_params['N_ch'], control_Vt_params['t_steps'])
    V1_t_list = V_combined[half_len:].reshape(system_params['N_ch'], control_Vt_params['t_steps'])

    # Compute fidelity with the fixed tmax
    fidelity = compute_multi_qubit_fidelity_closed_system(
        V0_t_list, V1_t_list, APIC_params['L'], APIC_params['n0'],
        APIC_params['lambda_0'], APIC_params['a0'], APIC_params['t0'],
        APIC_params['a1'], APIC_params['t1'], APIC_params['phase_mod'],
        APIC_params['amp_mod'], system_params['delta'],
        atom_beam_params['atom_positions'], atom_beam_params['dipoles'],
        atom_beam_params['beam_centers'], atom_beam_params['beam_waist'],
        atom_beam_params['X'], atom_beam_params['Y'], atom_beam_params['Omega_prefactor_MHz'],
        control_Vt_params['t_steps'], control_Vt_params['dt'],
        system_params['N_ch'], system_params['distances'], system_params['coupling_lengths'],
        system_params['n_eff_list'], system_params['kappa0'], system_params['alpha'],
        system_params['enable_crosstalk'],
        system_params['N_slm'], system_params['N_ch_slm_in'], system_params['N_scat_1'],
        system_params['N_scat_2'], system_params['N_a'],
        system_params['N_qubit_level'], system_params['omega_0'],
        system_params['omega_r'], atom_beam_params['a_pic'],
        atom_beam_params['a_scat_1'],
        system_params['U_target'], system_params['gate_type']
    )

    # Gate error
    gate_error = 1 - fidelity

    # Objective is to minimize gate_error
    objective_value = gate_error

    return objective_value

def adam_optimization(
    V_combined_init, APIC_params, atom_beam_params, control_Vt_params,
    system_params, optimizer_params, fidelity_gens=None
):
    # Gradient function for the objective
    grad_obj = grad(lambda V: multi_qubit_objective(
        V, APIC_params, atom_beam_params, control_Vt_params, system_params
    ))

    # Initialize parameters and optimizer
    params = V_combined_init
    current_lr = optimizer_params['adam_lr']  # Initialize current learning rate
    optimizer = optax.adam(current_lr)
    opt_state = optimizer.init(params)

    # Compute starting fidelity
    starting_fitness = multi_qubit_objective(
        params, APIC_params, atom_beam_params, control_Vt_params,
        system_params
    )
    starting_fidelity = 1 - starting_fitness
    logger.info(f"Initial Fidelity: {starting_fidelity:.10f}, Fixed tmax: {control_Vt_params['tmax']:.6f} us")

    # Initialize tracking variables
    best_fidelity = starting_fidelity
    fidelity_gens = fidelity_gens or [starting_fidelity]
    decay_idx = 0  # Pointer to the next threshold to trigger decay

    # Extract thresholds and decay factors
    thresholds = optimizer_params.get('fidelity_decay_thresholds', [0.95, 0.99, 0.995])
    decay_factors = optimizer_params.get('fidelity_decay_factors', [0.10, 0.50, 0.20])
    min_lr = optimizer_params.get('min_lr', 1e-6)

    # Adam optimization loop
    for step in range(optimizer_params['adam_steps']):
        # Compute gradients and update parameters
        grads = grad_obj(params)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        # Evaluate current fidelity
        current_fitness = multi_qubit_objective(
            params, APIC_params, atom_beam_params, control_Vt_params,
            system_params
        )
        current_fidelity = 1 - current_fitness
        fidelity_gens.append(current_fidelity)

        if current_fidelity > 1.0:
            current_fidelity = -np.inf  
            opt_state = optimizer.init(params)  
            continue  

        # Log step status
        logger.info(f"Adam Step {step}, Fidelity: {current_fidelity:.10f}, Learning Rate: {current_lr:.6f}")

        # Update the best fidelity
        if current_fidelity > best_fidelity:
            best_fidelity = current_fidelity

        # Decay learning rate when fidelity exceeds the next threshold
        if decay_idx < len(thresholds) and current_fidelity >= thresholds[decay_idx]:
            new_lr = current_lr * decay_factors[decay_idx]  # Apply the corresponding decay factor
            new_lr = max(new_lr, min_lr)  # Ensure the learning rate doesn't fall below min_lr
            logger.info(#f"Fidelity {current_fidelity:.4f} exceeded threshold {thresholds[decay_idx]:.2f}. "
                        f"Reducing learning rate from {current_lr:.6f} to {new_lr:.6f} using decay factor {decay_factors[decay_idx]:.2f}.")
            current_lr = new_lr
            optimizer = optax.adam(current_lr)
            opt_state = optimizer.init(params)  # Reinitialize optimizer state with the new learning rate
            decay_idx += 1  # Move to the next threshold

        # Early stopping based on fidelity threshold
        if current_fidelity >= 1 - optimizer_params['tol']:
            logger.info(f"Stopping at step {step}, Fidelity: {best_fidelity:.10f}")
            break

    return params, fidelity_gens

def evaluate_population_sade(
    population, APIC_params, atom_beam_params, control_Vt_params, system_params
):
    """
    Evaluates the fitness of each solution in the population during SaDE with a fixed tmax.
    """
    scores = []
    for idx, control_signals in enumerate(population):
        # Reshape control signals
        half_len = control_signals.shape[0] // 2
        V0_t_list = control_signals[:half_len].reshape(system_params['N_ch'], control_Vt_params['t_steps'])
        V1_t_list = control_signals[half_len:].reshape(system_params['N_ch'], control_Vt_params['t_steps'])

        # Compute fidelity with the fixed tmax
        fidelity = compute_multi_qubit_fidelity_closed_system(
            V0_t_list, V1_t_list, APIC_params['L'], APIC_params['n0'],
            APIC_params['lambda_0'], APIC_params['a0'], APIC_params['t0'],
            APIC_params['a1'], APIC_params['t1'], APIC_params['phase_mod'],
            APIC_params['amp_mod'], system_params['delta'],
            atom_beam_params['atom_positions'], atom_beam_params['dipoles'],
            atom_beam_params['beam_centers'], atom_beam_params['beam_waist'],
            atom_beam_params['X'], atom_beam_params['Y'], atom_beam_params['Omega_prefactor_MHz'],
            control_Vt_params['t_steps'], control_Vt_params['dt'],
            system_params['N_ch'], system_params['distances'], system_params['coupling_lengths'],
            system_params['n_eff_list'], system_params['kappa0'], system_params['alpha'],
            system_params['enable_crosstalk'],
            system_params['N_slm'], system_params['N_ch_slm_in'], system_params['N_scat_1'],
            system_params['N_scat_2'], system_params['N_a'],
            system_params['N_qubit_level'], system_params['omega_0'],
            system_params['omega_r'], atom_beam_params['a_pic'],
            atom_beam_params['a_scat_1'],
            system_params['U_target'], system_params['gate_type']
        )

        if fidelity > 1.0:
            scores.append(np.inf)  # Set fitness to infinity for rejection
        else:
            # SaDE minimizes (1 - fidelity)
            scores.append(1.0 - fidelity)

    return np.array(scores)

def optimize_multi_qubit_sade_adam(
    APIC_params, atom_beam_params, control_Vt_params, system_params, optimizer_params, V0_init, V1_init
):
    """
    Runs the SaDE optimization followed by Adam fine-tuning with the new objective.
    """

    def run_sade(V_combined_init):
        """Run SaDE to minimize gate error 1 - F with a fixed tmax."""

        # Fixed tmax
        fixed_tmax = control_Vt_params['tmax']
        logger.info(f"Fixed tmax for SaDE: {fixed_tmax:.6f} us")

        # Initialize population with control signals only (exclude tmax from population)
        population = [
            np.clip(
                V_combined_init + 0.2 * np.random.randn(*V_combined_init.shape),
                control_Vt_params['min_V_level'], control_Vt_params['max_V_level']
            )
            for _ in range(optimizer_params['popsize'])
        ]

        # Evaluate population fitness using the fixed tmax
        fitness = evaluate_population_sade(
            population, APIC_params, atom_beam_params, control_Vt_params, system_params
        )

        best_solution = population[np.argmin(fitness)]
        best_fitness = np.min(fitness)
        fidelity_gens = [1.0 - best_fitness]  # Track fidelity over generations

        logger.info(f"Initial Fidelity: {fidelity_gens[-1]:.10f}, Fixed tmax: {fixed_tmax:.6f} us")

        # Run SaDE generations
        for generation in range(optimizer_params['num_generations']):
            new_population = []
            for i in range(optimizer_params['popsize']):
                # Mutation and crossover
                idxs = np.random.choice([j for j in range(optimizer_params['popsize']) if j != i], 3, replace=False)
                a, b, c = population[idxs[0]], population[idxs[1]], population[idxs[2]]
                F = np.random.uniform(0.1, 0.9)  # Mutation factor
                mutant = np.clip(a + F * (b - c), control_Vt_params['min_V_level'], control_Vt_params['max_V_level'])

                # Crossover
                CR = np.random.uniform(0.1, 0.9)  # Crossover rate
                trial = np.copy(population[i])
                crossover_mask = np.random.rand(*mutant.shape) < CR
                trial[crossover_mask] = mutant[crossover_mask]

                # Evaluate trial solution
                trial_fitness = evaluate_population_sade(
                    [trial], APIC_params, atom_beam_params, control_Vt_params, system_params
                )[0]

                if trial_fitness == np.inf:
                    # Retain the original individual
                    new_population.append(population[i])
                    continue  # Proceed to the next individual

                # Replace the existing individual with the trial if better
                if trial_fitness < fitness[i]:
                    new_population.append(trial)
                    fitness[i] = trial_fitness
                else:
                    new_population.append(population[i])

            population = new_population

            # Update best solution
            best_idx = np.argmin(fitness)
            best_solution = population[best_idx]
            best_fitness = fitness[best_idx]
            current_fidelity = 1.0 - best_fitness
            fidelity_gens.append(current_fidelity)

            logger.info(f"Generation {generation}, Fidelity: {current_fidelity:.10f}, Fixed tmax: {fixed_tmax:.6f} us")

            # Check if the fidelity threshold has been reached
            if current_fidelity >= optimizer_params['fidelity_threshold']:
                logger.info(f"Switching to Adam after reaching fidelity threshold: {current_fidelity:.10f}")
                break

        # Return the best solution and fidelity generations
        return best_solution, fidelity_gens

    # Initialize population and run SaDE
    V_combined_init = np.concatenate([V0_init.flatten(), V1_init.flatten()])
    best_solution, fidelity_gens = run_sade(V_combined_init)

    # Run Adam fine-tuning only if fidelity_threshold was reached
    if fidelity_gens[-1] >= optimizer_params['fidelity_threshold']:
        logger.info("\n--- Switching to Adam Optimizer ---\n")
        best_solution, fidelity_gens = adam_optimization(
            best_solution, APIC_params, atom_beam_params, control_Vt_params,
            system_params, optimizer_params,
            fidelity_gens
        )
    else:
        logger.info("\n--- Fidelity Threshold Not Reached. Skipping Adam Optimizer ---\n")

    return best_solution, fidelity_gens


# ========================================
# Part 6: main execution
# ========================================

# Main function to run the hybrid optimization
def run_multi_qubit_optimization():
    logger.info(f" --- Initial Setups --- \n")

    # Laser intensity setup
    I_mW_per_cm2 = 20  # Laser intensity in mW/cm²
    Detuning_MHz = 1000  # 1 GHz detuning
    Omega_prefactor_MHz = calculate_Omega_rabi_prefactor(I_mW_per_cm2, Detuning_MHz)
    logger.info(f"Omega_prefactor_MHz: {Omega_prefactor_MHz} \n")

    # Control system parameters
    m, n0, lambda_0 = 600, 1.95, 780e-9
    L = m * lambda_0 / n0
    a0, t0, a1, t1 = 0.998, 0.998, 0.998, 0.998

    # Define frequencies in MHz
    omega_0, omega_r = 6.835e3, 6.835e3

    # Time parameters [us]
    tmin = 0
    tmax_fixed = 0.1  # Fixed tmax in microseconds
    t_steps = 100
    dt = (tmax_fixed - tmin) / t_steps
    logger.info(f"Chosen dt: {dt * 1e3:.2f} ns")

    # System and channel setup
    N_ch, N_scat_1, N_slm, N_ch_slm_in = 6, 4, 6, 4
    N_total = N_ch + N_scat_1
    N_scat_2, N_a = N_total - N_slm, 3

    selected_atoms=[0, 1, 2]

    # Crosstalk flag
    enable_crosstalk = True  # Set to True to enable crosstalk

    if enable_crosstalk:
        # Inter-channel waveguide distances
        base_distance = 1.0  # µm (nominal distance between adjacent waveguides)
        distance_variation = 0.1  # Random variation in distances (± µm)
        distances = generate_distances(N_ch, base_distance, random_variation=distance_variation, seed=42)

        # Coupling lengths
        base_coupling_length = 600.0  # µm (nominal coupling length for adjacent waveguides)
        length_variation = 60.0  # Random variation in coupling lengths (± µm)
        coupling_lengths = generate_coupling_lengths(N_ch, base_coupling_length, scaling_factor=1.1,
                                                      random_variation=length_variation, seed=42)

        # Effective refractive indices
        base_n_eff = 1.75  # Nominal effective refractive index
        n_eff_variation = 0.05  # Random variation in effective index (±)
        n_eff_list = generate_n_eff_list(N_ch, base_n_eff, random_variation=n_eff_variation, seed=42)

        # Waveguide-Coupling Parameters
        kappa0 = 10.145  # Maximum coupling coefficient at zero separation
        alpha = 6.934    # Coupling decay rate per µm

        logger.info("Crosstalk parameters initialized:")
        logger.info(f"Distances: {distances}")
        logger.info(f"Coupling lengths: {coupling_lengths}")
        logger.info(f"Effective refractive indices: {n_eff_list}")
        logger.info(f"kappa0: {kappa0}, alpha: {alpha}")
    else:
        logger.info("Crosstalk is disabled.")
        distances = None
        coupling_lengths = None
        n_eff_list = None
        kappa0 = None
        alpha = None

    # 2D atomic grid and atom positions
    atom_spacing = 3.0  # µm
    X, Y = generate_grid(grid_size=600, grid_range=(-6, 6))
    atom_positions = positions = generate_atom_positions_equilateral(N_a, side_length=atom_spacing, center=(0, 0))
    logger.info(f"Atom positions: {atom_positions}, in [um]")

    # Beam and dipole parameters
    beam_centers, beam_waist = atom_positions, 2.0
    logger.info(f"Beam centers: {beam_centers}, Beam waist: {beam_waist} in [um]")
    dipoles = generate_dipoles(N_a)

    # SLM modulation parameters
    phase_mod, amp_mod = generate_slm_mod(N_slm)

    # Weak scattering: pic inter-channel
    delta = 0.001

    # Control signal limits
    min_V_level, max_V_level = -15, 15

    # Initial fields
    a_pic = jnp.array([1.0] * N_ch)
    a_scat_1 = jnp.array([1.0] * N_scat_1)

    # Initialize control signals V0_t_list and V1_t_list
    num_pieces = t_steps  # Number of discrete voltage levels
    V0_t_list, V1_t_list = [], []
    key = random.PRNGKey(256)

    for _ in range(N_ch):
        # Generate voltage levels for V0
        key, subkey = random.split(key)
        voltage_levels_V0 = random.uniform(subkey, shape=(num_pieces,), minval=-15.0, maxval=15.0)
        V0_t = construct_V_smooth_with_carrier(tmin, tmax_fixed, t_steps, voltage_levels_V0, omega_0)
        V0_t_list.append(V0_t)

        # Generate voltage levels for V1
        key, subkey = random.split(key)
        voltage_levels_V1 = random.uniform(subkey, shape=(num_pieces,), minval=-15.0, maxval=15.0)
        V1_t = construct_V_smooth_with_carrier(tmin, tmax_fixed, t_steps, voltage_levels_V1, omega_0)
        V1_t_list.append(V1_t)

    # Convert lists to JAX arrays
    V0_t_list = jnp.array(V0_t_list)
    V1_t_list = jnp.array(V1_t_list)

    logger.info(f"V0_t_list shape: {V0_t_list.shape}")
    logger.info(f"V1_t_list shape: {V1_t_list.shape}")

    # Save initial control signals (optional)
    pd.DataFrame(V0_t_list).to_csv('V0_initial_multi_qubit_withCT_3sQ_dwg_0p6um.csv', index_label='Channel')
    pd.DataFrame(V1_t_list).to_csv('V1_initial_multi_qubit_withCT_3sQ_dwg_0p6um.csv', index_label='Channel')

    # Select the gate type and generate target gate
    gate_type = 'single'
    if gate_type == 'single':
        U_target = program_instruction(N_a, key_number=91, gate_type=gate_type, selected_atoms=selected_atoms)

    logger.info(f"U_target: {U_target}")

    # Define parameter dictionaries
    APIC_params = {
        'L': L, 'n0': n0, 'lambda_0': lambda_0, 'a0': a0, 't0': t0,
        'a1': a1, 't1': t1, 'phase_mod': phase_mod, 'amp_mod': amp_mod
    }

    atom_beam_params = {
        'atom_positions': atom_positions, 'dipoles': dipoles,
        'a_pic': a_pic, 'a_scat_1': a_scat_1, 'beam_centers': beam_centers,
        'beam_waist': beam_waist, 'X': X, 'Y': Y, 'Omega_prefactor_MHz': Omega_prefactor_MHz,
    }

    control_Vt_params = {
        'tmin': tmin,
        'tmax': tmax_fixed,  # Fixed tmax
        't_steps': t_steps,
        'dt': dt,
        'min_V_level': min_V_level,
        'max_V_level': max_V_level
    }

    system_params = {
        'N_ch': N_ch, 'distances': distances,
        'coupling_lengths': coupling_lengths,
        'n_eff_list': n_eff_list,
        'enable_crosstalk': enable_crosstalk,
        'kappa0': kappa0, 'alpha': alpha,
        'N_slm': N_slm, 'N_ch_slm_in': N_ch_slm_in, 'N_scat_1': N_scat_1,
        'N_scat_2': N_scat_2, 'N_a': N_a, 'N_qubit_level': 2, 'omega_0': omega_0,
        'omega_r': omega_r, 'U_target': U_target, 'gate_type': gate_type, 'delta': delta
    }

    optimizer_params = {
        'num_generations': 500,
        'popsize': 10,
        'adam_steps': 500,
        'adam_lr': 0.0001,
        'fidelity_threshold': 0.95,
        'max_no_improvement': 5,
        'fidelity_decay_factors': [0.5, 0.2, 0.5, 0.2],
        'min_lr': 1e-6,          # Minimum learning rate
        'fidelity_decay_thresholds': [0.98, 0.99, 0.995, 0.997],  # Multiple fidelity thresholds
        'tol': 1e-3,
        'stability_steps': 10,
        'tmax_tolerance': 1e-6
    }

    # Start optimization
    start_time = time.time()
    logger.info("\n--- Starting multi-qubit gate optimization, without CrossTalk...\n")

    best_solution, fidelity_gens = optimize_multi_qubit_sade_adam(
        APIC_params, atom_beam_params, control_Vt_params, system_params,
        optimizer_params, V0_t_list, V1_t_list
    )

    # Extract optimized control signals (tmax is fixed)
    half_len = (best_solution.shape[0]) // 2
    V0_opt = best_solution[:half_len].reshape(N_ch, t_steps)
    V1_opt = best_solution[half_len:].reshape(N_ch, t_steps)
    tmax_opt = control_Vt_params['tmax']  # Fixed tmax

    # End timing and print results
    execution_time = (time.time() - start_time) / 60.0
    logger.info(f"Optimization completed in {execution_time:.2f} minutes.")
    logger.info(f"Fixed tmax: {tmax_opt:.6f} us")
    logger.info(f"Final optimized fidelity: {fidelity_gens[-1]:.10f}")

    # Save optimized control signals
    pd.DataFrame(V0_opt).to_csv('V0_optimal_multi_qubit_withCT_3sQ_dwg_0p6um.csv', index_label='Channel')
    pd.DataFrame(V1_opt).to_csv('V1_optimal_multi_qubit_withCT_3sQ_dwg_0p6um.csv', index_label='Channel')

    # Plot fidelity over generations
    # Create a list of steps corresponding to each generation
    steps = list(range(1, len(fidelity_gens) + 1))  # Starting steps from 1 instead of 0

    # Create a pandas DataFrame with 'Step' and 'Fidelity' columns
    df = pd.DataFrame({
        'Step': steps,
        'Fidelity': fidelity_gens
    })

    # Save the DataFrame to a CSV file
    csv_filename = "fidelity_over_generations_multi_qubit_withCT_1sQ_XII.csv"
    df.to_csv(csv_filename, index=False)
    logger.info(f"Data successfully saved to {csv_filename}")

    # Plot fidelity over generations
    plt.figure(figsize=(10, 6))
    plt.plot(df['Step'], df['Fidelity'], label="Multi-Qubit Fidelity")
    plt.title("Fidelity over Generations (SaDE + Adam)")
    plt.xlabel("Step")
    plt.ylabel("Fidelity")
    plt.legend()
    plt.savefig("fidelity_over_generations_multi_qubit_withCT_3sQ_dwg_0p6um.png")
    plt.show()

# ========================================
# Entry Point
# ========================================
if __name__ == "__main__":
    run_multi_qubit_optimization()
