In [None]:
import numpy as np
from numpy import linalg as la
from numpy import *
from numpy import random as rrr
import matplotlib.pyplot as plt
from matplotlib import cm
from tqdm import tqdm
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from time import time
import joblib
from collections import namedtuple, deque
import math
import random as rnd
from scipy import stats
# importing movie py libraries
# from moviepy.editor import VideoClip
# from moviepy.video.io.bindings import mplfig_to_npimage
from fhn_system_2d import simulate_trajectory

import gc

In [None]:
print (torch.__version__, torch.cuda.is_available())
print(torch.version.cuda)
print (torch.cuda.get_device_name())
#device= 'cpu'
device= 'cuda'

In [None]:
## Choose a system (this is just for name)
system = 'fhn' 
# system = 'double_well'

In [None]:
num_steps = 10000
k_grid = 32

In [None]:
def true_evals (n_max= 3):
    eigs = []
    mu1= 0.05
    omega1= 0.588
    # Compute true eigenvalues for the l = 0 case
    for n in range(-n_max, n_max+1):
         # Skip n = 0 case
        lambda_cont = -n**2*mu1  + 1j*n*omega1
        eigs.append(lambda_cont)
    
    return np.array(eigs)
    

In [None]:
true_evalues = true_evals(3)

In [None]:

# Create scatter plot of eigenvalues in the complex plane
plt.figure()
plt.scatter(true_evalues.real, true_evalues.imag)

# Add reference lines for real and imaginary axes
plt.axhline(0, linewidth=0.5)
plt.axvline(0, linewidth=0.5)

# Set plot limits to focus on non-positive real parts
plt.xlim(-2, 1)   # Real part range
plt.ylim(-1.5, 1.5)    # Imaginary part range

# Label and title
plt.title("Eigenvalues in the Complex Plane")
plt.xlabel("Real Part")
plt.ylabel("Imaginary Part")

# Equal aspect ratio to preserve geometric interpretation
plt.gca().set_aspect('equal', 'box')

# Display the plot
plt.show()


In [None]:


"""
Main script for FirzHugh Nagumo system simulation
"""

# Set parameters for the FHN system 
gamma = 1
beta = 1
delta = 0.25
epsilon = 0.05


a1= 1/3
b1= 0.5
b2= 0
eta1= 0.1
DX= 0.05**2
DY= (eta1*0.05)**2
# Create a grid of points for phase space visualization
x_plot = np.linspace(-3, 3, 100)
y_plot = np.linspace(-3, 3, 100)
XX, YY = np.meshgrid(x_plot, y_plot)

# Calculate vector field for phase space plot
U = (XX - a1*XX**3-YY )
V = eta1*(XX+ b1)  # dy/dt at t=0

# Normalize the arrows
magnitude = np.sqrt(U**2 + V**2)
U_norm = U / magnitude
V_norm = V / magnitude

# Plot the phase space
plt.figure(figsize=(6, 4))
plt.streamplot(XX, YY, U, V, density=1.5, color='darkblue', linewidth=0.7)
plt.quiver(XX[::10, ::10], YY[::10, ::10], U_norm[::10, ::10], V_norm[::10, ::10], 
           color='red', scale=25)
plt.title('FirzHugh Nagumo Phase Space')
plt.xlabel('x (position)')
plt.ylabel('y (velocity)')
plt.grid(alpha=0.3)
plt.show()

def get_single_trajectory(x0, y0, T):
    """
    Produce SDE trajectory starting at (x0, y0) over T steps.
    This function delegates the simulation to duffing_system_2d.simulate_trajectory.
    
    Args:
        x0: Initial x-coordinate (position)
        y0: Initial y-coordinate (velocity)
        T: Total simulation time steps
        
    Returns:
        Tuple of (data_matrix_single, lag_time)
    """
    return simulate_trajectory(x0, y0, T, h=1e-4, n_steps=100, beta= beta, delta= delta, 
                        epsilon= epsilon, a1=a1, b1=b1, b2=b2, eta1= eta1, DX= DX, DY= DY)

In [None]:
# #define the domain of the dynamical system
# x_min= -3
# x_max= 3
# y_min = -4
# y_max= 4

# from double_well_potential_2d import potential, simulate_trajectory

# """
# Main script for plotting and trajectory generation.
# Imports potential and simulate_trajectory from double_well.py.
# """

# # Plot the potential landscape
# x_plot = np.linspace(-2, 2, 100)
# y_plot = np.linspace(-3, 3, 100)
# XX, YY = np.meshgrid(x_plot, y_plot)
# ZZ = potential(XX, YY)

# plt.figure(figsize=(6, 4))
# plt.contourf(XX, YY, ZZ, levels=20, cmap='coolwarm')
# plt.colorbar(label='Potential')
# plt.title('2D Double Well Potential Landscape')
# plt.xlabel('x')
# plt.ylabel('y')
# plt.show()

# # Generate a single SDE trajectory

# def get_single_trajectory(x0, y0, T):
#     """
#     Produce SDE trajectory starting at (x0, y0) over T steps.
#     This function delegates the simulation to double_well.simulate_trajectory.
#     """
#     return simulate_trajectory(x0, y0, T)

In [None]:
#workspace/rl_koopman_dn/ppo_history32x32_10000steps_example_fhn.jbl
#workspace/rl_koopman_dn/ppo_history32x32_100steps_example_fhn_test.jbl
reward_hist= joblib.load (f'ppo_history{k_grid}x{k_grid}_{num_steps}steps_example_{system}_final.jbl')
os.makedirs(f'figures_ppo_{system}', exist_ok=True)

In [None]:
def make_trajectory_from_state(state, chunk_len=100):  # Default is still 100
    state_len = state.shape[1]
    chunk_list = []
    T = 3*chunk_len  # Use the parameter
    
    for ii in arange(state_len):
        x0_single = state[0, ii]
        y0_single = state[1, ii]
        
        data_matrix_single, lag_time = get_single_trajectory(x0_single, y0_single, T)
        chunk_list.append(data_matrix_single)
    
    data_matrix_single = hstack(chunk_list)
    return data_matrix_single, lag_time

In [None]:
rewards_all=[]
states_all= []
actions_all= []
next_states_all= []
for ii in arange(len (reward_hist)):
        curr_list= reward_hist[ii]
        rewards_all.append (curr_list[-1])
        states_all.append (curr_list[0])
        actions_all.append (curr_list[1])
        next_states_all.append (curr_list[2])

In [None]:
def make_trajectory_from_state_list (state_list):
    n_traj= len (state_list)
    traj_list= []
    for kk in tqdm (arange (n_traj)):
        trajectory, lag_time= make_trajectory_from_state(state_list[kk])
        traj_list.append (trajectory)
        trajectory_final= np.hstack (traj_list)
    return trajectory_final, lag_time

In [None]:
from solver_sdmd_torch_gpu5 import KoopmanNNTorch, KoopmanSolverTorch
"""
Refactored SDMD‑DQN script (solver‑5 version)
------------------------------------------------
• Same optimisation style as the previous refactor
• All comments are in English
• y‑axis for eigenvalue plots fixed at –3…3
• Unit‑circle helper included but **NOT** called (uncomment if needed)

Assumed globals already defined elsewhere:
    states_all, system, device,
    make_trajectory_from_state, true_evals,
    KoopmanNNTorch, KoopmanSolverTorch
"""

import os, gc, torch
import numpy as np
import matplotlib.pyplot as plt
from numpy import arange, sign
from tqdm import tqdm

# ─────────────────────────── CONFIG ──────────────────────────── #
STEP_VIEWS = [100, 500, 1000, 2000, 4000]
TRAIN_SPLIT = 0.7
EPOCHS = 6
BATCH_SIZE = 256
LR = 1e-5
LR_DECAY = 0.8
REG = 0.1
LAYER_SIZES = [17, 17]
N_PSI_TRAIN = 15
TOP_K = 7        # how many eigenvalues to keep / plot
Y_LIM = (-3, 3)  # fixed y‑axis range

# ─────────────────────────── HELPERS ─────────────────────────── #
def dynamic_params(step_view: int) -> tuple[int, int]:
    """Return (chunk_len, n_traj) based on current step_view."""
    chunk_len = max(10, 20 + step_view // 10)
    if step_view < 100:
        n_traj = 2
    elif step_view < 500:
        n_traj = 3
    elif step_view < 1000:
        n_traj = 5
    else:
        n_traj = 8 + step_view // 1200
    return chunk_len, n_traj


def build_trajectory(state_slice, chunk_len: int):
    """Concatenate trajectories generated from a list of initial states."""
    trajs = []
    for st in tqdm(state_slice, desc="building‑traj", leave=False):
        traj, lag = make_trajectory_from_state(st, chunk_len=chunk_len)
        trajs.append(traj)
    return np.hstack(trajs), lag


def viz_multiplier(step_view: int) -> float:
    """Scale factor for larger trajectory used in eigenfunction plots."""
    if step_view < 100:
        return 1.0
    elif step_view < 500:
        return 1.5
    return 2.0


def ensure_dir(path: str):
    os.makedirs(path, exist_ok=True)
    return path


def unit_circle(ax):
    """Optional helper to draw a unit circle (currently unused)."""
    theta = np.linspace(0, 2 * np.pi, 200)
    ax.plot(np.cos(theta), np.sin(theta), "--", lw=0.5)


# ─────────────────────────── MAIN LOOP ────────────────────────── #
NUM_ITER = 10
for it in range(NUM_ITER):
    print(f"\n{'='*50}\nStarting iteration {it+1}/{NUM_ITER}\n{'='*50}")
    out_dir = ensure_dir(f"figures_ppo_{system}/iteration51_{it+1}")

    X_eval_accum = []  # hold evaluation points across step_views

    for step in STEP_VIEWS:
        chunk_len, n_traj = dynamic_params(step)
        print(f"step={step}  n_traj={n_traj}  chunk_len={chunk_len}")

        # ── build training data ─────────────────────────────────── #
        traj_slice = states_all[step : step + n_traj]
        data_mat, lag = build_trajectory(traj_slice, chunk_len)

        X = data_mat[:, :-1, :].reshape(-1, data_mat.shape[2])
        Y = data_mat[:, 1:, :].reshape(-1, data_mat.shape[2])

        split = int(TRAIN_SPLIT * len(X))
        train = [X[:split], Y[:split]]
        valid = [X[split:], Y[split:]]

        # ── initialise solver ───────────────────────────────────── #
        ckpt        = f"example_{system}_dqn_iter{it+1}_ckpt.torch"
        fnn_ckpt    = f"example_{system}_fnn_dqn_iter{it+1}.torch"
        coeff_file  = f"sde_coefficients_example_{system}_dqn_iter{it+1}.jbl"

        basis = KoopmanNNTorch(2, LAYER_SIZES, N_PSI_TRAIN).to(device)
        solver = KoopmanSolverTorch(
            dic=basis,
            target_dim=X.shape[1],
            reg=REG,
            checkpoint_file=ckpt,
            fnn_checkpoint_file=fnn_ckpt,
            a_b_file=coeff_file,
            generator_batch_size=2,
            fnn_batch_size=32,
            delta_t=lag
        )

        solver.build_with_generator(
            data_train=train,
            data_valid=valid,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            lr=LR,
            log_interval=10,
            lr_decay_factor=LR_DECAY
        )

        # ── eigenvalues ─────────────────────────────────────────── #
        eigs = solver.eigenvalues.T
        eigs_gen = (eigs - 1) / lag
        eigs_sorted = -np.sort_complex(-eigs_gen)

        print(f"   top‑{TOP_K} eigs:", eigs_sorted[:TOP_K])

        # # ── eigenvalue plot ─────────────────────────────────────── #
        # plt.figure(figsize=(6, 4))
        # lead = eigs_sorted[:TOP_K]
        # plt.scatter(lead.real, lead.imag, label="Estimated")
        # true = true_evals(3)
        # plt.scatter(true.real, true.imag, c="r", marker="x", label="True")
        # # unit_circle(plt.gca())  # uncomment if you want the reference circle

        # plt.title(f"Eigenvalues (iter={it+1}, step={step})")
        # plt.xlabel("Real");  plt.ylabel("Imag")
        # plt.xlim(-4, 2);     plt.ylim(*Y_LIM)
        # plt.axis("equal")
        # plt.grid(ls="--", lw=0.5)
        # plt.legend(loc="upper left")

        # plt.savefig(f"{out_dir}/eigs_{system}_iter{it+1}_step{step}.png")
        # plt.close()

        # ── eigenvalue plot (fixed axes & scale) ───────────────────────── #
        fig, ax = plt.subplots(figsize=(6, 4))
        
        # 1) plot analytic / reference items first  ➜ they never move
        true = true_evals(3)
        ax.scatter(true.real, true.imag, c="r", marker="x", label="True")
        
        # Optional: uncomment to draw a reference unit circle that never moves
        # theta = np.linspace(0, 2 * np.pi, 200)
        # ax.plot(np.cos(theta), np.sin(theta), "--", lw=0.5, color="gray")
        
        # 2) plot estimated eigenvalues (may lie outside limits — that’s OK)
        lead = eigs_sorted[:TOP_K]
        ax.scatter(lead.real, lead.imag, color="tab:blue", label="Estimated")
        
        # 3) absolutely fix axes & aspect every time
        ax.set_xlim(-4, 2)
        ax.set_ylim(-3, 3)
        ax.set_aspect("equal", adjustable="box")   # keeps circle & grid square
        ax.set_title(f"Eigenvalues (step={step})")
        ax.set_xlabel("Real part");  ax.set_ylabel("Imaginary part")
        ax.grid(ls="--", lw=0.5)
        ax.legend(loc="upper left")
        
        fig.tight_layout()
        fig.savefig(f"{out_dir}/eigs_{system}_iter{it+1}_step{step}.png")
        plt.close(fig)

        # eigenvalue TXT
        txt = f"{out_dir}/top{TOP_K}_eigs_iter{it+1}_step{step}.txt"
        with open(txt, "w") as f:
            f.write(f"Top {TOP_K} eigenvalues (generator)  iter={it+1}  step={step}\n")
            f.write(f"{'i':<3} {'Real':>12} {'Imag':>12} {'|λ|':>12} {'θ°':>10}\n")
            for i, ev in enumerate(lead, 1):
                f.write(f"{i:<3} {ev.real:12.6f} {ev.imag:12.6f} "
                        f"{abs(ev):12.6f} {np.angle(ev, deg=True):10.2f}\n")

        # ── eigenfunction evaluation & plot ─────────────────────── #
        n_big = max(2, int(viz_multiplier(step) * n_traj))
        traj_big, _ = build_trajectory(states_all[step : step + n_big], chunk_len)
        X_eval_accum.append(traj_big.reshape(-1, traj_big.shape[2]))
        X_all = np.vstack(X_eval_accum)

        efuns = solver.eigenfunctions(X_all)
        ref_sign = sign(solver.eigenfunctions(np.array([[0., 1.]]))[0].real)

        fig, axs = plt.subplots(1, 2, figsize=(16, 6))
        vals = np.real(efuns) * ref_sign

        sc1 = axs[0].scatter(X_all[:, 0], X_all[:, 1], c=vals[:, 0],
                             cmap="coolwarm", alpha=0.6)
        axs[0].set(title=f"1st Eigenfunction (step={step})",
                   xlim=(-3, 3), ylim=(-3, 3))
        fig.colorbar(sc1, ax=axs[0], shrink=0.7)

        sc2 = axs[1].scatter(X_all[:, 0], X_all[:, 1], c=vals[:, 1],
                             cmap="coolwarm", alpha=0.6)
        axs[1].set(title=f"2nd Eigenfunction (step={step})",
                   xlim=(-3, 3), ylim=(-3, 3))
        fig.colorbar(sc2, ax=axs[1])

        fig.tight_layout()
        fig.savefig(f"{out_dir}/efuns_{system}_iter{it+1}_step{step}.png")
        plt.close(fig)

        # ── memory cleanup ──────────────────────────────────────── #
        del efuns, fig, axs, sc1, sc2
        if 'cuda' in str(device):
            torch.cuda.empty_cache()
        gc.collect()

    print(f"Iteration {it+1} finished — figures saved to {out_dir}")

print("\nAll iterations completed!")


In [None]:
# from solver_sdmd_torch_gpu5 import KoopmanNNTorch, KoopmanSolverTorch
# # Main loop to run the entire process multiple times
# num_iterations = 1  # Number of times to run the full process

# for iteration in range(num_iterations):
#     print(f"\n{'='*50}\nStarting iteration {iteration+1}/{num_iterations}\n{'='*50}")
    
#     # Create a separate directory for each iteration
#     iteration_dir = f"figures_ppo_{system}/iteration5_{iteration+1}"
#     os.makedirs(iteration_dir, exist_ok=True)
    
#     # 0. Prepare a list to accumulate all evaluation points across iterations
#     X_eval_list = []
#     for step_view in [100, 500, 1000, 2000, 4000, 6000, 8000]:
#     # for step_view in [25, 100, 500, 1000]:
#         # Calculate dynamic chunk length based on step_view - starts smaller, grows with step_view
#         dynamic_chunk_len = max(10, 20 + step_view // 10)
        
#         # Determine number of trajectories - very few at the beginning, more later
#         if step_view < 100:
#             n_traj = 2  # Just 2 trajectory for very small step_view
#         elif step_view < 500:
#             n_traj = 3  # 3 trajectories for medium-small step_view
#         elif step_view < 1000:
#             n_traj = 5  # 5 trajectories for medium step_view
#         else:
#             n_traj = 8 + step_view // 1200  # More trajectories for larger step_view
        
#         print(f"Iteration {iteration+1}, Step view: {step_view}, Trajectories: {n_traj}, Chunk length: {dynamic_chunk_len}")
        
#         # Define a local function that uses the dynamic chunk length
#         def make_trajectory_with_dynamic_length(state_list):
#             n_traj_local = len(state_list)
#             traj_list = []
#             for kk in tqdm(arange(n_traj_local)):
#                 # Pass the dynamic chunk length to the inner function
#                 trajectory, lag_time = make_trajectory_from_state(state_list[kk], chunk_len=dynamic_chunk_len)
#                 traj_list.append(trajectory)
#             trajectory_final = np.hstack(traj_list)
#             return trajectory_final, lag_time
        
#         # Build the small trajectory to train the solver using the dynamic length
#         trajectory_final, lag_time = make_trajectory_with_dynamic_length(
#             states_all[step_view:step_view + n_traj]
#         )
        
#         data_matrix_single = trajectory_final

#         # 3. Prepare X and Y for training/validation
#         data_X = data_matrix_single[:, :-1, :]
#         data_Y = data_matrix_single[:, 1:, :]
#         X = data_X.reshape(-1, data_X.shape[2])
#         Y = data_Y.reshape(-1, data_Y.shape[2])
        
#         # 4. Split into 70% train / 30% valid
#         len_all = X.shape[0]
#         split_idx = int(0.7 * len_all)
#         data_x_train = X[:split_idx]
#         data_x_valid = X[split_idx + 1:]
#         data_y_train = Y[:split_idx]
#         data_y_valid = Y[split_idx + 1:]
#         data_train = [data_x_train, data_y_train]
#         data_valid = [data_x_valid, data_y_valid]

#         # 5. Initialize solver with neural-network dictionary - add iteration to checkpoint
#         checkpoint_file = f'example_{system}_dqn_iter_{iteration+1}_ckpt.torch'
#         basis_function = KoopmanNNTorch(
#             input_size=2,
#             layer_sizes=[15],
#             n_psi_train=12
#         ).to(device)
#         solver = KoopmanSolverTorch(
#             dic=basis_function,
#             target_dim=np.shape(data_x_train)[-1],
#             reg=0.1,
#             checkpoint_file=checkpoint_file,
#             fnn_checkpoint_file=f'example_{system}_fnn_dqn_iter_{iteration+1}.torch',
#             a_b_file=f'sde_coefficients_example_{system}_dqn_iter_{iteration+1}.jbl',
#             generator_batch_size=2,
#             fnn_batch_size=32,
#             delta_t=lag_time
#         )

#         # 6. Train the solver
#         solver.build_with_generator(
#             data_train=data_train,
#             data_valid=data_valid,
#             epochs=6,
#             batch_size=256,
#             lr=1e-5,
#             log_interval=10,
#             lr_decay_factor=0.8
#         )

#         # 7. Extract eigenvalues (for logging or unit-circle plot)
#         evalues = solver.eigenvalues.T
        
#         ev_cont = (evalues - 1) / lag_time
#         ev_sorted = -np.sort_complex(-ev_cont)
#         print(f"Iteration {iteration+1}, Step {step_view} - Eigenvalues of generator : {ev_sorted[:7]}")
#         # ev_plot= ev_sorted[(abs(ev_sorted.imag)>1) | (ev_sorted.real >-1)]
#         leading_evalues = ev_sorted[:7] # leading 7 eigenvalues by real part value
#         ev_plot= leading_evalues
#         # Plot eigenvalues on unit circle
#         plt.figure(figsize=(6, 4))
#         #plt.scatter(ev_sorted.real, ev_sorted.imag, label='Eigenvalues: PPO')
#         plt.scatter(ev_plot.real, ev_plot.imag, label='Eigenvalues: PPO')
#         true_eigs= true_evals (3)
#         plt.scatter(true_eigs.real, true_eigs.imag, c='r', label='True Eigenvalues', marker='x')
#         theta = np.linspace(0, 2*np.pi, 100)
#         #plt.plot(np.cos(theta), np.sin(theta), '--', label='Unit Circle')
#         plt.title(f'Eigenvalues (iter={iteration+1}, step_view={step_view})')
#         plt.xlabel('Real Part')
#         plt.ylabel('Imaginary Part')
#         plt.axis('equal')
       
#         plt.xlim(-4, 2)
#         plt.ylim(-2.5, 2.5)
#         plt.grid(True, linestyle='--', linewidth=0.5)
#         plt.legend(loc= 'upper left')
#         eigenvalue_filename = f'{iteration_dir}/eigenvalues_{system}_dqn_test_{iteration+1}_step_{step_view}.png'
#         plt.savefig(eigenvalue_filename)
#         plt.close()  # Close to free memory
        
#         # Save top 10 eigenvalues sorted by real part (largest to smallest)
#         # Convert complex eigenvalues to a numpy array if not already
#         # evalues_np = np.array(evalues) if not isinstance(evalues, np.ndarray) else evalues
#         # # Sort eigenvalues by real part in descending order
#         # sorted_indices = np.argsort(-evalues_np.real)  # Negative sign for descending order
#         # leading_evalues = evalues_np[sorted_indices[:10]]  # Take top 10
#         # leading_evalues = ev_sorted[:7] # leading 7 eigenvalues by real part value
        
#         # Create a text file to store the eigenvalues
#         eigenvalues_txt_filename = f'{iteration_dir}/top10_eigenvalues_test_{iteration+1}_step_{step_view}.txt'
#         with open(eigenvalues_txt_filename, 'w') as f:
#             f.write(f"Top 10 Eigenvalues for Iteration {iteration+1}, Step {step_view}\n")
#             f.write("Sorted by real part (largest to smallest)\n")
#             f.write("-" * 50 + "\n")
#             f.write(f"{'Index':<6} {'Real Part':<15} {'Imaginary Part':<15} {'Magnitude':<15} {'Phase':<15}\n")
#             f.write("-" * 70 + "\n")
            
#             for i, eig in enumerate(leading_evalues):
#                 # Calculate magnitude and phase (in degrees)
#                 magnitude = np.abs(eig)
#                 phase = np.angle(eig, deg=True)
#                 f.write(f"{i+1:<6} {eig.real:<15.6f} {eig.imag:<15.6f} {magnitude:<15.6f} {phase:<15.6f}\n")

#         # 8. For visualization, also scale the number of trajectories with step_view
#         if step_view < 100:
#             viz_multiplier = 1.0  # Just slightly more for visualization at low step_view
#         elif step_view < 500:
#             viz_multiplier = 1.5  # More for medium step_view
#         else:
#             viz_multiplier = 2.0  # Full multiplier for larger step_view
        
#         n_traj_big = max(2, int(viz_multiplier * n_traj))
#         trajectory_final_big, _ = make_trajectory_with_dynamic_length(
#             states_all[step_view:step_view + n_traj_big]
#         )
#         X_big = trajectory_final_big.reshape(-1, trajectory_final_big.shape[2])

#         # 9. Accumulate current evaluation points
#         X_eval_list.append(X_big)
#         X_all = np.vstack(X_eval_list)

#         # 10. Evaluate eigenfunctions on the entire accumulated dataset
#         efuns_all = solver.eigenfunctions(X_all)

#         # 11. Compute reference sign to ensure consistent coloring
#         reference_efun = solver.eigenfunctions(np.array([[0.0, 1.0]]))
#         reference_sign = sign(reference_efun[0].real)

#         # 12. Plot the first two eigenfunctions over the cumulative data
#         fig, axs = plt.subplots(1, 2, figsize=(16, 6))

#         # First eigenfunction
#         sc1 = axs[0].scatter(
#             X_all[:, 0], X_all[:, 1],
#             c=np.real(efuns_all)[:, 0] * reference_sign[0],
#             cmap='coolwarm',
#             alpha=0.6,
#             vmin=np.real(efuns_all)[:, 0].min(),
#             vmax=np.real(efuns_all)[:, 0].max()
#         )
#         axs[0].set_title(f'1st Eigenfunction (step={step_view})', fontsize=18)
#         axs[0].set_xlim(-3, 3)
#         axs[0].set_ylim(-3, 3)
#         fig.colorbar(sc1, ax=axs[0], shrink=0.7, aspect=20, format='%.2f')
        

#         # Second eigenfunction
#         sc2 = axs[1].scatter(
#             X_all[:, 0], X_all[:, 1],
#             c=np.real(efuns_all)[:, 1] * reference_sign[1],
#             cmap='coolwarm',
#             alpha=0.6
#         )
#         axs[1].set_title(f'2nd Eigenfunction (step={step_view})', fontsize=18)
#         axs[1].set_xlim(-3, 3)
#         axs[1].set_ylim(-3, 3)
#         fig.colorbar(sc2, ax=axs[1], format='%.2f')

#         plt.tight_layout()
#         eigenfunction_filename = f'{iteration_dir}/eigenfunction_{system}_dqn_test_{iteration+1}_step_{step_view}.png'
#         fig.savefig(eigenfunction_filename)
#         plt.close()  # Close to free memory

#         del efuns_all, fig, axs, sc1, sc2
#         if 'cuda' in str(device):
#             torch.cuda.empty_cache()  # Free GPU memory
#         gc.collect()  
        
#     print(f"Completed iteration {iteration+1}. Saved figures to {iteration_dir}/")

# print(f"\nAll {num_iterations} iterations completed successfully!")
# print(f"Check the 'figures/' directory to find subdirectories for each iteration.")