#### note: change to a GPU runtim type

In [1]:
!pip install optax pandas gymnasium torch stable-baselines3

Collecting stable-baselines3
  Downloading stable_baselines3-2.5.0-py3-none-any.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (

In [None]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.scipy.linalg as jla
import jax
import matplotlib.pyplot as plt
from jax import vmap
from jax import random
import numpy as np
import optax
import time
import pandas as pd
from jax import grad
import math
import logging
from collections import deque, defaultdict

import os
import sys
from typing import Optional

import gymnasium as gym  # Updated import for Gymnasium
from gymnasium import spaces
import torch
import torch.nn as nn  # Ensure nn is imported correctly
import optax
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_checker import check_env

import pickle  # For saving model parameters

# ========================================
# Logging Configuration
# ========================================

# ----- Logging Configuration (Print to Console) -----
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Suppress TensorFlow/XLA logs

# Create a logger and set level to INFO
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Remove any existing handlers
if logger.hasHandlers():
    logger.handlers.clear()

# Create a stream handler that outputs to sys.stdout
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO)

# Create formatter and add it to the stream handler
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)

# ========================================
# Part 1: Control System Model (PIC + SLM)
# ========================================

# Constants and Parameters
def calculate_Omega_rabi_prefactor(I_mW_per_cm2, Detuning_MHz):
    """
    Calculate the effective Rabi frequency prefactor (MHz) based on laser intensity and detuning.

    Parameters:
    - I_mW_per_cm2: Laser intensity in mW/cm²
    - Detuning_MHz: Detuning in MHz

    Returns:
    - Omega_effRabi_prefactor_MHz: Effective Rabi frequency prefactor in MHz
    """
    hbar = 1.0545718e-34  # Reduced Planck constant (J·s)
    c = 3e8  # Speed of light in vacuum (m/s)
    epsilon_0 = 8.854187817e-12  # Permittivity of free space (F/m)
    mu0e = 3.336e-29  # |0> <-> |e> (C·m)
    mu1e = 3.400e-29  # |1> <-> |e> (C·m)

    I = I_mW_per_cm2 * 10  # Convert intensity from mW/cm² to W/m²
    E_0 = jnp.sqrt(2 * I / (c * epsilon_0))  # Electric field in V/m
    Omega_0e_prefactor = (mu0e * E_0) / hbar
    Omega_1e_prefactor = (mu1e * E_0) / hbar
    Omega_effRabi_prefactor = (Omega_0e_prefactor * Omega_1e_prefactor) / (2 * Detuning_MHz * 1e6)
    Omega_effRabi_prefactor_MHz = Omega_effRabi_prefactor / 1e6  # Convert to MHz
    return Omega_effRabi_prefactor_MHz


# Grid and atom positions
def generate_grid(grid_size=600, grid_range=(-6, 6)):
    """
    Generates a 2D grid using meshgrid for atom positioning.

    Args:
    grid_size (int): Number of points on the grid (default is 600).
    grid_range (tuple): Range of x and y coordinates (default is (-6, 6)).

    Returns:
    X, Y: Meshgrid for the x and y coordinates.
    """
    x = jnp.linspace(grid_range[0], grid_range[1], grid_size - 1)
    y = jnp.linspace(grid_range[0], grid_range[1], grid_size - 1)
    X, Y = jnp.meshgrid(x, y)
    return X, Y

# Function to generate atom positions
def generate_atom_positions_equilateral(N_a, side_length=3.0, center=(0, 0), triangle_spacing_factor=2.0):
    """
    Generate atom positions for any integer number of atoms (N_a) in equilateral triangular patterns.

    Parameters:
    - N_a (int): Total number of atoms.
    - side_length (float): Length of each side of the equilateral triangle.
    - center (tuple): Center of the first triangle (x_center, y_center).
    - triangle_spacing_factor (float): Distance between triangle centers, as a multiple of side_length.

    Returns:
    - positions (list of tuples): List of (x, y) positions for the atoms.
    """
    positions = []
    x_center, y_center = center  # Center of the first triangle
    height = (math.sqrt(3) / 2) * side_length  # Height of the equilateral triangle

    # Generate complete triangles
    num_full_triangles = N_a // 3
    for triangle_idx in range(num_full_triangles):
        # Calculate the center for the current triangle
        current_x_center = x_center + triangle_idx * triangle_spacing_factor * side_length
        current_y_center = y_center

        # Place atoms at the corners of the equilateral triangle
        positions.append((current_x_center - side_length / 2, current_y_center - height / 2))  # Bottom left
        positions.append((current_x_center + side_length / 2, current_y_center - height / 2))  # Bottom right
        positions.append((current_x_center, current_y_center + height / 2))  # Top vertex

    # Place remaining atoms (if any)
    remaining_atoms = N_a % 3
    if remaining_atoms > 0:
        # Get the last triangle's center
        if num_full_triangles > 0:
            last_x_center = x_center + (num_full_triangles - 1) * triangle_spacing_factor * side_length
            last_y_center = y_center
        else:
            last_x_center = x_center
            last_y_center = y_center

        # Add remaining atoms sequentially around the last triangle
        if remaining_atoms >= 1:
            positions.append((last_x_center - side_length / 2, last_y_center - height / 2))  # Bottom left
        if remaining_atoms == 2:
            positions.append((last_x_center + side_length / 2, last_y_center - height / 2))  # Bottom right

    return positions

# Dipoles for each atom
def generate_dipoles(N_a):
    dipoles = [jnp.array([1.0, 0]) for _ in range(N_a)]
    return dipoles

# Function to generate SLM modulations
def generate_slm_mod(N_slm):
    phase_mod = jnp.zeros(N_slm)
    amp_mod = jnp.ones(N_slm)
    return phase_mod, amp_mod

# Functions for crosstalk model setup (with fabrication variations considered)
def generate_distances(num_channels, pitch, random_variation=0.0, seed=None):
    """
    Generate pairwise distances for waveguides arranged in a linear layout with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        pitch (float): Spacing between adjacent waveguides (µm).
        random_variation (float): Maximum random variation (±) for distances.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        jax.numpy.array: Pairwise distances matrix (µm).
    """
    if seed is not None:
        np.random.seed(seed)

    distances = np.zeros((num_channels, num_channels), dtype=np.float32)
    for i in range(num_channels):
        for j in range(num_channels):
            if i != j:
                # Add randomness to the calculated distance
                base_distance = abs(i - j) * pitch
                variation = np.random.uniform(-random_variation, random_variation)
                distances[i, j] = base_distance + variation
    return jnp.array(distances)

def generate_coupling_lengths(num_channels, base_length, scaling_factor=1.0, random_variation=0.0, seed=None):
    """
    Generate coupling lengths with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        base_length (float): Base coupling length for adjacent waveguides (µm).
        scaling_factor (float): Multiplier to scale coupling length with increasing separation.
        random_variation (float): Maximum random variation (±) for coupling lengths.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        jax.numpy.array: Coupling lengths matrix (µm).
    """
    if seed is not None:
        np.random.seed(seed)

    coupling_lengths = np.zeros((num_channels, num_channels), dtype=np.float32)
    for i in range(num_channels):
        for j in range(num_channels):
            if i != j:
                base_length_ij = base_length * (scaling_factor**abs(i - j))
                variation = np.random.uniform(-random_variation, random_variation)
                coupling_lengths[i, j] = base_length_ij + variation
    return jnp.array(coupling_lengths)

def generate_n_eff_list(num_channels, base_n_eff, random_variation=0.0, seed=None):
    """
    Generate an array of effective refractive indices with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        base_n_eff (float): Base effective refractive index for all waveguides.
        random_variation (float): Maximum random variation (±) around the base value.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        numpy.ndarray: Array of effective refractive indices (n_eff_list).
    """
    if seed is not None:
        np.random.seed(seed)

    variations = np.random.uniform(-random_variation, random_variation, num_channels)
    n_eff_list = base_n_eff + variations
    return jnp.array(n_eff_list)


# 1.1 Control Signal Construction

def construct_V_smooth_with_carrier(tmin, tmax, t_steps, voltage_levels, omega_0, phi=0, max_step=30):
    """
    Constructs a time-dependent control signal with piecewise smooth voltage levels and a carrier frequency component.

    Parameters:
    - tmin (float): Start time.
    - tmax (float): End time.
    - t_steps (int): Number of time steps.
    - voltage_levels (array): Voltage levels for each piece.
    - omega_0 (float): Carrier angular frequency (rad/s).
    - phi (float, optional): Initial phase of the carrier (radians). Default is 0.
    - max_step (float, optional): Maximum allowed voltage change between consecutive pieces. Default is 30.

    Returns:
    - V_t (array): Time-dependent control voltage with carrier modulation.
    """
    time_points = jnp.linspace(tmin, tmax, t_steps)
    num_pieces = len(voltage_levels)
    piece_duration = t_steps // num_pieces

    # Ensure voltage_levels is a JAX array
    voltage_levels = jnp.array(voltage_levels, dtype=jnp.float32)

    V_piecewise = jnp.zeros_like(time_points)

    for i in range(num_pieces):
        if i > 0:
            delta_V = voltage_levels[i] - voltage_levels[i - 1]
            if jnp.abs(delta_V) > max_step:
                # Limit the voltage change to max_step
                voltage_levels = voltage_levels.at[i].set(
                    voltage_levels[i - 1] + jnp.sign(delta_V) * max_step
                )

        start_idx = i * piece_duration
        # Ensure the last piece covers any remaining time steps
        end_idx = (i + 1) * piece_duration if i < num_pieces - 1 else t_steps
        V_piecewise = V_piecewise.at[start_idx:end_idx].set(
            jnp.full(end_idx - start_idx, voltage_levels[i])
        )

    # Carrier modulation
    carrier = jnp.cos(omega_0 * time_points + phi)
    V_t = V_piecewise * carrier

    return V_t


# 1.2 Unitary Matrix Construction
def dn(V):
    return 4e-5 * V

def phi_func(L, n, dn_val, lambda_):
    return 2 * jnp.pi * L * (n + dn_val) / lambda_

def D_func(a, t_val, phi_val):
    return jnp.exp(1j * jnp.pi) * (a - t_val * jnp.exp(-1j * phi_val)) / (1 - t_val * a * jnp.exp(-1j * phi_val))

def U_drmzm_single_channel(V0, V1, L, n0, lambda_0, a0, t0, a1, t1, psi_0=0):
    dn_0 = dn(V0)
    dn_1 = dn(V1)
    phi0 = phi_func(L, n0, dn_0, lambda_0)
    phi1 = phi_func(L, n0, dn_1, lambda_0)

    D_00 = D_func(a0, t0, phi0) * jnp.exp(psi_0)
    D_11 = D_func(a1, t1, phi1)

    U_drmzm_matrix = jnp.array([[D_00, jnp.zeros_like(D_00)], [jnp.zeros_like(D_11), D_11]])
    U_bs = (1 / jnp.sqrt(2)) * jnp.array([[1, 1], [1, -1]])
    U_total = jnp.dot(U_bs, jnp.dot(U_drmzm_matrix, U_bs))

    return U_total

def U_drmzm_multi_channel(
    V0_t_list, V1_t_list, L, n0, lambda_0, a0, t0, a1, t1,
    N_ch, distances, coupling_lengths, n_eff_list, kappa0, alpha, enable_crosstalk=True
):
    """
    Constructs the multi-channel unitary matrix with optional inter-channel coupling (amplitude and phase crosstalk).
    No Python 'if distances[i,j] > 0:' is used; we use jnp.where to avoid TracerBoolConversionError.
    """

    # Step 1: Construct U_multi_channel_no_ct (diagonal only, no crosstalk)
    U_multi_channel_no_ct = jnp.zeros((N_ch, N_ch), dtype=jnp.complex64)
    for i in range(N_ch):
        V0_t = V0_t_list[i]
        V1_t = V1_t_list[i]
        U_single_channel = U_drmzm_single_channel(
            V0_t, V1_t, L, n0, lambda_0, a0, t0, a1, t1
        )
        # Just place the (0,0) element on the diagonal
        U_multi_channel_no_ct = U_multi_channel_no_ct.at[i, i].set(U_single_channel[0, 0])

    # If crosstalk is disabled, return the diagonal matrix only
    if not enable_crosstalk:
        return U_multi_channel_no_ct

    # Step 2: Initialize crosstalk matrix as identity
    U_wg_coupling_ct = jnp.eye(N_ch, dtype=jnp.complex64)

    # Wave vector k = 2π / λ
    k = (2.0 * jnp.pi) / lambda_0

    def compute_transfer_matrix(L_val, beta_1, beta_2, kappa):
        """Compute transfer matrix for two coupled waveguides."""
        delta_beta = (beta_1 - beta_2) / 2.0
        kappa_eff = jnp.sqrt(kappa**2 + delta_beta**2)  # Effective coupling
        cos_term = jnp.cos(kappa_eff * L_val)
        sin_term = jnp.sin(kappa_eff * L_val)
        # Avoid Python 'if kappa_eff != 0:'; use a jnp.where if needed:
        safe_kappa_eff = jnp.where(kappa_eff != 0, kappa_eff, 1e-12)  # small epsilon
        delta_term = delta_beta / safe_kappa_eff

        return jnp.array([
            [cos_term - 1j * delta_term * sin_term, -1j * sin_term],
            [-1j * sin_term, cos_term + 1j * delta_term * sin_term]
        ], dtype=jnp.complex64)

    # Step 3: Build crosstalk off-diagonal with jnp.where instead of Python if
    for i in range(N_ch):
        beta_i = k * n_eff_list[i]
        for j in range(i + 1, N_ch):
            beta_j = k * n_eff_list[j]

            dist_ij = distances[i, j]                # distance
            L_ij_full = coupling_lengths[i, j]       # nominal coupling length
            # If dist_ij <= 0 => set them to 0.0 by jnp.where
            L_ij = jnp.where(dist_ij > 0.0, L_ij_full, 0.0)
            kappa_ij = jnp.where(dist_ij > 0.0,
                                 kappa0 * jnp.exp(-alpha * dist_ij),
                                 0.0)

            # Compute transfer matrix
            M = compute_transfer_matrix(L_ij, beta_i, beta_j, kappa_ij)

            # Update crosstalk matrix with amplitude and phase contributions
            # (0,1) element goes in [i,j], conj(...) goes in [j,i]
            U_wg_coupling_ct = U_wg_coupling_ct.at[i, j].set(M[0, 1])
            U_wg_coupling_ct = U_wg_coupling_ct.at[j, i].set(jnp.conj(M[0, 1]))

    # Step 4: Combine diagonal matrix (no crosstalk) with crosstalk matrix
    U_multi_channel_with_ct = jnp.matmul(U_multi_channel_no_ct, U_wg_coupling_ct)

    return U_multi_channel_with_ct

# 1.3 SLM and Scattering Matrices
def construct_U_multi_channel_slm(N_slm, phase_mod, amp_mod, t_steps):
    U_slm = jnp.zeros((t_steps, N_slm, N_slm), dtype=jnp.complex64)
    for t in range(t_steps):
        for i in range(N_slm):
            phase = jnp.exp(1j * phase_mod[i])
            amplitude = amp_mod[i]
            U_slm = U_slm.at[t, i, i].set(amplitude * phase)
    return U_slm

def construct_I_prime(N_scat, delta, t_steps):
    I_prime = jnp.zeros((t_steps, N_scat, N_scat), dtype=jnp.complex64)
    for t in range(t_steps):
        I_prime = I_prime.at[t].set(jnp.eye(N_scat) + jnp.diag(jnp.full(N_scat, delta)))
    return I_prime

# 1.4 E-field Calculations
def lg00_mode_profile(X, Y, beam_center, beam_waist):
    cx, cy = beam_center
    r_squared = (X - cx)**2 + (Y - cy)**2
    E_profile = jnp.exp(-r_squared / (2 * beam_waist**2))
    return E_profile

def compute_E_field_for_channel(X, Y, E_t, beam_center, beam_waist, t_idx):
    E_profile = lg00_mode_profile(X, Y, beam_center, beam_waist)
    E_field = E_profile * (jnp.real(E_t) + 1j * jnp.imag(E_t))  # Corrected indexing
    return E_field

def compute_total_E_field_profile(X, Y, b_slm_out, beam_centers, beam_waist):
    E_field_profiles = []
    for t_idx in range(b_slm_out.shape[0]):
        E_field_total = jnp.zeros_like(X, dtype=jnp.complex64)
        for atom_index, beam_center in enumerate(beam_centers):
            E_field_total += compute_E_field_for_channel(X, Y, b_slm_out[t_idx, atom_index], beam_center, beam_waist, t_idx)
        E_field_profiles.append(E_field_total)
    return jnp.array(E_field_profiles)

def extract_E_field_at_atoms(E_field_profiles, atom_positions, X, Y):
    E_field_at_atoms = []

    for t_idx in range(E_field_profiles.shape[0]):
        E_field_at_timestep = []
        for x0, y0 in atom_positions:
            x_idx = jnp.argmin(jnp.abs(X[0, :] - x0))
            y_idx = jnp.argmin(jnp.abs(Y[:, 0] - y0))
            E_field_at_timestep.append(E_field_profiles[t_idx][y_idx, x_idx])
        E_field_at_atoms.append(jnp.array(E_field_at_timestep))
    return jnp.array(E_field_at_atoms)

def compute_alpha_t(E_fields_at_atoms, dipoles, Omega_prefactor_MHz):
    alpha_t = []
    for t_idx in range(E_fields_at_atoms.shape[0]):
        alpha_t_timestep = []
        for atom_idx in range(len(dipoles)):
            E_field_atom = E_fields_at_atoms[t_idx, atom_idx]
            dipole = dipoles[atom_idx]
            alpha_t_timestep.append(Omega_prefactor_MHz * dipole[0] * E_field_atom)
        alpha_t.append(jnp.array(alpha_t_timestep))
    return jnp.array(alpha_t)


# ========================================
# Part 2: Atomic Model
# ========================================

# Two levels per qubit
N_qubit_level = 2

# Define multi-qubit operators and Hamiltonians (as described previously)
s_plus = jnp.array([[0, 1], [0, 0]], dtype=jnp.complex64)
s_minus = jnp.array([[0, 0], [1, 0]], dtype=jnp.complex64)
s_z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64)

def construct_multi_qubit_operator(single_qubit_op, N_a, N_qubit_level, qubit_idx):
    I = jnp.eye(N_qubit_level, dtype=jnp.complex64)
    op_list = [I] * N_a
    op_list[qubit_idx] = single_qubit_op

    multi_qubit_op = op_list[0]
    for op in op_list[1:]:
        multi_qubit_op = jnp.kron(multi_qubit_op, op)

    return multi_qubit_op

# Creation and annihilation operators for the quantized field (Fock mode)
def construct_annihilation_operator(fock_dim):
    a = jnp.zeros((fock_dim, fock_dim), dtype=jnp.complex64)
    for n in range(1, fock_dim):
        a = a.at[n - 1, n].set(jnp.sqrt(n))
    a_dag = a.T.conj()  # Creation operator is the Hermitian conjugate of annihilation operator
    return a, a_dag

# Fock space dimensions
fock_dim = 5
a, a_dag = construct_annihilation_operator(fock_dim)
I_fock = jnp.eye(fock_dim, dtype=jnp.complex64)

# Drift Hamiltonian (H_0)
def construct_H_0(N_a, omega_0, omega_r):  # omega_0/r in MHz
    H_0_qubits = sum(0.5 * omega_0 * construct_multi_qubit_operator(s_z, N_a, 2, i) for i in range(N_a))
    H_0_field = omega_r * jnp.kron(jnp.eye(2 ** N_a), a_dag @ a)
    return jnp.kron(H_0_qubits, I_fock) + H_0_field

# Control Hamiltonian
def construct_H_control(N_a, N_qubit_level, g_real_t, g_imag_t):
    H_control = jnp.zeros((N_qubit_level ** N_a * fock_dim, N_qubit_level ** N_a * fock_dim), dtype=jnp.complex64)
    for i in range(N_a):
        H_control += g_real_t[i] * (jnp.kron(construct_multi_qubit_operator(s_plus, N_a, N_qubit_level, i), a) +
                                    jnp.kron(construct_multi_qubit_operator(s_minus, N_a, N_qubit_level, i), a_dag))
        H_control += g_imag_t[i] * (1j * jnp.kron(construct_multi_qubit_operator(s_plus, N_a, N_qubit_level, i), a) -
                                    1j * jnp.kron(construct_multi_qubit_operator(s_minus, N_a, N_qubit_level, i), a_dag))
    return H_control

# Time Evolution Hamiltonian (H(t))
def construct_H_time(N_a, N_qubit_level, omega_0, omega_r, g_real_t, g_imag_t,
                   atom_positions, gate_type='single'):
    """Construct the time-dependent Hamiltonian H(t) as a JAX array."""
    H_t_list = []  # Store Hamiltonians for each time step

    # Construct base Hamiltonians
    H_0 = construct_H_0(N_a, omega_0, omega_r)

    # Construct H(t) for each time step
    for t in range(len(g_real_t)):
        H_control = construct_H_control(N_a, N_qubit_level, g_real_t[t], g_imag_t[t])
        H_t = H_0 + H_control
        H_t_list.append(H_t)

    # Convert list of Hamiltonians to a JAX array
    return jnp.array(H_t_list)

# Compute accumulated propagator (U(t)) over time
def compute_accumulated_propagator(H_t, dt, N_a, N_qubit_level):
    dim = N_qubit_level ** N_a * fock_dim
    U_accumulated = jnp.eye(dim, dtype=jnp.complex64)
    U_t_all = []
    for H in H_t:
        U_t = jla.expm(-1j * H * dt)
        U_accumulated = U_t @ U_accumulated
        U_t_all.append(U_accumulated)
    return jnp.array(U_t_all)

# Trace out field from unitary matrix
def trace_out_field_from_unitary(U_t, N_a, N_qubit_level, field_dim=5):
    dim_qubits = N_qubit_level ** N_a
    U_t_reshaped = U_t.reshape(dim_qubits, field_dim, dim_qubits, field_dim)
    U_t_traced = jnp.sum(U_t_reshaped, axis=(1, 3))
    return U_t_traced

def trace_out_field_from_unitary_t_all(U_t_all, N_a, N_qubit_level, field_dim=5):
    """
    Trace out the field from the unitary evolution matrix for all time steps.

    Parameters:
    - U_t_all: Time series of unitary matrices with shape (time_steps, dim_total, dim_total),
      where dim_total = dim_qubits * field_dim.
    - N_a: Number of atoms (qubits).
    - N_qubit_level: Number of levels per qubit (typically 2).
    - field_dim: Dimension of the field Hilbert space.

    Returns:
    - U_t_traced_all: Time series of unitary matrices with field traced out, shape (time_steps, dim_qubits, dim_qubits).
    """
    time_steps = U_t_all.shape[0]
    dim_qubits = N_qubit_level**N_a
    total_dim = U_t_all.shape[-1]

    # Ensure dimensions match
    assert dim_qubits * field_dim == total_dim, "Mismatch in total dimensions!"

    # Reshape to separate qubit and field dimensions
    U_t_reshaped = U_t_all.reshape(time_steps, dim_qubits, field_dim, dim_qubits, field_dim)

    # Trace out field dimensions
    U_t_traced_all = jnp.sum(U_t_reshaped, axis=(2, 4))  # Sum over field dimensions

    return U_t_traced_all

# Fidelity Calculation
def compute_fidelity_unitary(U_t_traced, U_target):
    """
    Computes the fidelity of the unitary evolution compared to the target unitary
    using the standard multi-qubit gate fidelity formula.

    Parameters:
    - U_t_traced: Reduced unitary matrix for the qubit system after tracing out the field.
    - U_target: Target unitary matrix.

    Returns:
    - fidelity: The fidelity between the evolved and target unitary matrices.
    """
    # Dimension of the Hilbert space
    d = U_target.shape[0]  # For N qubits, d = 2^N

    # Compute fidelity using the standard formula
    trace_overlap = jnp.abs(jnp.trace(U_target.conj().T @ U_t_traced))**2
    fidelity = trace_overlap / (d**2)

    return jnp.real(fidelity)

# record F(t)
def compute_fidelity_unitary_t_all(U_t_traced_all, U_target):
    """
    Computes the fidelity of the unitary evolution for all time steps
    compared to the target unitary using the standard multi-qubit gate fidelity formula.

    Parameters:
    - U_t_traced_all: A list or array of reduced unitary matrices for the qubit system at all time steps.
    - U_target: Target unitary matrix.

    Returns:
    - fidelity_all: Array of fidelities at each time step.
    """
    # Dimension of the Hilbert space
    d = U_target.shape[0]  # For N qubits, d = 2^N

    def fidelity_at_timestep(U_t_traced):
        trace_overlap = jnp.abs(jnp.trace(U_target.conj().T @ U_t_traced))**2
        return trace_overlap / (d**2)

    # Vectorized computation of fidelity for all time steps
    fidelity_all = jax.vmap(fidelity_at_timestep)(U_t_traced_all)
    return fidelity_all


# Clifford gates
def clifford_group_and_t_gate():
    """
    Constructs the single-qubit Clifford group and includes the T gate.

    Returns:
    - clifford_group: List of tuples containing the gate name and matrix.
    """
    # Define single-qubit gates
    I = jnp.array([[1, 0], [0, 1]], dtype=jnp.complex64)  # Identity
    H = (1 / jnp.sqrt(2)) * jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64)  # Hadamard
    S = jnp.array([[1, 0], [0, 1j]], dtype=jnp.complex64)  # Phase gate
    X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex64)  # Pauli-X
    Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex64)  # Pauli-Y
    Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64)  # Pauli-Z
    T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex64)  # T gate (π/8 rotation)

    # Clifford group: combinations of I, H, S, X, Y, Z
    # The single-qubit Clifford group consists of 24 elements.
    clifford_group = [
        ("X", X),
        ("I", I),
        ("H", H),
        ("S", S),
        ("Y", Y),
        ("Z", Z),
        ("HS", H @ S),
        ("SH", S @ H),
        ("HX", H @ X),
        ("SX", S @ X),
        ("SY", S @ Y),
        ("SZ", S @ Z),
        ("XH", X @ H),
        ("YH", Y @ H),
        ("ZH", Z @ H),
        ("XS", X @ S),
        ("YS", Y @ S),
        ("ZS", Z @ S),
        ("HSX", H @ S @ X),
        ("HSY", H @ S @ Y),
        ("HSZ", H @ S @ Z),
        ("XHS", X @ H @ S),
        ("YHS", Y @ H @ S),
        ("ZHS", Z @ H @ S),
    ]

    # Include T gate for extended functionality
    clifford_group_with_t = clifford_group + [("T", T)]

    return clifford_group_with_t


# ========================================
# Part 3: Function to Compute Gate Fidelity
# ========================================

def compute_multi_qubit_fidelity_closed_system(
    V0_t_list, V1_t_list, L, n0, lambda_0, a0, t0, a1, t1, phase_mod, amp_mod, delta,
    atom_positions, dipoles, beam_centers, beam_waist, X, Y, Omega_prefactor_MHz,
    t_steps, dt,
    N_ch, distances, coupling_lengths, n_eff_list, kappa0, alpha, enable_crosstalk,
    N_slm, N_ch_slm_in, N_scat_1, N_scat_2, N_a, N_qubit_level, omega_0, omega_r, a_pic, a_scat_1,
    U_target, gate_type='single'
):
    """
    Computes the multi-qubit fidelity for a closed system with a fixed tmax.

    Parameters:
        V0_t_list, V1_t_list: Control voltage time-series for each channel.
        L, n0, lambda_0, a0, t0, a1, t1: Physical parameters of the MZM.
        phase_mod, amp_mod: Phase and amplitude modulation parameters.
        delta: Detuning.
        atom_positions, dipoles: Atom positions and dipole moments.
        beam_centers, beam_waist: Parameters of the spatial light modulator (SLM).
        X, Y: 2D grid for the beam field.
        Omega_prefactor_MHz: Prefactor for Rabi frequency in MHz.
        t_steps, dt: Time parameters.
        N_ch: Number of photonic channels.
        distances: Pairwise distances between channels (µm).
        coupling_lengths: Pairwise coupling lengths (µm).
        n_eff_list: Effective refractive indices for the waveguides.
        kappa0, alpha: Crosstalk parameters.
        enable_crosstalk (bool): If False, crosstalk is disabled.
        N_slm, N_ch_slm_in, N_scat_1, N_scat_2, N_a, N_qubit_level: System configuration.
        omega_0, omega_r: Frequencies in MHz.
        a_pic, a_scat_1: Initial fields.
        U_target: Target unitary matrix.
        gate_type: Type of quantum gate ('single', 'multi', etc.).

    Returns:
        Fidelity of the computed unitary matrix with respect to the target matrix.
    """
    # Compute Unitary Matrix for Multi-Channel System
    U_system_multi_channel = jnp.array([
        U_drmzm_multi_channel(
            [V0_t[i] for V0_t in V0_t_list],
            [V1_t[i] for V1_t in V1_t_list],
            L, n0, lambda_0, a0, t0, a1, t1,
            N_ch, distances, coupling_lengths, n_eff_list,
            kappa0, alpha, enable_crosstalk
        )
        for i in range(t_steps)
    ])

    # Compute scattering and SLM matrices
    I1_prime = construct_I_prime(N_scat_1, delta, t_steps)
    U_tensor_product = jnp.array([jnp.kron(U_system_multi_channel[i], I1_prime[i]) for i in range(t_steps)])

    # Output modes calculations
    a_total = jnp.kron(a_pic, a_scat_1)
    output_modes = jnp.array([jnp.dot(U_tensor_product[i], a_total) for i in range(t_steps)])
    output_modes_reshaped = output_modes.reshape(t_steps, N_ch, N_scat_1)
    a_pic_out = jnp.sum(output_modes_reshaped, axis=-1)
    a_scat_1_out = jnp.sum(output_modes_reshaped, axis=-2)

    # Compute final output modes after SLM and second scattering
    b_slm_in = jnp.zeros((t_steps, N_slm), dtype=jnp.complex64)
    b_scat2_in = jnp.zeros((t_steps, N_scat_2), dtype=jnp.complex64)

    for t in range(t_steps):
        b_slm_in = b_slm_in.at[t, :N_ch_slm_in].set(a_pic_out[t, :N_ch_slm_in])
        b_slm_in = b_slm_in.at[t, N_ch_slm_in:].set(a_scat_1_out[t, :(N_slm - N_ch_slm_in)])

        b_scat2_in = b_scat2_in.at[t, :(N_ch - N_ch_slm_in)].set(a_pic_out[t, N_ch_slm_in:])
        b_scat2_in = b_scat2_in.at[t, (N_ch - N_ch_slm_in):].set(a_scat_1_out[t, (N_scat_2 - (N_ch - N_ch_slm_in)):])

    b_total_in = jnp.array([jnp.kron(b_slm_in[t], b_scat2_in[t]) for t in range(t_steps)])
    U_multi_channel_slm = construct_U_multi_channel_slm(N_slm, phase_mod, amp_mod, t_steps)
    I2_prime = construct_I_prime(N_scat_2, delta, t_steps)
    U_tensor_product_stage_2 = jnp.array([jnp.kron(U_multi_channel_slm[t], I2_prime[t]) for t in range(t_steps)])
    b_total_out = jnp.array([jnp.dot(U_tensor_product_stage_2[t], b_total_in[t]) for t in range(t_steps)])

    b_total_out_reshaped = b_total_out.reshape(t_steps, N_slm, N_scat_2)
    b_slm_out = jnp.sum(b_total_out_reshaped, axis=-1)

    # Compute the total E-field on the atom plane
    E_field_profiles = compute_total_E_field_profile(X, Y, b_slm_out, beam_centers, beam_waist)

    # Extract E-fields at the atom positions
    E_fields_at_atoms = extract_E_field_at_atoms(E_field_profiles, atom_positions, X, Y)

    # Compute interaction strength α(t) for each atom
    alpha_t = compute_alpha_t(E_fields_at_atoms, dipoles, Omega_prefactor_MHz)

    # g_real_t and g_imag_t represent g(t) in the Jaynes-Cummings model, g(t) = eff_Rabi(t)/2
    g_real_t = alpha_t.real / 2
    g_imag_t = alpha_t.imag / 2

    # Construct the time-dependent Hamiltonian
    H_t = construct_H_time(N_a, N_qubit_level, omega_0, omega_r, g_real_t, g_imag_t, atom_positions, gate_type)

    # Evolve the unitary matrix (propagator)
    U_t_all = compute_accumulated_propagator(H_t, dt, N_a, N_qubit_level)

    # Trace out the field from the final unitary matrix
    U_t_traced = trace_out_field_from_unitary(U_t_all[-1], N_a, N_qubit_level)
    U_t_traced_all = trace_out_field_from_unitary_t_all(U_t_all, N_a, N_qubit_level)

    # Compute fidelity with the target gate
    fidelity = compute_fidelity_unitary(U_t_traced, U_target)
    fidelity_all = compute_fidelity_unitary_t_all(U_t_traced_all, U_target)

    return fidelity_all

# ========================================
# Part 4: Program Instruction function
# ========================================
def program_instruction(N_a, key_number, gate_type='single', selected_atoms=None, control_atom=None, target_atom=None):
    """
    Constructs the target gates for selected atoms. If `gate_type` is 'single', it applies single-qubit gates.
    Logs the gate applied to each atom.
    """
    if gate_type == 'single':
        # Define the Clifford gates for single-qubit gates
        clifford_t_gates = clifford_group_and_t_gate()
        U_target_multi_qubit = []
        key = jrandom.PRNGKey(key_number)  # Initialize the random key for gate selection

        # Apply random single-qubit gates to selected atoms
        for atom_idx in range(N_a):
            if selected_atoms is None or atom_idx in selected_atoms:
                key, subkey = jrandom.split(key)  # Split key for each atom to ensure different gates
                gate_idx = jrandom.randint(subkey, (), minval=0, maxval=len(clifford_t_gates))
                gate_name, gate = clifford_t_gates[gate_idx]
                logger.info(f"Applying {gate_name} gate to Atom {atom_idx}")
                U_target_multi_qubit.append(gate)
            else:
                logger.info(f"Applying Identity gate to Atom {atom_idx}")
                U_target_multi_qubit.append(jnp.eye(2))  # Identity gate for unselected atoms

        # Combine the gates using the Kronecker product
        U_target = U_target_multi_qubit[0]
        for gate in U_target_multi_qubit[1:]:
            U_target = jnp.kron(U_target, gate)

        return U_target

# ========================================
# Part 5: Modified Configuration Setup
# ========================================
def create_config(simple=False):
    # Define APIC parameters
    APIC_params = {
        'L': 600 * 780e-9 / 1.95,  # Example calculation based on your main code
        'n0': 1.95,
        'lambda_0': 780e-9,
        'a0': 0.998,
        't0': 0.998,
        'a1': 0.998,
        't1': 0.998,
        'phase_mod': jnp.zeros(6),  # Assuming N_slm = 6
        'amp_mod': jnp.ones(6)      # Assuming N_slm = 6
    }

    # Define atom beam parameters
    X, Y = generate_grid(grid_size=600, grid_range=(-6, 6))
    atom_positions = generate_atom_positions_equilateral(3, side_length=3.0, center=(0, 0))  # N_a = 3
    dipoles = generate_dipoles(3)  # N_a = 3
    beam_centers = atom_positions
    beam_waist = 2.0  # µm
    Omega_prefactor_MHz = calculate_Omega_rabi_prefactor(20, 1000)  # I_mW_per_cm2=20, Detuning_MHz=1000
    a_pic = jnp.array([1.0] * 6)      # N_ch = 6
    a_scat_1 = jnp.array([1.0] * 4)   # N_scat_1 = 4

    atom_beam_params = {
        'atom_positions': atom_positions,
        'dipoles': dipoles,
        'beam_centers': beam_centers,
        'beam_waist': beam_waist,
        'X': X,
        'Y': Y,
        'Omega_prefactor_MHz': Omega_prefactor_MHz,
    }

    # Define control voltage parameters
    t_steps = 10  # **Changed** to align with the first training phase
    tmax_fixed = 0.1  # Fixed tmax in microseconds
    dt = tmax_fixed / t_steps
    control_Vt_params = {
        'tmin': 0.0,
        'tmax': tmax_fixed,
        't_steps': t_steps,
        'dt': dt,
        'min_V_level': -15.0,
        'max_V_level': 15.0
    }

    # Define system parameters
    system_params = {
        'a_pic': a_pic,
        'a_scat_1': a_scat_1,
        'delta': 0.001,
        'N_ch': 6,
        'distances': generate_distances(6, 1.0, random_variation=0.1, seed=42),
        'coupling_lengths': generate_coupling_lengths(6, 600.0, scaling_factor=1.1, random_variation=60.0, seed=42),
        'n_eff_list': generate_n_eff_list(6, 1.75, random_variation=0.05, seed=42),
        'kappa0': 10.145,
        'alpha': 6.934,
        'enable_crosstalk': True,
        'N_slm': 6,
        'N_ch_slm_in': 4,
        'N_scat_1': 4,
        'N_scat_2': 6 + 4 - 6,  # N_total - N_slm
        'N_a': 3,
        'N_qubit_level': 2,
        'omega_0': 6.835e3,  # MHz
        'omega_r': 6.835e3,  # MHz
        'a_pic': a_pic,
        'a_scat_1': a_scat_1,
        'gate_type': 'single'
    }

    # Generate the target unitary
    selected_atoms = [0,1,2]  # Set to [1, 2] as per N_a = 3
    system_params['U_target'] = program_instruction(
        N_a=3,
        key_number=105,
        gate_type='single',
        selected_atoms=selected_atoms
    )

    # Defensive check to ensure U_target is not None
    if system_params['U_target'] is None:
        raise ValueError("U_target was not correctly assigned by program_instruction.")

    # Define reward scaling parameters with hybrid scaling methods
    reward_scaling_params = {
        'a': 1.0,  # Scaling factor for absolute fidelity
        'b': 1.0,  # Scaling factor for fidelity improvement
        'scaling_methods': [  # Define scaling methods with relative thresholds
            {'method': 'log', 'min_ratio': 0.0, 'max_ratio': 0.5},
            {'method': 'linear', 'min_ratio': 0.5, 'max_ratio': 0.9},
            {'method': 'quadratic', 'min_ratio': 0.9, 'max_ratio': 1.0},
        ],
        'k': 5.0,  # Relevant for exponential scaling
        'thresholds': [0.8, 0.9],  # Fidelity thresholds for bonuses (can also be made relative)
        'bonus': 5.0,  # Bonus reward for crossing thresholds
        'clip_min': -10.0,  # Minimum reward
        'clip_max': 10.0,  # Maximum reward
        'step_penalty': 0.01  # Penalty per step
    }

    config = {
        't_steps': t_steps,
        'N_ch': 6,
        'N_a': 3,
        'piecewise_segments': 10,  # **Changed** from 50 to 10 to align with the first training phase
        'min_voltage': -15.0,
        'max_voltage': 15.0,
        'history_length': 5,  # Number of past states to include
        'max_delta_voltage': 1.0,  # Added parameter for action scaling
        'APIC_params': APIC_params,          # Unchanged
        'system_params': system_params,      # Only system-related parameters
        'atom_beam_params': atom_beam_params,
        'control_Vt_params': control_Vt_params,
        'reward_scaling_params': reward_scaling_params,  # Updated reward scaling parameters
        'stagnant_threshold': 10,           # Number of stagnant episodes before termination
        'stagnant_fidelity_min': 0.99,      # Minimum fidelity to consider stagnation
        'target_fidelity': 0.999            # Desired fidelity threshold to achieve
    }

    return config

# ========================================
# Part 7: Modified Custom Callback (With Additional Logging)
# ========================================
class RLLoggingCallback(BaseCallback):
    """
    Custom callback for logging additional metrics during PPO training.
    Logs to both console and a log file.
    Implements early termination if best_fidelity is stuck between stagnant_fidelity_min and target_fidelity for too many episodes.
    Also saves the final optimal V(t) and plots upon termination.
    """
    def __init__(self, stagnant_threshold=100,
                 stagnant_fidelity_min=0.99,
                 target_fidelity=0.999,
                 results_dir='results',
                 verbose=0):
        super(RLLoggingCallback, self).__init__(verbose)
        self.best_fidelity = 0.0
        self.best_voltage = None
        self.best_fidelity_history = []
        self.episode_count = 0
        self.stagnant_counter = 0
        self.stagnant_threshold = stagnant_threshold
        self.stagnant_fidelity_min = stagnant_fidelity_min
        self.target_fidelity = target_fidelity
        self.start_time = time.perf_counter()
        self.results_dir = results_dir

        os.makedirs(self.results_dir, exist_ok=True)

    def _on_step(self) -> bool:
        infos = self.locals.get('infos', [])
        improved = False

        # Additional log line for debugging how often this is called
        logger.info(f"[RLLoggingCallback] _on_step called. Infos count={len(infos)}. Current best fidelity={self.best_fidelity:.4f}")

        for env_id, info in enumerate(infos):
            if 'final_fidelity' in info:
                fidelity = info['final_fidelity']
                self.episode_count += 1

                if fidelity < 0.0 or fidelity > 1.0:
                    continue  

                logger.info(f"[RLLoggingCallback] Env={env_id}, Episode={self.episode_count}, final_fidelity={fidelity:.6f}")

                # Check improvement
                if fidelity > self.best_fidelity:
                    self.best_fidelity = fidelity
                    self.stagnant_counter = 0
                    improved = True
                    new_best_voltage = info.get('best_voltage', None)
                    if new_best_voltage is not None:
                        self.best_voltage = new_best_voltage

                    logger.info(
                        f"New best fidelity: {self.best_fidelity:.6f} achieved in Env {env_id} | Episode: {self.episode_count}"
                    )

        self.best_fidelity_history.append(self.best_fidelity)

        if not improved:
            # Check stagnation
            if self.stagnant_fidelity_min <= self.best_fidelity < self.target_fidelity:
                self.stagnant_counter += 1
                logger.info(
                    f"Stagnant Counter Incremented: {self.stagnant_counter}/{self.stagnant_threshold} | Best Fidelity: {self.best_fidelity:.6f}"
                )
            else:
                self.stagnant_counter = 0

        # Early stop if target_fidelity reached
        if self.best_fidelity >= self.target_fidelity:
            elapsed_time = time.perf_counter() - self.start_time
            logger.info(
                f"Target fidelity of {self.best_fidelity:.6f} achieved in {self.episode_count} episodes. "
                f"Training terminated after {elapsed_time:.2f} seconds."
            )
            # Save final best voltage
            if self.best_voltage is not None:
                best_voltage_final_path = os.path.join(self.results_dir, "best_voltage_final.npy")
                np.save(best_voltage_final_path, self.best_voltage)
                logger.info(f"Final best voltage saved to '{best_voltage_final_path}'.")

            # Save Best Fidelity
            best_fidelity_path = os.path.join(self.results_dir, "best_fidelity.txt")
            with open(best_fidelity_path, "w") as f:
                f.write(f"Best Fidelity: {self.best_fidelity:.6f}\n")
            logger.info(f"Best fidelity saved to '{best_fidelity_path}'.")

            # Save Best Fidelity History
            best_fidelity_history_path_npy = os.path.join(self.results_dir, "best_fidelity_history.npy")
            best_fidelity_history_path_csv = os.path.join(self.results_dir, "best_fidelity_history.csv")
            np.save(best_fidelity_history_path_npy, np.array(self.best_fidelity_history))
            np.savetxt(best_fidelity_history_path_csv, np.array(self.best_fidelity_history), delimiter=",",
                       header="Episode,Best_Fidelity")
            logger.info(f"Best fidelity history saved to '{best_fidelity_history_path_npy}' and '{best_fidelity_history_path_csv}'.")

            self._plot_fidelity_progress(self.best_fidelity_history, self.results_dir)
            return False  # Stop training

        # Early stop if stagnating
        if self.stagnant_counter >= self.stagnant_threshold:
            elapsed_time = time.perf_counter() - self.start_time
            logger.warning(
                f"Training is stagnating after {self.episode_count} episodes. "
                f"Best Fidelity: {self.best_fidelity:.6f}. Terminating training after {elapsed_time:.2f} seconds."
            )
            if self.best_voltage is not None:
                best_voltage_final_path = os.path.join(self.results_dir, "best_voltage_final.npy")
                np.save(best_voltage_final_path, self.best_voltage)
                logger.info(f"Final best voltage saved to '{best_voltage_final_path}'.")

            # Save Best Fidelity
            best_fidelity_path = os.path.join(self.results_dir, "best_fidelity.txt")
            with open(best_fidelity_path, "w") as f:
                f.write(f"Best Fidelity: {self.best_fidelity:.6f}\n")
            logger.info(f"Best fidelity saved to '{best_fidelity_path}'.")

            # Save Best Fidelity History
            best_fidelity_history_path_npy = os.path.join(self.results_dir, "best_fidelity_history.npy")
            best_fidelity_history_path_csv = os.path.join(self.results_dir, "best_fidelity_history.csv")
            np.save(best_fidelity_history_path_npy, np.array(self.best_fidelity_history))
            np.savetxt(best_fidelity_history_path_csv, np.array(self.best_fidelity_history), delimiter=",",
                       header="Episode,Best_Fidelity")
            logger.info(f"Best fidelity history saved to '{best_fidelity_history_path_npy}' and '{best_fidelity_history_path_csv}'.")

            self._plot_fidelity_progress(self.best_fidelity_history, self.results_dir)
            return False  # Stop training

        elapsed_time = time.perf_counter() - self.start_time
        logger.info(
            f"Training Status | Episodes: {self.episode_count} | Best Fidelity: {self.best_fidelity:.6f} | "
            f"Stagnant Counter: {self.stagnant_counter}/{self.stagnant_threshold} | "
            f"Elapsed Time: {elapsed_time:.2f} seconds"
        )

        return True  # Continue training


    def _plot_fidelity_progress(self, fidelity_history, results_dir):
        """
        Plot and save the fidelity progress over episodes.
        """
        try:
            episodes = np.arange(1, len(fidelity_history) + 1)
            best_fidelity = np.array(fidelity_history)

            plt.figure(figsize=(12, 6))
            plt.plot(episodes, best_fidelity, label='Best Fidelity')
            plt.title('Global Best Fidelity Progress Over Episodes')
            plt.xlabel('Episode')
            plt.ylabel('Best Fidelity')
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            fidelity_plot_path = os.path.join(results_dir, "fidelity_progress.png")
            plt.savefig(fidelity_plot_path)
            plt.close()
            logger.info(f"Fidelity progress plot saved as '{fidelity_plot_path}'.")

            fidelity_progress_csv = os.path.join(results_dir, "fidelity_progress.csv")
            df = pd.DataFrame({
                'Episode': episodes,
                'Best_Fidelity': best_fidelity
            })
            df.to_csv(fidelity_progress_csv, index=False)
            logger.info(f"Fidelity progress data saved as '{fidelity_progress_csv}'.")
        except Exception as e:
            logger.error(f"Error while plotting or saving fidelity progress: {e}")

# ========================================
# Part 8: NEW END-TO-END GRADIENT-BASED RL APPROACH (with batching + GPU)
# ========================================
class MLPPolicy:
    """
    Minimal JAX MLP that outputs entire schedule (n_ch x piecewise_segments).
    Incorporates mixed precision and is compatible with gradient checkpointing.
    """
    def __init__(self, rng_key, n_ch=6, piecewise_segments=10, hidden_dim=64, latent_dim=8):
        # **Changed**: piecewise_segments initialized to 10
        self.n_ch = n_ch
        self.piecewise_segments = piecewise_segments
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        rng_key, sub1, sub2 = jrandom.split(rng_key, 3)
        W1 = 0.01 * jrandom.normal(sub1, (latent_dim, hidden_dim))
        b1 = jnp.zeros((hidden_dim,), dtype=jnp.float32)
        W2 = 0.01 * jrandom.normal(sub2, (hidden_dim, n_ch * piecewise_segments))
        b2 = jnp.zeros((n_ch * piecewise_segments,), dtype=jnp.float32)
        self.params = dict(W1=W1, b1=b1, W2=W2, b2=b2)

    def forward(self, params, rng_key):
        """
        Return controls of shape (n_ch, piecewise_segments).
        """
        z = jrandom.normal(rng_key, (self.latent_dim,))
        z = z.astype(jnp.float32)
        W1 = params["W1"].astype(jnp.float32)
        b1 = params["b1"].astype(jnp.float32)
        W2 = params["W2"].astype(jnp.float32)
        b2 = params["b2"].astype(jnp.float32)

        h = jnp.tanh(z @ W1 + b1)
        out = h @ W2 + b2
        out = out.reshape((self.n_ch, self.piecewise_segments))
        return out

    def update_piecewise_segments(self, new_piecewise_segments, rng_key):
        """
        Update the policy's piecewise_segments by initializing additional parameters.
        """
        if new_piecewise_segments <= self.piecewise_segments:
            logger.warning("New piecewise_segments must be greater than the current value.")
            return

        additional_segments = new_piecewise_segments - self.piecewise_segments
        rng_key, subkey1, subkey2 = jrandom.split(rng_key, 3)

        # Initialize new W2 and b2 for additional segments
        new_W2 = 0.01 * jrandom.normal(subkey1, (self.hidden_dim, self.n_ch * additional_segments))
        new_b2 = jnp.zeros((self.n_ch * additional_segments,), dtype=jnp.float32)

        # Concatenate existing and new parameters
        updated_W2 = jnp.concatenate([self.params["W2"], new_W2], axis=1)
        updated_b2 = jnp.concatenate([self.params["b2"], new_b2], axis=0)

        self.params["W2"] = updated_W2
        self.params["b2"] = updated_b2
        self.piecewise_segments = new_piecewise_segments
        logger.info(f"Updated piecewise_segments to {new_piecewise_segments}.")

def single_seed_fidelity(params, rng_key, config, policy):
    """
    Evaluate final fidelity for a single random seed, given the current policy params.
    """
    # Generate controls from policy
    voltages = policy.forward(params, rng_key)
    V0_t_list = voltages
    V1_t_list = voltages

    fidelity_all = compute_multi_qubit_fidelity_closed_system(
        V0_t_list, V1_t_list,
        config["APIC_params"]["L"],
        config["APIC_params"]["n0"],
        config["APIC_params"]["lambda_0"],
        config["APIC_params"]["a0"],
        config["APIC_params"]["t0"],
        config["APIC_params"]["a1"],
        config["APIC_params"]["t1"],
        config["APIC_params"]["phase_mod"],
        config["APIC_params"]["amp_mod"],

        config["system_params"]["delta"],
        config["atom_beam_params"]["atom_positions"],
        config["atom_beam_params"]["dipoles"],
        config["atom_beam_params"]["beam_centers"],
        config["atom_beam_params"]["beam_waist"],
        config["atom_beam_params"]["X"],
        config["atom_beam_params"]["Y"],
        config["atom_beam_params"]["Omega_prefactor_MHz"],

        config["control_Vt_params"]["t_steps"],
        config["control_Vt_params"]["dt"],

        config["system_params"]["N_ch"],
        config["system_params"]["distances"],
        config["system_params"]["coupling_lengths"],
        config["system_params"]["n_eff_list"],
        config["system_params"]["kappa0"],
        config["system_params"]["alpha"],
        config["system_params"]["enable_crosstalk"],

        config["system_params"]["N_slm"],
        config["system_params"]["N_ch_slm_in"],
        config["system_params"]["N_scat_1"],
        config["system_params"]["N_scat_2"],
        config["system_params"]["N_a"],
        config["system_params"]["N_qubit_level"],
        config["system_params"]["omega_0"],
        config["system_params"]["omega_r"],
        config["system_params"]["a_pic"],
        config["system_params"]["a_scat_1"],
        config["system_params"]["U_target"],

        config["system_params"].get("gate_type", "single"),
    )

    return fidelity_all[-1]  # scalar

def build_loss_fn(policy, config):
    """
    Returns a function loss_fn(params, rng_key) -> scalar,
    that runs multiple seeds in parallel with vmap and then returns avg cost = (1 - avg_fidelity).
    """
    batch_size = config.get("batch_size", 2)

    # Define a single-seed function that depends on policy via closure
    def single_seed_loss(param, subkey):
        fid = single_seed_fidelity(param, subkey, config, policy)
        loss = jnp.where(fid <= 1.0, 1.0 - fid, (fid - 1.0) ** 2)
        return loss

    # Vectorize over the 'rng_keys' dimension to handle multiple seeds
    batched_seed_loss = jax.vmap(single_seed_loss, in_axes=(None, 0))

    def loss_fn(params, step_rng_key):
        rng_keys = jrandom.split(step_rng_key, batch_size)  # shape=(batch_size,)
        losses = batched_seed_loss(params, rng_keys)        # shape=(batch_size,)
        return jnp.mean(losses)  # average cost across seeds

    return loss_fn

def train_end_to_end_grad(
    policy, config, num_iterations=500, initial_lr=1e-3, patience=50, min_lr=1e-6, lr_decay_factor=0.5,
    use_mixed_precision=True, use_checkpointing=True
):
    """
    Optimized end-to-end gradient-based training using JAX.

    Parameters:
        policy (MLPPolicy): The policy instance to train.
        config (dict): Configuration parameters.
        num_iterations (int): Number of training iterations.
        initial_lr (float): Initial learning rate.
        patience (int): Number of iterations to wait before reducing learning rate.
        min_lr (float): Minimum learning rate.
        lr_decay_factor (float): Factor to decay learning rate.
        use_mixed_precision (bool): Whether to use mixed precision.
        use_checkpointing (bool): Whether to use gradient checkpointing.

    Returns:
        dict: Trained policy parameters.
    """
    key = jrandom.PRNGKey(42)
    loss_fn = build_loss_fn(policy, config)

    # Initialize optimizer with initial learning rate and gradient clipping
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),  # Gradient clipping
        optax.adam(initial_lr)
    )
    opt_state = optimizer.init(policy.params)

    best_fidelity = 0.0
    best_voltage = None
    fidelity_history = []
    stagnant_counter = 0
    current_lr = initial_lr

    logger.info(f"[End-to-End Grad] Starting training for {num_iterations} iterations, initial_lr={initial_lr}, batch_size={config.get('batch_size', 4)}")

    # Define the loss function with optional checkpointing
    if use_checkpointing:
        def loss_with_checkpoint(params, step_rng_key):
            return jax.checkpoint(loss_fn, prevent_cse=True)(params, step_rng_key)
        loss_fn_jit = jax.jit(loss_with_checkpoint)
    else:
        loss_fn_jit = jax.jit(loss_fn)

    for i in range(1, num_iterations + 1):
        key, subkey = jrandom.split(key)
        grads = jax.grad(loss_fn_jit)(policy.params, subkey)
        updates, opt_state = optimizer.update(grads, opt_state, policy.params)
        policy.params = optax.apply_updates(policy.params, updates)
        current_loss = loss_fn_jit(policy.params, subkey)

        # Compute fidelity from current_loss
        fidelity = 1.0 - float(current_loss)

        if fidelity < 0.0 or fidelity > 1.0:
            continue  # Skip updating best_fidelity

        # Update best fidelity and save V(t) if improved
        if fidelity > best_fidelity:
            best_fidelity = fidelity
            best_voltage = policy.forward(policy.params, subkey)
            # Reset stagnant counter since improvement was observed
            stagnant_counter = 0
        else:
            stagnant_counter += 1
            logger.info(f"[End-to-End Grad] Iter={i}, no improvement in fidelity. Stagnant Counter: {stagnant_counter}/{patience}")

        # Record fidelity history
        fidelity_history.append(best_fidelity)

        # Check if patience is exceeded to decay learning rate
        if stagnant_counter >= patience:
            if current_lr > min_lr:
                current_lr = max(current_lr * lr_decay_factor, min_lr)
                optimizer = optax.chain(
                    optax.clip_by_global_norm(1.0),
                    optax.adam(current_lr)
                )
                opt_state = optimizer.init(policy.params)
                logger.info(f"[End-to-End Grad] Iter={i}, learning rate decayed to {current_lr}. Resetting stagnant counter.")
                stagnant_counter = 0  # Reset counter after decay
            else:
                logger.warning(f"[End-to-End Grad] Iter={i}, minimum learning rate reached ({min_lr}).")

        # Log every 10 iterations
        if (i % 10) == 0:
            logger.info(f"[End-to-End Grad] Iter={i}, loss={current_loss:.6f}, best_fidelity={best_fidelity:.6f}, current_lr={current_lr:.6f}")

            if best_fidelity > config.get("target_fidelity", 0.999):
                logger.info(f"[End-to-End Grad] Best fidelity > {config.get('target_fidelity', 0.999)} reached at iter={i}. Stopping early!")
                break

    # After training, save fidelity history
    fidelity_history_path_npy = os.path.join(config["results_dir"], "fidelity_history.npy")
    fidelity_history_path_csv = os.path.join(config["results_dir"], "fidelity_history.csv")
    np.save(fidelity_history_path_npy, np.array(fidelity_history))
    np.savetxt(fidelity_history_path_csv, np.array(fidelity_history), delimiter=",", header="Episode,Best_Fidelity")
    logger.info(f"Fidelity history saved to '{fidelity_history_path_npy}' and '{fidelity_history_path_csv}'.")

    # Plot fidelity progress
    _plot_fidelity_progress(fidelity_history, config["results_dir"])

    # Save the trained model parameters
    model_save_path = os.path.join(config.get('results_dir', 'results'), "trained_model_params.pkl")
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)  # Ensure directory exists
    with open(model_save_path, "wb") as f:
        pickle.dump(policy.params, f)
    logger.info(f"Trained model parameters saved to '{model_save_path}'.")

    return policy.params


def _plot_fidelity_progress(fidelity_history, results_dir):
    """
    Helper function to plot and save fidelity progress.
    """
    try:
        episodes = np.arange(1, len(fidelity_history) + 1)
        best_fidelity = np.array(fidelity_history)

        plt.figure(figsize=(12, 6))
        plt.plot(episodes, best_fidelity, label='Best Fidelity')
        plt.title('Global Best Fidelity Progress Over Episodes')
        plt.xlabel('Episode')
        plt.ylabel('Best Fidelity')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        fidelity_plot_path = os.path.join(results_dir, "fidelity_progress.png")
        plt.savefig(fidelity_plot_path)
        plt.close()
        logger.info(f"Fidelity progress plot saved as '{fidelity_plot_path}'.")

    except Exception as e:
        logger.error(f"Error while plotting or saving fidelity progress: {e}")

# ========================================
# Part 9: Loading and Utilizing the Saved Model Parameters
# ========================================
def load_trained_model(model_path):
    """
    Load trained model parameters from a pickle file.
    """
    try:
        with open(model_path, "rb") as f:
            params = pickle.load(f)
        return params
    except Exception as e:
        logger.error(f"Failed to load trained model parameters from '{model_path}': {e}")
        return None


# ========================================
# Part 10: Modified Main Function
# ========================================
def main():
    start_time = time.perf_counter()  # Define start_time for total execution time logging

    results_dir = "results"
    os.makedirs(results_dir, exist_ok=True)

    # Configure logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    log_file = os.path.join(results_dir, "training_log.txt")
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(file_formatter)
    if not logger.handlers:
        logger.addHandler(file_handler)

    # Force JAX to use GPU if available
    jax.config.update("jax_platform_name", "gpu")   # Use GPU
    # Enable dynamic memory allocation on GPU
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"  # Prevent pre-allocation of GPU memory
    # Optionally limit JAX to use fewer CPU threads if running in HPC environment
    # os.environ["JAX_NUM_THREADS"] = "1"

    # Create configuration
    config = create_config()  # Now initializes with piecewise_segments=10
    config["batch_size"] = 4  # Adjusted for optimal batching
    config["results_dir"] = results_dir  # Ensure results_dir is in config

    # Define adaptive training phases with increasing t_steps and piecewise_segments
    training_phases = [
        {"t_steps": 20, "piecewise_segments": 10, "num_iterations": 500},
        {"t_steps": 50, "piecewise_segments": 25, "num_iterations": 500},
        {"t_steps": 100, "piecewise_segments": 50, "num_iterations": 1000},
        # Add more phases as needed
    ]

    # Initialize policy parameters to None for the first phase
    initial_params = None
    policy = None  # To hold the policy instance across phases
    rng_key = jrandom.PRNGKey(0)  # Initialize RNG key

    for phase_idx, phase in enumerate(training_phases, 1):
        t_steps = int(phase["t_steps"])
        piecewise_segments = int(phase["piecewise_segments"])
        num_iterations = phase["num_iterations"]

        # Update config with current phase's t_steps and dt
        config["control_Vt_params"]["t_steps"] = t_steps
        config["control_Vt_params"]["dt"] = config["control_Vt_params"]["tmax"] / t_steps

        logger.info(f"Starting Training Phase {phase_idx}: t_steps={t_steps}, piecewise_segments={piecewise_segments}, num_iterations={num_iterations}")

        # Initialize or update the policy
        if phase_idx == 1:
            # Initialize the policy for the first phase
            policy = MLPPolicy(rng_key, piecewise_segments=piecewise_segments)
        else:
            # Update piecewise_segments for subsequent phases
            rng_key, subkey = jrandom.split(rng_key)
            policy.update_piecewise_segments(piecewise_segments, subkey)

        # Train the model for the current phase
        trained_params = train_end_to_end_grad(
            policy=policy,  # Pass the existing policy
            config=config,
            num_iterations=num_iterations,
            initial_lr=1e-3,
            patience=500,  # Updated patience
            min_lr=1e-6,
            lr_decay_factor=0.5,
            use_mixed_precision=True,
            use_checkpointing=True
        )

        # Update the policy's parameters
        policy.params = trained_params

        # Set the trained parameters as the initial parameters for the next phase
        initial_params = trained_params

    logger.info("Adaptive Training with Increasing t_steps and piecewise_segments Completed.")

    # Loading and utilizing the saved model parameters
    trained_model_path = os.path.join(config["results_dir"], "trained_model_params.pkl")
    if os.path.exists(trained_model_path):
        loaded_params = load_trained_model(trained_model_path)
        if loaded_params is not None:
            logger.info(f"Loaded trained model parameters from '{trained_model_path}'.")

            # Assign loaded parameters to the existing policy
            policy.params = loaded_params

            # Generate and save optimal V(t) using the loaded parameters
            new_rng_key = jrandom.PRNGKey(123)  # Example seed
            optimal_Vt = policy.forward(loaded_params, new_rng_key)

            # Save the generated optimal V(t)
            optimal_Vt_path = os.path.join(config["results_dir"], "optimal_Vt_new_seed.npy")
            np.save(optimal_Vt_path, optimal_Vt)
            logger.info(f"Optimal V(t) for new seed saved to '{optimal_Vt_path}'.")

        else:
            logger.error(f"Failed to load trained model parameters from '{trained_model_path}'.")
    else:
        logger.error(f"Trained model parameters file '{trained_model_path}' not found.")

    # Log total execution time for adaptive training approach
    end_time = time.perf_counter()
    total_time = end_time - start_time
    logger.info(f"Adaptive training with increasing t_steps and piecewise_segments completed in {total_time/60:.2f} minutes.")

# ========================================
# Entry Point
# ========================================
if __name__ == "__main__":
    main()
