In [2]:
!pip install optax pandas gymnasium torch stable-baselines3



In [None]:
import jax.numpy as jnp
import jax.random as jrandom
import jax.scipy.linalg as jla
import jax
import matplotlib.pyplot as plt
from jax import vmap
from jax import random
import numpy as np
import optax
import time
import pandas as pd
from jax import grad
import math
import logging
from collections import deque, defaultdict

import os
import sys
from typing import Optional

import gymnasium as gym  # Updated import for Gymnasium
from gymnasium import spaces
import torch
import torch.nn as nn  # Ensure nn is imported correctly
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.env_checker import check_env

# ========================================
# Logging Configuration
# ========================================
# ----- Logging Configuration (Print to Console) -----
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # Suppress TensorFlow/XLA logs

# Create a logger and set level to INFO
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Remove any existing handlers
if logger.hasHandlers():
    logger.handlers.clear()

# Create a stream handler that outputs to sys.stdout
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO)

# Create formatter and add it to the stream handler
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)


# ========================================
# Part 1: Control System Model (PIC + SLM)
# ========================================

# Constants and Parameters
def calculate_Omega_rabi_prefactor(I_mW_per_cm2, Detuning_MHz):
    """
    Calculate the effective Rabi frequency prefactor (MHz) based on laser intensity and detuning.

    Parameters:
    - I_mW_per_cm2: Laser intensity in mW/cm²
    - Detuning_MHz: Detuning in MHz

    Returns:
    - Omega_effRabi_prefactor_MHz: Effective Rabi frequency prefactor in MHz
    """
    hbar = 1.0545718e-34  # Reduced Planck constant (J·s)
    c = 3e8  # Speed of light in vacuum (m/s)
    epsilon_0 = 8.854187817e-12  # Permittivity of free space (F/m)
    mu0e = 3.336e-29  # |0> <-> |e> (C·m)
    mu1e = 3.400e-29  # |1> <-> |e> (C·m)

    I = I_mW_per_cm2 * 10  # Convert intensity from mW/cm² to W/m²
    E_0 = jnp.sqrt(2 * I / (c * epsilon_0))  # Electric field in V/m
    Omega_0e_prefactor = (mu0e * E_0) / hbar
    Omega_1e_prefactor = (mu1e * E_0) / hbar
    Omega_effRabi_prefactor = (Omega_0e_prefactor * Omega_1e_prefactor) / (2 * Detuning_MHz * 1e6)
    Omega_effRabi_prefactor_MHz = Omega_effRabi_prefactor / 1e6  # Convert to MHz
    return Omega_effRabi_prefactor_MHz


# Grid and atom positions
def generate_grid(grid_size=600, grid_range=(-6, 6)):
    """
    Generates a 2D grid using meshgrid for atom positioning.

    Args:
    grid_size (int): Number of points on the grid (default is 600).
    grid_range (tuple): Range of x and y coordinates (default is (-6, 6)).

    Returns:
    X, Y: Meshgrid for the x and y coordinates.
    """
    x = jnp.linspace(grid_range[0], grid_range[1], grid_size - 1)
    y = jnp.linspace(grid_range[0], grid_range[1], grid_size - 1)
    X, Y = jnp.meshgrid(x, y)
    return X, Y

# Function to generate atom positions
def generate_atom_positions_equilateral(N_a, side_length=3.0, center=(0, 0), triangle_spacing_factor=2.0):
    """
    Generate atom positions for any integer number of atoms (N_a) in equilateral triangular patterns.

    Parameters:
    - N_a (int): Total number of atoms.
    - side_length (float): Length of each side of the equilateral triangle.
    - center (tuple): Center of the first triangle (x_center, y_center).
    - triangle_spacing_factor (float): Distance between triangle centers, as a multiple of side_length.

    Returns:
    - positions (list of tuples): List of (x, y) positions for the atoms.
    """
    positions = []
    x_center, y_center = center  # Center of the first triangle
    height = (math.sqrt(3) / 2) * side_length  # Height of the equilateral triangle

    # Generate complete triangles
    num_full_triangles = N_a // 3
    for triangle_idx in range(num_full_triangles):
        # Calculate the center for the current triangle
        current_x_center = x_center + triangle_idx * triangle_spacing_factor * side_length
        current_y_center = y_center

        # Place atoms at the corners of the equilateral triangle
        positions.append((current_x_center - side_length / 2, current_y_center - height / 2))  # Bottom left
        positions.append((current_x_center + side_length / 2, current_y_center - height / 2))  # Bottom right
        positions.append((current_x_center, current_y_center + height / 2))  # Top vertex

    # Place remaining atoms (if any)
    remaining_atoms = N_a % 3
    if remaining_atoms > 0:
        # Get the last triangle's center
        if num_full_triangles > 0:
            last_x_center = x_center + (num_full_triangles - 1) * triangle_spacing_factor * side_length
            last_y_center = y_center
        else:
            last_x_center = x_center
            last_y_center = y_center

        # Add remaining atoms sequentially around the last triangle
        if remaining_atoms >= 1:
            positions.append((last_x_center - side_length / 2, last_y_center - height / 2))  # Bottom left
        if remaining_atoms == 2:
            positions.append((last_x_center + side_length / 2, last_y_center - height / 2))  # Bottom right

    return positions

# Dipoles for each atom
def generate_dipoles(N_a):
    dipoles = [jnp.array([1.0, 0]) for _ in range(N_a)]
    return dipoles

# Function to generate SLM modulations
def generate_slm_mod(N_slm):
    phase_mod = jnp.zeros(N_slm)
    amp_mod = jnp.ones(N_slm)
    return phase_mod, amp_mod

# Functions for crosstalk model setup (with fabrication variations considered)
def generate_distances(num_channels, pitch, random_variation=0.0, seed=None):
    """
    Generate pairwise distances for waveguides arranged in a linear layout with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        pitch (float): Spacing between adjacent waveguides (µm).
        random_variation (float): Maximum random variation (±) for distances.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        jax.numpy.array: Pairwise distances matrix (µm).
    """
    if seed is not None:
        np.random.seed(seed)

    distances = np.zeros((num_channels, num_channels), dtype=np.float32)
    for i in range(num_channels):
        for j in range(num_channels):
            if i != j:
                # Add randomness to the calculated distance
                base_distance = abs(i - j) * pitch
                variation = np.random.uniform(-random_variation, random_variation)
                distances[i, j] = base_distance + variation
    return jnp.array(distances)

def generate_coupling_lengths(num_channels, base_length, scaling_factor=1.0, random_variation=0.0, seed=None):
    """
    Generate coupling lengths with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        base_length (float): Base coupling length for adjacent waveguides (µm).
        scaling_factor (float): Multiplier to scale coupling length with increasing separation.
        random_variation (float): Maximum random variation (±) for coupling lengths.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        jax.numpy.array: Coupling lengths matrix (µm).
    """
    if seed is not None:
        np.random.seed(seed)

    coupling_lengths = np.zeros((num_channels, num_channels), dtype=np.float32)
    for i in range(num_channels):
        for j in range(num_channels):
            if i != j:
                base_length_ij = base_length * (scaling_factor**abs(i - j))
                variation = np.random.uniform(-random_variation, random_variation)
                coupling_lengths[i, j] = base_length_ij + variation
    return jnp.array(coupling_lengths)

def generate_n_eff_list(num_channels, base_n_eff, random_variation=0.0, seed=None):
    """
    Generate an array of effective refractive indices with optional randomness.

    Parameters:
        num_channels (int): Number of waveguides.
        base_n_eff (float): Base effective refractive index for all waveguides.
        random_variation (float): Maximum random variation (±) around the base value.
        seed (int or None): Seed for reproducibility of randomness.

    Returns:
        numpy.ndarray: Array of effective refractive indices (n_eff_list).
    """
    if seed is not None:
        np.random.seed(seed)

    variations = np.random.uniform(-random_variation, random_variation, num_channels)
    n_eff_list = base_n_eff + variations
    return jnp.array(n_eff_list)


# 1.1 Control Signal Construction

def construct_V_smooth_with_carrier(tmin, tmax, t_steps, voltage_levels, omega_0, phi=0, max_step=30):
    """
    Constructs a time-dependent control signal with piecewise smooth voltage levels and a carrier frequency component.

    Parameters:
    - tmin (float): Start time.
    - tmax (float): End time.
    - t_steps (int): Number of time steps.
    - voltage_levels (array): Voltage levels for each piece.
    - omega_0 (float): Carrier angular frequency (rad/s).
    - phi (float, optional): Initial phase of the carrier (radians). Default is 0.
    - max_step (float, optional): Maximum allowed voltage change between consecutive pieces. Default is 30.

    Returns:
    - V_t (array): Time-dependent control voltage with carrier modulation.
    """
    time_points = jnp.linspace(tmin, tmax, t_steps)
    num_pieces = len(voltage_levels)
    piece_duration = t_steps // num_pieces

    # Ensure voltage_levels is a JAX array
    voltage_levels = jnp.array(voltage_levels, dtype=jnp.float32)

    V_piecewise = jnp.zeros_like(time_points)

    for i in range(num_pieces):
        if i > 0:
            delta_V = voltage_levels[i] - voltage_levels[i - 1]
            if jnp.abs(delta_V) > max_step:
                # Limit the voltage change to max_step
                voltage_levels = voltage_levels.at[i].set(
                    voltage_levels[i - 1] + jnp.sign(delta_V) * max_step
                )

        start_idx = i * piece_duration
        # Ensure the last piece covers any remaining time steps
        end_idx = (i + 1) * piece_duration if i < num_pieces - 1 else t_steps
        V_piecewise = V_piecewise.at[start_idx:end_idx].set(
            jnp.full(end_idx - start_idx, voltage_levels[i])
        )

    # Carrier modulation
    carrier = jnp.cos(omega_0 * time_points + phi)
    V_t = V_piecewise * carrier

    return V_t


# 1.2 Unitary Matrix Construction
def dn(V):
    return 4e-5 * V

def phi_func(L, n, dn_val, lambda_):
    return 2 * jnp.pi * L * (n + dn_val) / lambda_

def D_func(a, t_val, phi_val):
    return jnp.exp(1j * jnp.pi) * (a - t_val * jnp.exp(-1j * phi_val)) / (1 - t_val * a * jnp.exp(-1j * phi_val))

def U_drmzm_single_channel(V0, V1, L, n0, lambda_0, a0, t0, a1, t1, psi_0=0):
    dn_0 = dn(V0)
    dn_1 = dn(V1)
    phi0 = phi_func(L, n0, dn_0, lambda_0)
    phi1 = phi_func(L, n0, dn_1, lambda_0)

    D_00 = D_func(a0, t0, phi0) * jnp.exp(psi_0)
    D_11 = D_func(a1, t1, phi1)

    U_drmzm_matrix = jnp.array([[D_00, jnp.zeros_like(D_00)], [jnp.zeros_like(D_11), D_11]])
    U_bs = (1 / jnp.sqrt(2)) * jnp.array([[1, 1], [1, -1]])
    U_total = jnp.dot(U_bs, jnp.dot(U_drmzm_matrix, U_bs))

    return U_total

# Add Crosstalk
def U_drmzm_multi_channel(
    V0_t_list, V1_t_list, L, n0, lambda_0, a0, t0, a1, t1,
    N_ch, distances, coupling_lengths, n_eff_list, kappa0, alpha, enable_crosstalk=True
):
    """
    Constructs the multi-channel unitary matrix with optional inter-channel coupling (amplitude and phase crosstalk).

    Parameters:
    V0_t_list, V1_t_list: Control voltages for each channel.
    L, n0, lambda_0, a0, t0, a1, t1: Physical parameters of the system.
    N_ch (int): Number of waveguides (channels).
    distances (jax.numpy.array): Pairwise distances between channels (µm).
    coupling_lengths (jax.numpy.array): Pairwise coupling lengths (µm).
    n_eff_list (jax.numpy.array): Effective refractive indices for each waveguide.
    kappa0 (float): Maximum coupling coefficient.
    alpha (float): Decay rate of coupling with distance.
    enable_crosstalk (bool): If False, disables all crosstalk and returns a diagonal matrix.

    Returns:
    U_multi_channel_with_ct: Combined unitary matrix including (or excluding) crosstalk effects.
    """
    # Step 1: Construct U_multi_channel_no_ct (diagonal matrix without crosstalk)
    U_multi_channel_no_ct = jnp.zeros((N_ch, N_ch), dtype=jnp.complex64)
    for i in range(N_ch):
        V0_t = V0_t_list[i]
        V1_t = V1_t_list[i]
        U_single_channel = U_drmzm_single_channel(V0_t, V1_t, L, n0, lambda_0, a0, t0, a1, t1)
        U_multi_channel_no_ct = U_multi_channel_no_ct.at[i, i].set(U_single_channel[0, 0])

    # If crosstalk is disabled, return the diagonal matrix only
    if not enable_crosstalk:
        return U_multi_channel_no_ct

    # Step 2: Initialize crosstalk matrix
    U_wg_coupling_ct = jnp.eye(N_ch, dtype=jnp.complex64)  # Identity matrix

    # Wave vector (k = 2π / λ)
    k = 2 * jnp.pi / lambda_0

    def compute_transfer_matrix(L_val, beta_1, beta_2, kappa):
        """Compute transfer matrix for two coupled waveguides."""
        delta_beta = (beta_1 - beta_2) / 2.0
        kappa_eff = jnp.sqrt(kappa**2 + delta_beta**2)  # Effective coupling coefficient
        cos_term = jnp.cos(kappa_eff * L_val)
        sin_term = jnp.sin(kappa_eff * L_val)
        delta_term = delta_beta / kappa_eff if kappa_eff != 0 else 0

        return jnp.array([
            [cos_term - 1j * delta_term * sin_term, -1j * sin_term],
            [-1j * sin_term, cos_term + 1j * delta_term * sin_term]
        ])

    # Step 3: Compute direct coupling contributions
    for i in range(N_ch):
        beta_i = k * n_eff_list[i]
        for j in range(i + 1, N_ch):  # Upper triangle only
            beta_j = k * n_eff_list[j]

            if distances[i, j] > 0:
                # Compute coupling coefficient and propagation length
                kappa_ij = kappa0 * jnp.exp(-alpha * distances[i, j])
                L_ij = coupling_lengths[i, j]

                # Compute transfer matrix
                M = compute_transfer_matrix(L_ij, beta_i, beta_j, kappa_ij)

                # Update crosstalk matrix with amplitude and phase contributions
                U_wg_coupling_ct = U_wg_coupling_ct.at[i, j].set(M[0, 1])
                U_wg_coupling_ct = U_wg_coupling_ct.at[j, i].set(jnp.conj(M[0, 1]))

    # Step 4: Calculate combined matrix with crosstalk
    U_multi_channel_with_ct = jnp.matmul(U_multi_channel_no_ct, U_wg_coupling_ct)

    return U_multi_channel_with_ct

# 1.3 SLM and Scattering Matrices
def construct_U_multi_channel_slm(N_slm, phase_mod, amp_mod, t_steps):
    U_slm = jnp.zeros((t_steps, N_slm, N_slm), dtype=jnp.complex64)
    for t in range(t_steps):
        for i in range(N_slm):
            phase = jnp.exp(1j * phase_mod[i])
            amplitude = amp_mod[i]
            U_slm = U_slm.at[t, i, i].set(amplitude * phase)
    return U_slm

def construct_I_prime(N_scat, delta, t_steps):
    I_prime = jnp.zeros((t_steps, N_scat, N_scat), dtype=jnp.complex64)
    for t in range(t_steps):
        I_prime = I_prime.at[t].set(jnp.eye(N_scat) + jnp.diag(jnp.full(N_scat, delta)))
    return I_prime

# 1.4 E-field Calculations
def lg00_mode_profile(X, Y, beam_center, beam_waist):
    cx, cy = beam_center
    r_squared = (X - cx)**2 + (Y - cy)**2
    E_profile = jnp.exp(-r_squared / (2 * beam_waist**2))
    return E_profile

def compute_E_field_for_channel(X, Y, E_t, beam_center, beam_waist, t_idx):
    E_profile = lg00_mode_profile(X, Y, beam_center, beam_waist)
    E_field = E_profile * (jnp.real(E_t) + 1j * jnp.imag(E_t))  # Corrected indexing
    return E_field

def compute_total_E_field_profile(X, Y, b_slm_out, beam_centers, beam_waist):
    E_field_profiles = []
    for t_idx in range(b_slm_out.shape[0]):
        E_field_total = jnp.zeros_like(X, dtype=jnp.complex64)
        for atom_index, beam_center in enumerate(beam_centers):
            E_field_total += compute_E_field_for_channel(X, Y, b_slm_out[t_idx, atom_index], beam_center, beam_waist, t_idx)
        E_field_profiles.append(E_field_total)
    return jnp.array(E_field_profiles)

def extract_E_field_at_atoms(E_field_profiles, atom_positions, X, Y):
    E_field_at_atoms = []

    for t_idx in range(E_field_profiles.shape[0]):
        E_field_at_timestep = []
        for x0, y0 in atom_positions:
            x_idx = jnp.argmin(jnp.abs(X[0, :] - x0))
            y_idx = jnp.argmin(jnp.abs(Y[:, 0] - y0))
            E_field_at_timestep.append(E_field_profiles[t_idx][y_idx, x_idx])
        E_field_at_atoms.append(jnp.array(E_field_at_timestep))
    return jnp.array(E_field_at_atoms)

def compute_alpha_t(E_fields_at_atoms, dipoles, Omega_prefactor_MHz):
    alpha_t = []
    for t_idx in range(E_fields_at_atoms.shape[0]):
        alpha_t_timestep = []
        for atom_idx in range(len(dipoles)):
            E_field_atom = E_fields_at_atoms[t_idx, atom_idx]
            dipole = dipoles[atom_idx]
            alpha_t_timestep.append(Omega_prefactor_MHz * dipole[0] * E_field_atom)
        alpha_t.append(jnp.array(alpha_t_timestep))
    return jnp.array(alpha_t)


# ========================================
# Part 2: Atomic Model
# ========================================

# Two levels per qubit
N_qubit_level = 2

# Define multi-qubit operators and Hamiltonians (as described previously)
s_plus = jnp.array([[0, 1], [0, 0]], dtype=jnp.complex64)
s_minus = jnp.array([[0, 0], [1, 0]], dtype=jnp.complex64)
s_z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64)

def construct_multi_qubit_operator(single_qubit_op, N_a, N_qubit_level, qubit_idx):
    I = jnp.eye(N_qubit_level, dtype=jnp.complex64)
    op_list = [I] * N_a
    op_list[qubit_idx] = single_qubit_op

    multi_qubit_op = op_list[0]
    for op in op_list[1:]:
        multi_qubit_op = jnp.kron(multi_qubit_op, op)

    return multi_qubit_op

# Creation and annihilation operators for the quantized field (Fock mode)
def construct_annihilation_operator(fock_dim):
    a = jnp.zeros((fock_dim, fock_dim), dtype=jnp.complex64)
    for n in range(1, fock_dim):
        a = a.at[n - 1, n].set(jnp.sqrt(n))
    a_dag = a.T.conj()  # Creation operator is the Hermitian conjugate of annihilation operator
    return a, a_dag

# Fock space dimensions
fock_dim = 5
a, a_dag = construct_annihilation_operator(fock_dim)
I_fock = jnp.eye(fock_dim, dtype=jnp.complex64)

# Drift Hamiltonian (H_0)
def construct_H_0(N_a, omega_0, omega_r):  # omega_0/r in MHz
    H_0_qubits = sum(0.5 * omega_0 * construct_multi_qubit_operator(s_z, N_a, 2, i) for i in range(N_a))
    H_0_field = omega_r * jnp.kron(jnp.eye(2 ** N_a), a_dag @ a)
    return jnp.kron(H_0_qubits, I_fock) + H_0_field

# Control Hamiltonian
def construct_H_control(N_a, N_qubit_level, g_real_t, g_imag_t):
    H_control = jnp.zeros((N_qubit_level ** N_a * fock_dim, N_qubit_level ** N_a * fock_dim), dtype=jnp.complex64)
    for i in range(N_a):
        H_control += g_real_t[i] * (jnp.kron(construct_multi_qubit_operator(s_plus, N_a, N_qubit_level, i), a) +
                                    jnp.kron(construct_multi_qubit_operator(s_minus, N_a, N_qubit_level, i), a_dag))
        H_control += g_imag_t[i] * (1j * jnp.kron(construct_multi_qubit_operator(s_plus, N_a, N_qubit_level, i), a) -
                                    1j * jnp.kron(construct_multi_qubit_operator(s_minus, N_a, N_qubit_level, i), a_dag))
    return H_control

# Time Evolution Hamiltonian (H(t))
def construct_H_time(N_a, N_qubit_level, omega_0, omega_r, g_real_t, g_imag_t,
                   atom_positions, gate_type='single'):
    """Construct the time-dependent Hamiltonian H(t) as a JAX array."""
    H_t_list = []  # Store Hamiltonians for each time step

    # Construct base Hamiltonians
    H_0 = construct_H_0(N_a, omega_0, omega_r)

    # Construct H(t) for each time step
    for t in range(len(g_real_t)):
        H_control = construct_H_control(N_a, N_qubit_level, g_real_t[t], g_imag_t[t])
        H_t = H_0 + H_control
        H_t_list.append(H_t)

    # Convert list of Hamiltonians to a JAX array
    return jnp.array(H_t_list)

# Compute accumulated propagator (U(t)) over time
def compute_accumulated_propagator(H_t, dt, N_a, N_qubit_level):
    dim = N_qubit_level ** N_a * fock_dim
    U_accumulated = jnp.eye(dim, dtype=jnp.complex64)
    U_t_all = []
    for H in H_t:
        U_t = jla.expm(-1j * H * dt)
        U_accumulated = U_t @ U_accumulated
        U_t_all.append(U_accumulated)
    return jnp.array(U_t_all)

# Trace out field from unitary matrix
def trace_out_field_from_unitary(U_t, N_a, N_qubit_level, field_dim=5):
    dim_qubits = N_qubit_level ** N_a
    U_t_reshaped = U_t.reshape(dim_qubits, field_dim, dim_qubits, field_dim)
    U_t_traced = jnp.sum(U_t_reshaped, axis=(1, 3))
    return U_t_traced

def trace_out_field_from_unitary_t_all(U_t_all, N_a, N_qubit_level, field_dim=5):
    """
    Trace out the field from the unitary evolution matrix for all time steps.

    Parameters:
    - U_t_all: Time series of unitary matrices with shape (time_steps, dim_total, dim_total),
      where dim_total = dim_qubits * field_dim.
    - N_a: Number of atoms (qubits).
    - N_qubit_level: Number of levels per qubit (typically 2).
    - field_dim: Dimension of the field Hilbert space.

    Returns:
    - U_t_traced_all: Time series of unitary matrices with field traced out, shape (time_steps, dim_qubits, dim_qubits).
    """
    time_steps = U_t_all.shape[0]
    dim_qubits = N_qubit_level**N_a
    total_dim = U_t_all.shape[-1]

    # Ensure dimensions match
    assert dim_qubits * field_dim == total_dim, "Mismatch in total dimensions!"

    # Reshape to separate qubit and field dimensions
    U_t_reshaped = U_t_all.reshape(time_steps, dim_qubits, field_dim, dim_qubits, field_dim)

    # Trace out field dimensions
    U_t_traced_all = jnp.sum(U_t_reshaped, axis=(2, 4))  # Sum over field dimensions

    return U_t_traced_all

# Fidelity Calculation
def compute_fidelity_unitary(U_t_traced, U_target):
    """
    Computes the fidelity of the unitary evolution compared to the target unitary
    using the standard multi-qubit gate fidelity formula.

    Parameters:
    - U_t_traced: Reduced unitary matrix for the qubit system after tracing out the field.
    - U_target: Target unitary matrix.

    Returns:
    - fidelity: The fidelity between the evolved and target unitary matrices.
    """
    # Dimension of the Hilbert space
    d = U_target.shape[0]  # For N qubits, d = 2^N

    # Compute fidelity using the standard formula
    trace_overlap = jnp.abs(jnp.trace(U_target.conj().T @ U_t_traced))**2
    fidelity = trace_overlap / (d**2)

    return jnp.real(fidelity)

# record F(t)
def compute_fidelity_unitary_t_all(U_t_traced_all, U_target):
    """
    Computes the fidelity of the unitary evolution for all time steps
    compared to the target unitary using the standard multi-qubit gate fidelity formula.

    Parameters:
    - U_t_traced_all: A list or array of reduced unitary matrices for the qubit system at all time steps.
    - U_target: Target unitary matrix.

    Returns:
    - fidelity_all: Array of fidelities at each time step.
    """
    # Dimension of the Hilbert space
    d = U_target.shape[0]  # For N qubits, d = 2^N

    def fidelity_at_timestep(U_t_traced):
        trace_overlap = jnp.abs(jnp.trace(U_target.conj().T @ U_t_traced))**2
        return trace_overlap / (d**2)

    # Vectorized computation of fidelity for all time steps
    fidelity_all = jax.vmap(fidelity_at_timestep)(U_t_traced_all)
    return fidelity_all


# Clifford gates
def clifford_group_and_t_gate():
    """
    Constructs the single-qubit Clifford group and includes the T gate.

    Returns:
    - clifford_group: List of tuples containing the gate name and matrix.
    """
    # Define single-qubit gates
    I = jnp.array([[1, 0], [0, 1]], dtype=jnp.complex64)  # Identity
    H = (1 / jnp.sqrt(2)) * jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64)  # Hadamard
    S = jnp.array([[1, 0], [0, 1j]], dtype=jnp.complex64)  # Phase gate
    X = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex64)  # Pauli-X
    Y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex64)  # Pauli-Y
    Z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64)  # Pauli-Z
    T = jnp.array([[1, 0], [0, jnp.exp(1j * jnp.pi / 4)]], dtype=jnp.complex64)  # T gate (π/8 rotation)

    # Clifford group: combinations of I, H, S, X, Y, Z
    # The single-qubit Clifford group consists of 24 elements.
    clifford_group = [
        ("X", X),
        ("I", I),
        ("H", H),
        ("S", S),
        ("Y", Y),
        ("Z", Z),
        ("HS", H @ S),
        ("SH", S @ H),
        ("HX", H @ X),
        ("SX", S @ X),
        ("SY", S @ Y),
        ("SZ", S @ Z),
        ("XH", X @ H),
        ("YH", Y @ H),
        ("ZH", Z @ H),
        ("XS", X @ S),
        ("YS", Y @ S),
        ("ZS", Z @ S),
        ("HSX", H @ S @ X),
        ("HSY", H @ S @ Y),
        ("HSZ", H @ S @ Z),
        ("XHS", X @ H @ S),
        ("YHS", Y @ H @ S),
        ("ZHS", Z @ H @ S),
    ]

    # Include T gate for extended functionality
    clifford_group_with_t = clifford_group + [("T", T)]

    return clifford_group_with_t


# ========================================
# Part 3: Function to Compute Gate Fidelity
# ========================================

def compute_multi_qubit_fidelity_closed_system(
    V0_t_list, V1_t_list, L, n0, lambda_0, a0, t0, a1, t1, phase_mod, amp_mod, delta,
    atom_positions, dipoles, beam_centers, beam_waist, X, Y, Omega_prefactor_MHz,
    t_steps, dt,
    N_ch, distances, coupling_lengths, n_eff_list, kappa0, alpha, enable_crosstalk,
    N_slm, N_ch_slm_in, N_scat_1, N_scat_2, N_a, N_qubit_level, omega_0, omega_r, a_pic, a_scat_1,
    U_target, gate_type='single'
):
    """
    Computes the multi-qubit fidelity for a closed system with a fixed tmax.

    Parameters:
        V0_t_list, V1_t_list: Control voltage time-series for each channel.
        L, n0, lambda_0, a0, t0, a1, t1: Physical parameters of the MZM.
        phase_mod, amp_mod: Phase and amplitude modulation parameters.
        delta: Detuning.
        atom_positions, dipoles: Atom positions and dipole moments.
        beam_centers, beam_waist: Parameters of the spatial light modulator (SLM).
        X, Y: 2D grid for the beam field.
        Omega_prefactor_MHz: Prefactor for Rabi frequency in MHz.
        t_steps, dt: Time parameters.
        N_ch: Number of photonic channels.
        distances: Pairwise distances between channels (µm).
        coupling_lengths: Pairwise coupling lengths (µm).
        n_eff_list: Effective refractive indices for the waveguides.
        kappa0, alpha: Crosstalk parameters.
        enable_crosstalk (bool): If False, crosstalk is disabled.
        N_slm, N_ch_slm_in, N_scat_1, N_scat_2, N_a, N_qubit_level: System configuration.
        omega_0, omega_r: Frequencies in MHz.
        a_pic, a_scat_1: Initial fields.
        U_target: Target unitary matrix.
        gate_type: Type of quantum gate ('single', 'multi', etc.).

    Returns:
        Fidelity of the computed unitary matrix with respect to the target matrix.
    """
    # Compute Unitary Matrix for Multi-Channel System
    U_system_multi_channel = jnp.array([
        U_drmzm_multi_channel(
            [V0_t[i] for V0_t in V0_t_list],
            [V1_t[i] for V1_t in V1_t_list],
            L, n0, lambda_0, a0, t0, a1, t1,
            N_ch, distances, coupling_lengths, n_eff_list,
            kappa0, alpha, enable_crosstalk
        )
        for i in range(t_steps)
    ])

    # Compute scattering and SLM matrices
    I1_prime = construct_I_prime(N_scat_1, delta, t_steps)
    U_tensor_product = jnp.array([jnp.kron(U_system_multi_channel[i], I1_prime[i]) for i in range(t_steps)])

    # Output modes calculations
    a_total = jnp.kron(a_pic, a_scat_1)
    output_modes = jnp.array([jnp.dot(U_tensor_product[i], a_total) for i in range(t_steps)])
    output_modes_reshaped = output_modes.reshape(t_steps, N_ch, N_scat_1)
    a_pic_out = jnp.sum(output_modes_reshaped, axis=-1)
    a_scat_1_out = jnp.sum(output_modes_reshaped, axis=-2)

    # Compute final output modes after SLM and second scattering
    b_slm_in = jnp.zeros((t_steps, N_slm), dtype=jnp.complex64)
    b_scat2_in = jnp.zeros((t_steps, N_scat_2), dtype=jnp.complex64)

    for t in range(t_steps):
        b_slm_in = b_slm_in.at[t, :N_ch_slm_in].set(a_pic_out[t, :N_ch_slm_in])
        b_slm_in = b_slm_in.at[t, N_ch_slm_in:].set(a_scat_1_out[t, :(N_slm - N_ch_slm_in)])

        b_scat2_in = b_scat2_in.at[t, :(N_ch - N_ch_slm_in)].set(a_pic_out[t, N_ch_slm_in:])
        b_scat2_in = b_scat2_in.at[t, (N_ch - N_ch_slm_in):].set(a_scat_1_out[t, (N_scat_2 - (N_ch - N_ch_slm_in)):])

    b_total_in = jnp.array([jnp.kron(b_slm_in[t], b_scat2_in[t]) for t in range(t_steps)])
    U_multi_channel_slm = construct_U_multi_channel_slm(N_slm, phase_mod, amp_mod, t_steps)
    I2_prime = construct_I_prime(N_scat_2, delta, t_steps)
    U_tensor_product_stage_2 = jnp.array([jnp.kron(U_multi_channel_slm[t], I2_prime[t]) for t in range(t_steps)])
    b_total_out = jnp.array([jnp.dot(U_tensor_product_stage_2[t], b_total_in[t]) for t in range(t_steps)])

    b_total_out_reshaped = b_total_out.reshape(t_steps, N_slm, N_scat_2)
    b_slm_out = jnp.sum(b_total_out_reshaped, axis=-1)

    # Compute the total E-field on the atom plane
    E_field_profiles = compute_total_E_field_profile(X, Y, b_slm_out, beam_centers, beam_waist)

    # Extract E-fields at the atom positions
    E_fields_at_atoms = extract_E_field_at_atoms(E_field_profiles, atom_positions, X, Y)

    # Compute interaction strength α(t) for each atom
    alpha_t = compute_alpha_t(E_fields_at_atoms, dipoles, Omega_prefactor_MHz)

    # g_real_t and g_imag_t represent g(t) in the Jaynes-Cummings model, g(t) = eff_Rabi(t)/2
    g_real_t = alpha_t.real / 2
    g_imag_t = alpha_t.imag / 2

    # Construct the time-dependent Hamiltonian
    H_t = construct_H_time(N_a, N_qubit_level, omega_0, omega_r, g_real_t, g_imag_t, atom_positions, gate_type)

    # Evolve the unitary matrix (propagator)
    U_t_all = compute_accumulated_propagator(H_t, dt, N_a, N_qubit_level)

    # Trace out the field from the final unitary matrix
    U_t_traced = trace_out_field_from_unitary(U_t_all[-1], N_a, N_qubit_level)
    U_t_traced_all = trace_out_field_from_unitary_t_all(U_t_all, N_a, N_qubit_level)

    # Compute fidelity with the target gate
    fidelity = compute_fidelity_unitary(U_t_traced, U_target)
    fidelity_all = compute_fidelity_unitary_t_all(U_t_traced_all, U_target)

    return fidelity_all

# ========================================
# Part 4: Program Instruction function
# ========================================
def program_instruction(N_a, key_number, gate_type='single', selected_atoms=None, control_atom=None, target_atom=None):
    """
    Constructs the target gates for selected atoms. If `gate_type` is 'single', it applies single-qubit gates.
    Logs the gate applied to each atom.
    """
    if gate_type == 'single':
        # Define the Clifford gates for single-qubit gates
        clifford_t_gates = clifford_group_and_t_gate()
        U_target_multi_qubit = []
        key = jrandom.PRNGKey(key_number)  # Initialize the random key for gate selection

        # Apply random single-qubit gates to selected atoms
        for atom_idx in range(N_a):
            if selected_atoms is None or atom_idx in selected_atoms:
                key, subkey = jrandom.split(key)  # Split key for each atom to ensure different gates
                gate_idx = jrandom.randint(subkey, (), minval=0, maxval=len(clifford_t_gates))
                gate_name, gate = clifford_t_gates[gate_idx]
                logger.info(f"Applying {gate_name} gate to Atom {atom_idx}")
                U_target_multi_qubit.append(gate)
            else:
                logger.info(f"Applying Identity gate to Atom {atom_idx}")
                U_target_multi_qubit.append(jnp.eye(2))  # Identity gate for unselected atoms

        # Combine the gates using the Kronecker product
        U_target = U_target_multi_qubit[0]
        for gate in U_target_multi_qubit[1:]:
            U_target = jnp.kron(U_target, gate)

        return U_target


# ========================================
# Part 5: Modified Custom Gymnasium Environment with Warm-Up Start
# ========================================

class QOCEnv(gym.Env):
    """
    Custom Environment for Quantum Optimal Control using Gymnasium interface.
    Enhanced with warm-up start for better initial performance.
    """
    metadata = {'render.modes': ['human']}

    def __init__(self, config):
        super(QOCEnv, self).__init__()

        # Store the entire config for access in other methods
        self.config = config

        # Configuration parameters
        self.t_steps = self.config.get('t_steps', 500)
        self.N_ch = self.config.get('N_ch', 6)
        self.N_a = self.config.get('N_a', 3)
        self.U_target = self.config.get('system_params', {}).get('U_target')

        if self.U_target is None:
            raise ValueError("U_target is not defined in the configuration.")

        # Control voltage parameters
        self.piecewise_segments = self.config.get('piecewise_segments', 50)
        self.min_voltage = self.config.get('min_voltage', -15.0)
        self.max_voltage = self.config.get('max_voltage', 15.0)
        self.max_delta_voltage = self.config.get('max_delta_voltage', 1.0)

        # History parameters
        self.history_length = self.config.get('history_length', 5)

        # Define action and observation space
        self.action_space = spaces.Box(
            low=-1.0,
            high=1.0,
            shape=(self.N_ch,),
            dtype=np.float32
        )

        # Observation space
        obs_shape = (self.history_length, self.N_ch, self.piecewise_segments + 2)
        self.observation_space = spaces.Box(
            low=np.full(obs_shape, self.min_voltage, dtype=np.float32),
            high=np.full(obs_shape, self.max_voltage, dtype=np.float32),
            shape=obs_shape,
            dtype=np.float32
        )

        # Initialize variables
        self.current_step = 0
        self.done = False
        self.fidelity = 0.0
        self.max_steps = self.config.get('max_steps', 100)

        # Track best fidelity and corresponding voltage
        self.best_fidelity = 0.0
        self.best_voltage = np.zeros((self.N_ch, self.piecewise_segments), dtype=np.float32)

        # Store system parameters
        self.system_params = self.config.get('system_params', {})
        self.APIC_params = self.config.get('APIC_params', {})
        self.atom_beam_params = self.config.get('atom_beam_params', {})
        self.control_Vt_params = self.config.get('control_Vt_params', {})

        # Initialize voltage levels
        self.current_voltage = self.np_random.uniform(
            low=self.min_voltage,
            high=self.max_voltage,
            size=(self.N_ch, self.piecewise_segments)
        ).astype(np.float32)

        # Initialize history buffer
        self.voltage_history = deque(maxlen=self.history_length)
        fidelity_array = np.full((self.N_ch, 1), self.fidelity, dtype=np.float32)
        self.current_segment = 0
        segment_array = np.full((self.N_ch, 1), self.current_segment, dtype=np.float32)
        initial_observation = np.hstack((self.current_voltage, fidelity_array, segment_array))
        for _ in range(self.history_length):
            self.voltage_history.append(initial_observation.copy())

        # Initialize logging
        self.logger = logging.getLogger(__name__)

        # Initialize random number generator
        self.seed()

        # Extract scaling methods from config
        self.scaling_methods = self.config.get('reward_scaling_params', {}).get('scaling_methods', [
            {'method': 'linear', 'min_ratio': 0.0, 'max_ratio': 0.6},
            {'method': 'quadratic', 'min_ratio': 0.6, 'max_ratio': 0.9},
            {'method': 'exponential', 'min_ratio': 0.9, 'max_ratio': 1.0},
        ])

        # Extract target fidelity
        self.target_fidelity = self.config.get('target_fidelity', 0.999)

        # Define other reward scaling parameters
        self.a = self.config.get('reward_scaling_params', {}).get('a', 1.0)  # Scaling factor for absolute fidelity
        self.b = self.config.get('reward_scaling_params', {}).get('b', 1.0)  # Scaling factor for fidelity improvement
        self.method = 'linear'  # Initial scaling method
        self.k = self.config.get('reward_scaling_params', {}).get('k', 5.0)  # Relevant for exponential scaling
        self.thresholds = self.config.get('reward_scaling_params', {}).get('thresholds', [0.8, 0.9])  # Fidelity thresholds for bonuses
        self.bonus = self.config.get('reward_scaling_params', {}).get('bonus', 5.0)  # Bonus reward for crossing thresholds
        self.clip_min = self.config.get('reward_scaling_params', {}).get('clip_min', -10.0)  # Minimum reward
        self.clip_max = self.config.get('reward_scaling_params', {}).get('clip_max', 10.0)  # Maximum reward
        self.step_penalty = self.config.get('reward_scaling_params', {}).get('step_penalty', 0.01)  # Penalty per step

    def seed(self, seed: Optional[int] = None):
        """
        Set the seed for this environment's random number generator(s).
        """
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]

    def select_scaling_method(self, best_fidelity):
        """
        Selects the scaling method based on the progress towards the target fidelity.
        """
        ratio = best_fidelity / self.target_fidelity
        new_method = None

        for scaling in self.scaling_methods:
            if scaling['min_ratio'] <= ratio < scaling['max_ratio']:
                new_method = scaling['method']
                break

        # If ratio exceeds the highest threshold
        if new_method is None and ratio >= self.scaling_methods[-1]['max_ratio']:
            new_method = self.scaling_methods[-1]['method']

        if new_method and new_method != self.method:
            self.logger.info(
                f"Scaling method changed from {self.method} to {new_method} based on best fidelity ratio {ratio:.2f}."
            )
            self.method = new_method

    def compute_scaled_reward(self, final_fidelity, improvement):
        """
        Computes the scaled reward based on the chosen method, improvement, and threshold bonuses.
        """
        # Select the appropriate scaling method based on current best fidelity
        self.select_scaling_method(self.best_fidelity)

        # Scale raw fidelity based on the chosen method
        if self.method == 'linear':
            scaled_fidelity = self.a * final_fidelity
        elif self.method == 'quadratic':
            scaled_fidelity = self.a * (final_fidelity ** 2)
        elif self.method == 'exponential':
            scaled_fidelity = self.a * (math.exp(self.k * final_fidelity) - 1)
        elif self.method == 'log':
            scaled_fidelity = self.a * (math.log(self.k * final_fidelity) - 1)
        elif self.method == 'sigmoid':
            # Sigmoid scaling: reward increases sharply around a central fidelity ratio
            c = 0.95  # Central point where sigmoid is steepest (adjust as needed)
            scaled_fidelity = self.a * (1 / (1 + math.exp(-self.k * (final_fidelity - c))))
        else:
            raise ValueError("Unsupported scaling method")

        # Add improvement reward
        if improvement > 0:
            scaled_fidelity += self.b * improvement

        # Add threshold bonuses
        threshold_bonus = 0.0
        for threshold in self.thresholds:
            # Make thresholds relative to target fidelity
            relative_threshold = threshold * self.target_fidelity
            if self.best_fidelity < relative_threshold <= final_fidelity:
                threshold_bonus += self.bonus
        scaled_fidelity += threshold_bonus

        return scaled_fidelity

    def step(self, action):
        """
        Apply an action (voltage adjustments for the current segment),
        compute fidelity, and return the observation, reward, terminated, truncated, and info.
        Implements enhanced reward design with adaptive and hybrid scaling.
        """
        if self.done:
            raise ValueError("Episode has ended. Call reset() to start a new episode.")

        # Clip the action to ensure it's within [-1, 1]
        action = np.clip(action, -1.0, 1.0)

        # Scale actions to the desired delta voltage range
        scaled_action = action * self.max_delta_voltage  # Adjusted to max_delta_voltage

        # Apply delta V to the current segment
        self.current_voltage[:, self.current_segment] += scaled_action
        self.current_voltage[:, self.current_segment] = np.clip(
            self.current_voltage[:, self.current_segment],
            self.min_voltage,
            self.max_voltage
        )

        # Compute fidelity up to the current segment
        V0_t_list = jnp.array(self.current_voltage[:, :self.current_segment + 1])
        V1_t_list = jnp.array(self.current_voltage[:, :self.current_segment + 1])

        fidelity_all = compute_multi_qubit_fidelity_closed_system(
            V0_t_list, V1_t_list,
            self.APIC_params.get('L'), self.APIC_params.get('n0'),
            self.APIC_params.get('lambda_0'), self.APIC_params.get('a0'), self.APIC_params.get('t0'),
            self.APIC_params.get('a1'), self.APIC_params.get('t1'), self.APIC_params.get('phase_mod'),
            self.APIC_params.get('amp_mod'), self.system_params.get('delta'),
            self.atom_beam_params.get('atom_positions'), self.atom_beam_params.get('dipoles'),
            self.atom_beam_params.get('beam_centers'), self.atom_beam_params.get('beam_waist'),
            self.atom_beam_params.get('X'), self.atom_beam_params.get('Y'), self.atom_beam_params.get('Omega_prefactor_MHz'),
            self.control_Vt_params.get('t_steps'), self.control_Vt_params.get('dt'),
            self.system_params.get('N_ch'), self.system_params.get('distances'), self.system_params.get('coupling_lengths'),
            self.system_params.get('n_eff_list'), self.system_params.get('kappa0'), self.system_params.get('alpha'),
            self.system_params.get('enable_crosstalk'),
            self.system_params.get('N_slm'), self.system_params.get('N_ch_slm_in'), self.system_params.get('N_scat_1'),
            self.system_params.get('N_scat_2'), self.system_params.get('N_a'),
            self.system_params.get('N_qubit_level'), self.system_params.get('omega_0'),
            self.system_params.get('omega_r'), self.system_params.get('a_pic'),
            self.system_params.get('a_scat_1'),
            self.U_target, self.system_params.get('gate_type')
        )

        # Get unclipped fidelity
        try:
            unclipped_fidelity = float(jnp.real(fidelity_all[-1]))
        except (IndexError, TypeError) as e:
            self.logger.error(f"Error computing fidelity: {e}")
            unclipped_fidelity = 0.0  # Assign a default value or handle appropriately

        # Clip fidelity to [0,1]
        final_fidelity = np.clip(unclipped_fidelity, 0.0, 1.0)

        # Calculate fidelity improvement
        improvement = final_fidelity - self.best_fidelity

        # Compute scaled reward
        scaled_reward = self.compute_scaled_reward(final_fidelity, improvement)

        # Clip the scaled reward
        scaled_reward = np.clip(scaled_reward, self.clip_min, self.clip_max)

        # Initialize reward
        reward = scaled_reward

        if unclipped_fidelity > 1.0:
            # Reject the action by reverting to the best_voltage
            self.current_voltage = self.best_voltage.copy()

            # Assign a significant penalty
            penalty = -1000.0  # Adjust the penalty value as needed
            reward = penalty

            # Terminate the episode
            self.done = True

            # Revert fidelity to best_fidelity
            final_fidelity = self.best_fidelity
            self.fidelity = final_fidelity
        else:
            # Update best_fidelity and best_voltage if improvement occurred
            if improvement > 0:
                self.best_fidelity = final_fidelity
                self.best_voltage = self.current_voltage.copy()

            # Always update self.fidelity to the latest final_fidelity
            self.fidelity = final_fidelity

        # Advance to next segment
        self.current_segment += 1
        if self.current_segment >= self.piecewise_segments:
            self.done = True  # Terminate episode after all segments are controlled

        # Initialize info dict with the correct final_fidelity
        info = {
            'final_fidelity': final_fidelity,  # Updated to reflect current step's fidelity
            'current_step': self.current_step,
            'current_segment': self.current_segment
        }

        # Include best_voltage in info if updated
        if improvement > 0:
            info['best_voltage'] = self.best_voltage.copy()

        # Update history buffer
        fidelity_array = np.full((self.N_ch, 1), self.fidelity, dtype=np.float32)
        segment_array = np.full((self.N_ch, 1), self.current_segment, dtype=np.float32)
        observation_component = np.hstack((self.current_voltage, fidelity_array, segment_array))
        self.voltage_history.append(observation_component.copy())

        # Stack the history into the observation
        observation = np.stack(self.voltage_history, axis=0)

        # Determine 'terminated' and 'truncated'
        terminated = False  # Define additional termination conditions if any
        truncated = self.done  # Episode truncated due to step limit or target achieved

        # Enhanced Logging: Detailed debug information
        self.logger.debug(f"Segment: {self.current_segment}, Reward: {reward:.4f}, Fidelity: {self.fidelity:.4f}, Done: {self.done}")

        return observation.astype(np.float32), reward, terminated, truncated, info

    def render(self, mode='human'):
        """
        Render the environment to the screen.
        """
        if mode == 'human':
            print(f"Segment: {self.current_segment} | Fidelity: {self.fidelity:.4f}")

    def close(self):
        """
        Clean up resources.
        """
        pass

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        try:
            super().reset(seed=seed)
            self.current_step = 0
            self.done = False
            self.fidelity = 0.0
            self.best_fidelity = 0.0
            self.best_voltage = np.zeros((self.N_ch, self.piecewise_segments), dtype=np.float32)

            # **Warm-Up Initialization: Initialize voltages to zero or small random values around zero**
            self.current_voltage = self.np_random.normal(
                loc=0.0,  # Mean voltage
                scale=1.0,  # Small standard deviation to introduce slight variations
                size=(self.N_ch, self.piecewise_segments)
            ).astype(np.float32)
            self.current_voltage = np.clip(
                self.current_voltage,
                self.min_voltage,
                self.max_voltage
            )

            # **Optional: Fix the first few segments to zero for smoother starts**
            fixed_segments = 5  # Number of initial segments to fix
            self.current_voltage[:, :fixed_segments] = 0.0  # Or another heuristic initial voltage

            # Initialize history buffer
            self.voltage_history = deque(maxlen=self.history_length)
            fidelity_array = np.full((self.N_ch, 1), self.fidelity, dtype=np.float32)
            self.current_segment = 0  # Reset to first segment
            segment_array = np.full((self.N_ch, 1), self.current_segment, dtype=np.float32)
            initial_observation = np.hstack((self.current_voltage, fidelity_array, segment_array))
            for _ in range(self.history_length):
                self.voltage_history.append(initial_observation.copy())

            # Stack the history into the observation
            observation = np.stack(self.voltage_history, axis=0)
            return observation.astype(np.float32), {}
        except Exception as e:
            self.logger.error(f"Error during reset: {e}")
            raise e  # Re-raise to notify the worker

# ========================================
# Part 6: Modified Configuration Setup
# ========================================

def create_config(simple=False):
    # Define APIC parameters
    APIC_params = {
        'L': 600 * 780e-9 / 1.95,  # Example calculation based on your main code
        'n0': 1.95,
        'lambda_0': 780e-9,
        'a0': 0.998,
        't0': 0.998,
        'a1': 0.998,
        't1': 0.998,
        'phase_mod': jnp.zeros(6),  # Assuming N_slm = 6
        'amp_mod': jnp.ones(6)      # Assuming N_slm = 6
    }

    # Define atom beam parameters
    X, Y = generate_grid(grid_size=600, grid_range=(-6, 6))
    atom_positions = generate_atom_positions_equilateral(3, side_length=3.0, center=(0, 0))  # N_a = 3
    dipoles = generate_dipoles(3)  # N_a = 3
    beam_centers = atom_positions
    beam_waist = 2.0  # µm
    Omega_prefactor_MHz = calculate_Omega_rabi_prefactor(20, 1000)  # I_mW_per_cm2=20, Detuning_MHz=1000
    a_pic = jnp.array([1.0] * 6)      # N_ch = 6
    a_scat_1 = jnp.array([1.0] * 4)   # N_scat_1 = 4

    atom_beam_params = {
        'atom_positions': atom_positions,
        'dipoles': dipoles,
        'beam_centers': beam_centers,
        'beam_waist': beam_waist,
        'X': X,
        'Y': Y,
        'Omega_prefactor_MHz': Omega_prefactor_MHz,
    }

    # Define control voltage parameters
    t_steps = 100  # Increased from 100 to allow finer control
    tmax_fixed = 0.1  # Fixed tmax in microseconds
    dt = tmax_fixed / t_steps
    control_Vt_params = {
        'tmin': 0.0,
        'tmax': tmax_fixed,
        't_steps': t_steps,
        'dt': dt,
        'min_V_level': -15.0,
        'max_V_level': 15.0
    }

    # Define system parameters
    system_params = {
        'a_pic': a_pic,
        'a_scat_1': a_scat_1,
        'delta': 0.001,
        'N_ch': 6,
        'distances': generate_distances(6, 1.0, random_variation=0.1, seed=42),
        'coupling_lengths': generate_coupling_lengths(6, 600.0, scaling_factor=1.1, random_variation=60.0, seed=42),
        'n_eff_list': generate_n_eff_list(6, 1.75, random_variation=0.05, seed=42),
        'kappa0': 10.145,
        'alpha': 6.934,
        'enable_crosstalk': True,
        'N_slm': 6,
        'N_ch_slm_in': 4,
        'N_scat_1': 4,
        'N_scat_2': 6 + 4 - 6,  # N_total - N_slm
        'N_a': 3,
        'N_qubit_level': 2,
        'omega_0': 6.835e3,  # MHz
        'omega_r': 6.835e3,  # MHz
        'a_pic': a_pic,
        'a_scat_1': a_scat_1,
        'gate_type': 'single'
    }

    # Generate the target unitary
    selected_atoms = [0, 1, 2]  # Set to [0, 1, 2] as per N_a = 3
    system_params['U_target'] = program_instruction(
        N_a=3,
        key_number=12,
        gate_type='single',
        selected_atoms=selected_atoms
    )

    # Defensive check to ensure U_target is not None
    if system_params['U_target'] is None:
        raise ValueError("U_target was not correctly assigned by program_instruction.")

    # Define reward scaling parameters with hybrid scaling methods
    reward_scaling_params = {
        'a': 1.0,  # Scaling factor for absolute fidelity
        'b': 1.0,  # Scaling factor for fidelity improvement
        'scaling_methods': [  # Define scaling methods with relative thresholds
            {'method': 'log', 'min_ratio': 0.0, 'max_ratio': 0.5},
            {'method': 'linear', 'min_ratio': 0.5, 'max_ratio': 0.9},
            {'method': 'quadratic', 'min_ratio': 0.9, 'max_ratio': 1.0},
        ],
        'k': 5.0,  # Relevant for exponential scaling
        'thresholds': [0.8, 0.9],  # Fidelity thresholds for bonuses (can also be made relative)
        'bonus': 5.0,  # Bonus reward for crossing thresholds
        'clip_min': -10.0,  # Minimum reward
        'clip_max': 10.0,  # Maximum reward
        'step_penalty': 0.01  # Penalty per step
    }

    config = {
        't_steps': t_steps,
        'N_ch': 6,
        'N_a': 3,
        'piecewise_segments': 50,  # Number of voltage segments
        'min_voltage': -15.0,
        'max_voltage': 15.0,
        'history_length': 5,  # Number of past states to include
        'max_delta_voltage': 1.0,  # Added parameter for action scaling
        'APIC_params': APIC_params,          # Unchanged
        'system_params': system_params,      # Only system-related parameters
        'atom_beam_params': atom_beam_params,
        'control_Vt_params': control_Vt_params,
        'reward_scaling_params': reward_scaling_params,  # Updated reward scaling parameters
        'stagnant_threshold': 10,           # Number of stagnant episodes before termination
        'stagnant_fidelity_min': 0.99,      # Minimum fidelity to consider stagnation
        'target_fidelity': 0.999            # Desired fidelity threshold to achieve
    }

    return config

# ========================================
# Part 7: Define the Custom Policy
# ========================================

class CorrelationFeatureExtractor(nn.Module):
    """
    Custom feature extractor using convolutional layers to capture spatial and temporal correlations.
    """
    def __init__(self, observation_space: gym.Space, *, features_dim: int = 256):
        """
        Initializes the feature extractor.

        Parameters:
        - observation_space (gym.Space): The observation space of the environment.
        - features_dim (int): The dimension of the extracted features.
        """
        super(CorrelationFeatureExtractor, self).__init__()

        # Validate observation space
        assert isinstance(observation_space, spaces.Box), "Observation space must be of type Box"
        obs_shape = observation_space.shape  # Expected shape: (history_length, N_ch, piecewise_segments + 2)
        history_length, N_ch, piecewise_segments_plus_2 = obs_shape

        # Define convolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=history_length, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Flatten()
        )

        # Compute the size of the flattened features
        with torch.no_grad():
            dummy_input = torch.zeros(1, history_length, N_ch, piecewise_segments_plus_2)
            conv_output = self.conv_layers(dummy_input)
            conv_output_size = conv_output.shape[1]

        # Define fully connected layer
        self.fc = nn.Sequential(
            nn.Linear(conv_output_size, features_dim),
            nn.ReLU()
        )

        # **Important:** Set the features_dim attribute for SB3
        self.features_dim = features_dim

    def forward(self, x):
        """
        Forward pass for the feature extractor.

        Parameters:
        - x (torch.Tensor): Tensor of shape (batch_size, history_length, N_ch, piecewise_segments + 2)

        Returns:
        - torch.Tensor: Extracted features of shape (batch_size, features_dim)
        """
        x = self.conv_layers(x)
        x = self.fc(x)
        return x

class CustomActorCriticPolicy(ActorCriticPolicy):
    """
    Custom Actor-Critic Policy with convolutional feature extractor to capture correlations.
    """
    def __init__(self, *args, **kwargs):
        """
        Initializes the custom policy.

        Parameters:
        - *args, **kwargs: Arguments and keyword arguments passed to the parent class.
        """
        super(CustomActorCriticPolicy, self).__init__(
            *args,
            **kwargs,
            features_extractor_class=CorrelationFeatureExtractor,
            features_extractor_kwargs=dict(features_dim=256),
            net_arch=[256, 256, 256],  # Define network architecture for actor and critic
            activation_fn=nn.ReLU
        )

    def forward(self, obs, deterministic=False):
        """
        Forward pass to get actions, values, and log probabilities.

        Parameters:
        - obs (torch.Tensor): Observation tensor.
        - deterministic (bool): Whether to sample actions deterministically.

        Returns:
        - actions (torch.Tensor): Action tensor.
        - values (torch.Tensor): Value estimates.
        - log_probs (torch.Tensor): Log probabilities of actions.
        """
        actions, values, log_probs = super().forward(obs, deterministic=deterministic)
        return actions, values, log_probs  # Correct order

# ========================================
# Part 8: Modified Custom Callback
# ========================================

class RLLoggingCallback(BaseCallback):
    """
    Custom callback for logging additional metrics during training.
    Logs to both console and TensorBoard.
    Implements early termination if best_fidelity is stuck between stagnant_fidelity_min and target_fidelity for too many episodes.
    Also saves the final optimal V(t) and plots upon termination.
    """
    def __init__(self, stagnant_threshold=100,
                 stagnant_fidelity_min=0.99,
                 target_fidelity=0.999,
                 results_dir='results',
                 verbose=0):
        """
        Initializes the callback.

        Parameters:
        - stagnant_threshold (int): Number of consecutive stagnant episodes before terminating training.
        - stagnant_fidelity_min (float): Minimum fidelity to consider stagnation.
        - target_fidelity (float): Desired fidelity threshold to achieve.
        - results_dir (str): Directory where results will be saved.
        - verbose (int): Verbosity level.
        """
        super(RLLoggingCallback, self).__init__(verbose)
        self.best_fidelity = 0.0
        self.best_voltage = None  # To store the best voltage
        self.best_fidelity_history = []  # Track best fidelity per episode
        self.episode_count = 0
        self.stagnant_counter = 0
        self.stagnant_threshold = stagnant_threshold
        self.stagnant_fidelity_min = stagnant_fidelity_min
        self.target_fidelity = target_fidelity
        self.start_time = time.time()
        self.results_dir = results_dir

        # Create the results directory if it doesn't exist
        os.makedirs(self.results_dir, exist_ok=True)

    def _on_step(self) -> bool:
        # Access the current episode's info
        infos = self.locals.get('infos', [])
        improved = False  # Flag to check if any env improved best_fidelity
        new_best_voltage = None  # To store the new best_voltage if any

        # Track the maximum fidelity achieved in this step across all envs
        max_fidelity_this_step = self.best_fidelity

        for env_id, info in enumerate(infos):
            if 'final_fidelity' in info:
                fidelity = info['final_fidelity']
                self.episode_count += 1

                # Check if this episode has a new best fidelity
                if fidelity > self.best_fidelity:
                    self.best_fidelity = fidelity
                    self.stagnant_counter = 0
                    improved = True
                    # Save the corresponding best_voltage
                    new_best_voltage = info.get('best_voltage', None)
                    if new_best_voltage is not None:
                        self.best_voltage = new_best_voltage

                    # Log the improvement
                    logger.info(
                        f"New best fidelity: {self.best_fidelity:.4f} achieved in Env {env_id} | Episode: {self.episode_count}"
                    )

                # Update the maximum fidelity for this step
                if fidelity > max_fidelity_this_step:
                    max_fidelity_this_step = fidelity

        # Append the global best_fidelity to history
        self.best_fidelity_history.append(self.best_fidelity)

        # After processing all envs in this step
        # If no improvement was made in this step
        if not improved:
            # Check if best_fidelity is within the stagnation range
            if self.stagnant_fidelity_min <= self.best_fidelity < self.target_fidelity:
                self.stagnant_counter += 1
                logger.info(
                    f"Stagnant Counter Incremented: {self.stagnant_counter}/{self.stagnant_threshold} | Best Fidelity: {self.best_fidelity:.4f}"
                )
            else:
                # Reset the stagnant_counter if best_fidelity is below the min or already reached the target
                self.stagnant_counter = 0

        # Check for termination based on target fidelity
        if self.best_fidelity >= self.target_fidelity:
            elapsed_time = time.time() - self.start_time
            logger.info(
                f"Target fidelity of {self.best_fidelity:.4f} achieved in {self.episode_count} episodes. "
                f"Training terminated."
            )
            # Save the best_voltage before terminating
            if self.best_voltage is not None:
                best_voltage_final_path = os.path.join(self.results_dir, "best_voltage_final.npy")
                np.save(best_voltage_final_path, self.best_voltage)
                logger.info(f"Final best voltage saved to '{best_voltage_final_path}'.")


            # Save Best Fidelity
            best_fidelity_path = os.path.join(self.results_dir, "best_fidelity.txt")
            with open(best_fidelity_path, "w") as f:
                f.write(f"Best Fidelity: {self.best_fidelity:.6f}\n")
            logger.info(f"Best fidelity saved to '{best_fidelity_path}'.")

            # Save Best Fidelity History as .npy and .csv
            best_fidelity_history_path_npy = os.path.join(self.results_dir, "best_fidelity_history.npy")
            best_fidelity_history_path_csv = os.path.join(self.results_dir, "best_fidelity_history.csv")
            np.save(best_fidelity_history_path_npy, np.array(self.best_fidelity_history))
            np.savetxt(best_fidelity_history_path_csv, np.array(self.best_fidelity_history), delimiter=",", header="Episode,Best_Fidelity")
            logger.info(f"Best fidelity history saved to '{best_fidelity_history_path_npy}' and '{best_fidelity_history_path_csv}'.")

            # Plot Fidelity Progress and Save as .png and .csv
            self._plot_fidelity_progress()

            return False  # Returning False stops the training

        # Check for termination based on stagnation
        if self.stagnant_counter >= self.stagnant_threshold:
            elapsed_time = time.time() - self.start_time
            logger.warning(
                f"Training is stagnating after {self.episode_count} episodes. "
                f"Best Fidelity: {self.best_fidelity:.4f}. Terminating training."
            )

            # Save the best_voltage before terminating
            if self.best_voltage is not None:
                best_voltage_final_path = os.path.join(self.results_dir, "best_voltage_final.npy")
                np.save(best_voltage_final_path, self.best_voltage)
                logger.info(f"Final best voltage saved to '{best_voltage_final_path}'.")


            # Save Best Fidelity
            best_fidelity_path = os.path.join(self.results_dir, "best_fidelity.txt")
            with open(best_fidelity_path, "w") as f:
                f.write(f"Best Fidelity: {self.best_fidelity:.6f}\n")
            logger.info(f"Best fidelity saved to '{best_fidelity_path}'.")

            # Save Best Fidelity History as .npy and .csv
            best_fidelity_history_path_npy = os.path.join(self.results_dir, "best_fidelity_history.npy")
            best_fidelity_history_path_csv = os.path.join(self.results_dir, "best_fidelity_history.csv")
            np.save(best_fidelity_history_path_npy, np.array(self.best_fidelity_history))
            np.savetxt(best_fidelity_history_path_csv, np.array(self.best_fidelity_history), delimiter=",", header="Episode,Best_Fidelity")
            logger.info(f"Best fidelity history saved to '{best_fidelity_history_path_npy}' and '{best_fidelity_history_path_csv}'.")

            # Plot Fidelity Progress and Save as .png and .csv
            self._plot_fidelity_progress()

            return False  # Returning False stops the training

        # Log to console and file
        elapsed_time = time.time() - self.start_time
        logger.info(
            f"Training Status | Episodes: {self.episode_count} | Best Fidelity: {self.best_fidelity:.4f} | "
            f"Stagnant Counter: {self.stagnant_counter}/{self.stagnant_threshold} | "
            f"Elapsed Time: {elapsed_time/60:.2f} min"
        )

        return True  # Continue training


    def _plot_fidelity_progress(self):
        """
        Plots the global best fidelity vs. episodes and saves it as 'fidelity_progress.png'.
        Also saves the fidelity data as 'fidelity_progress.csv'.
        """
        try:
            episodes = np.arange(1, len(self.best_fidelity_history) + 1)
            best_fidelity = np.array(self.best_fidelity_history)

            # Plot fidelity progress
            plt.figure(figsize=(12, 6))
            plt.plot(episodes, best_fidelity, label='Best Fidelity')
            plt.title('Global Best Fidelity Progress Over Episodes')
            plt.xlabel('Episode')
            plt.ylabel('Best Fidelity')
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            fidelity_plot_path = os.path.join(self.results_dir, "fidelity_progress.png")
            plt.savefig(fidelity_plot_path)
            plt.close()
            logger.info(f"Fidelity progress plot saved as '{fidelity_plot_path}'.")

            # Save fidelity progress data as .csv
            fidelity_progress_csv = os.path.join(self.results_dir, "fidelity_progress.csv")
            df = pd.DataFrame({
                'Episode': episodes,
                'Best_Fidelity': best_fidelity
            })
            df.to_csv(fidelity_progress_csv, index=False)
            logger.info(f"Fidelity progress data saved as '{fidelity_progress_csv}'.")
        except Exception as e:
            logger.error(f"Error while plotting or saving fidelity progress: {e}")

# ========================================
# Part 9: Modified Environment Creation Function
# ========================================
def make_env(config, seed):
    def _init():
        env = QOCEnv(config)
        env.seed(seed)
        return env
    return _init

# ========================================
# Part 10: Modified Main Function
# ========================================
def main():
    # Define the results directory first to avoid NameError
    results_dir = "results"
    os.makedirs(results_dir, exist_ok=True)  # Create 'results' directory if it doesn't exist

    # Configure logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    # File handler for logging
    log_file = os.path.join(results_dir, "training_log.txt")  # Specify the log file name within 'results' directory
    file_handler = logging.FileHandler(log_file)
    file_handler.setLevel(logging.INFO)
    file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    file_handler.setFormatter(file_formatter)

    # Add handlers to the logger
    if not logger.handlers:
        logger.addHandler(file_handler)

    start_time = time.time()

    # Configure JAX to use multiple threads (cores)
    num_envs = 4  # Number of parallel environments; adjust based on your CPU cores
    jax.config.update("jax_platform_name", "cpu")
    os.environ["JAX_NUM_THREADS"] = str(max(1, os.cpu_count() // num_envs))
    jax.config.update("jax_enable_x64", True)  # Use float64 for numerical stability

    # Create configuration
    config = create_config()

    # Instantiate multiple environments for parallel training
    env_fns = [make_env(config, seed=i) for i in range(num_envs)]
    vec_env = SubprocVecEnv(env_fns)

    # Optional: Check if the environment follows the Gymnasium API
    # (Note: Due to complex parameters, this may raise warnings. Proceed if no critical errors.)
    # check_env(vec_env, warn=True)

    # Instantiate the PPO agent with enhanced hyperparameters
    model = PPO(
        CustomActorCriticPolicy,  # Custom policy class defined above
        vec_env,
        verbose=0,  # Suppress default SB3 logs, using custom callback instead
        tensorboard_log=os.path.join(results_dir, "ppo_qoc_tensorboard/"),
        device='cpu',  # Ensure CPU usage; adjust if using GPU
        ent_coef=0.1,  # Increased entropy coefficient for exploration
        learning_rate=1e-4,  # Reduced from 3e-4 to 1e-4 for stability
        clip_range=0.2,
        n_steps=2048,          # Increased from default to 2048
        gamma=0.99,
        gae_lambda=0.95,
        batch_size=64,         # Adjust as needed
        max_grad_norm=0.5      # Gradient clipping to prevent explosion
    )

    # Instantiate the custom callback with the desired stagnation threshold and fidelity parameters
    rl_logging_callback = RLLoggingCallback(
        stagnant_threshold=config.get('stagnant_threshold', 100),
        stagnant_fidelity_min=config.get('stagnant_fidelity_min', 0.99),
        target_fidelity=config.get('target_fidelity', 0.999),
        results_dir=results_dir  # Ensure all outputs are saved in 'results' directory
    )

    # Train the agent with increased timesteps
    total_timesteps = 10000  # Adjust as needed for your training regimen
    logger.info("Commencing PPO training...")
    model.learn(total_timesteps=total_timesteps, callback=rl_logging_callback)

    # After training, compute and log the total run time in minutes
    total_run_time = time.time() - start_time  # Compute total run time in seconds
    total_run_time_minutes = total_run_time / 60  # Convert to minutes
    logger.info(f"Total run time: {total_run_time_minutes:.2f} minutes")  # Log total run time in minutes

    # Save the trained model within the 'results' directory
    model_save_path = os.path.join(results_dir, "ppo_qoc_model.zip")
    model.save(model_save_path)
    logger.info(f"Trained PPO model saved as '{model_save_path}'.")

    # Test the trained agent
    test_env = DummyVecEnv([make_env(config, seed=100)])
    test_model = PPO.load(model_save_path)

    # Create a directory to save test results
    test_results_dir = os.path.join(results_dir, "test_results")
    os.makedirs(test_results_dir, exist_ok=True)

    for episode in range(1, 6):
        obs, _ = test_env.reset()
        done = False
        total_reward = 0.0
        while not done:
            action, _states = test_model.predict(obs, deterministic=True)
            obs, reward, terminated, truncated, info = test_env.step(action)
            done = terminated or truncated
            total_reward += reward
        logger.info(f"Test Episode {episode}: Final Fidelity = {total_reward:.4f}")

    # Save the final best_fidelity and best_voltage if not already saved
    best_fidelity_path = os.path.join(results_dir, "best_fidelity.txt")
    best_voltage_final_path = os.path.join(results_dir, "best_voltage_final.npy")
    if os.path.exists(best_fidelity_path) and os.path.exists(best_voltage_final_path):
        logger.info("Final best fidelity and voltage have been saved.")
    else:
        logger.warning("Best fidelity and/or best voltage files are missing.")

    # Optionally, close environments
    vec_env.close()
    test_env.close()

# ========================================
# Entry Point
# ========================================
if __name__ == "__main__":
    main()
