#### This code is intended for utilization by OP-airPLS in the context of batch-process the optimization operations with simluated datasets.

Inquiries regarding the utilization of this code are welcomed. Please contact: jiaheng.cui@uga.edu.

# Grid search and optimization for OP-airPLS

In [None]:
import numpy as np
import pandas as pd
from scipy.sparse import csc_matrix, eye, diags
from scipy.sparse.linalg import spsolve
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings


# This code is modified from Yizeng Liang and Zhang Zhimin (https://github.com/zmzhang/airPLS). Their liscence is preserved according to their requirement.
'''
airPLS.py Copyright 2014 Renato Lombardo - renato.lombardo@unipa.it
Baseline correction using adaptive iteratively reweighted penalized least squares

This program is a translation in python of the R source code of airPLS version 2.0
by Yizeng Liang and Zhang Zhimin - https://code.google.com/p/airpls
Reference:
Z.-M. Zhang, S. Chen, and Y.-Z. Liang, Baseline correction using adaptive iteratively reweighted penalized least squares. Analyst 135 (5), 1138-1146 (2010).

Description from the original documentation:

Baseline drift always blurs or even swamps signals and deteriorates analytical results, particularly in multivariate analysis.  It is necessary to correct baseline drift to perform further data analysis. Simple or modified polynomial fitting has been found to be effective in some extent. However, this method requires user intervention and prone to variability especially in low signal-to-noise ratio environments. The proposed adaptive iteratively reweighted Penalized Least Squares (airPLS) algorithm doesn't require any user intervention and prior information, such as detected peaks. It iteratively changes weights of sum squares errors (SSE) between the fitted baseline and original signals, and the weights of SSE are obtained adaptively using between previously fitted baseline and original signals. This baseline estimator is general, fast and flexible in fitting baseline.


LICENCE
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>
'''

def WhittakerSmooth(x, w, lambda_, differences=1):
    X = np.matrix(x)
    m = X.size
    E = eye(m, format='csc')
    for i in range(differences):
        E = E[1:] - E[:-1]
    W = diags(w, 0, shape=(m,m))
    A = csc_matrix(W + (lambda_ * E.T * E))
    B = csc_matrix(W * X.T)
    background = spsolve(A, B)
    return np.array(background)


def airPLS(x, lambda_=100, porder=1, itermax=15, tol=0.001):
#     print(f"\n--- Starting airPLS with lambda={lambda_}, tol={tol}, itermax={itermax} ---")
    
    try:
        m = x.shape[0]
        weights = np.ones(m)

        # Store iteration data
        z_iterations = []
        d_iterations = []
        dssn_iterations = []
        weights_iterations = []

        for i in range(1, itermax+1):
            z = WhittakerSmooth(x, weights, lambda_, porder)
            if z is None:
                print(f"WhittakerSmooth failed at iteration {i}")
                return None
                  
            d = x - z
            dssn = np.abs(d[d<0].sum())
            
            # Store iteration data
            z_iterations.append(z)
            d_iterations.append(d)
            dssn_iterations.append(dssn)
            weights_iterations.append(weights.copy())

            if(dssn < tol * (abs(x)).sum() or i == itermax):
                break
            weights[d>=0] = 0

            try:
                with warnings.catch_warnings(record=True) as caught_warnings:
                    warnings.simplefilter("always")
                    if np.isinf(dssn):
                        dssn = np.finfo(dssn.dtype).max
                    exp_term = i * np.abs(d[d<0]) / dssn
                    max_exp = 709  # exponential term <= round(np.log(np.finfo(float).max)) to avoid overflow
                    exp_term = np.minimum(exp_term, max_exp)
                    weights[d<0] = np.exp(exp_term)
                    if caught_warnings:
                        warning = caught_warnings[0]
                        print(f"Warning at iteration {i} in exp calculation: {warning.message}")
            except Exception as e:
                print(f"Exception at iteration {i} in exp calculation: {e}")
                return z

            try:
                if np.isinf(dssn):
                    dssn = np.finfo(dssn.dtype).max
                weights[0] = np.exp(i * (d[d<0]).max() / dssn)
            except Warning as warn:
                print(f"Warning at iteration {i} in w[0] calculation: {warn}")
                return z

            weights[-1] = weights[0]
        return z
    except Exception as e:
        print(f"Error in airPLS: {str(e)}")
        print(f"Traceback: {traceback.format_exc()}")
        return None
        
def calculate_mae(true, predicted):  
    diff = np.abs(true - predicted)
    mae = np.mean(diff)    
    return mae

def calculate_mae_for_params(spectrum, true_spectrum, lambda_, tol):
    baseline = airPLS(spectrum, lambda_, porder=2, itermax=1000, tol=tol)
    corrected_spectrum = spectrum - baseline
    return calculate_mae(true_spectrum, corrected_spectrum)

### Loop

In [9]:
import pickle
from matplotlib.colors import LogNorm

def calculate_parameter_space(spectrum, true_spectrum, param_ranges, num_points=50):
    lambda_range = np.logspace(np.log10(param_ranges[0, 0]), np.log10(param_ranges[0, 1]), num_points)
    tol_range = np.logspace(np.log10(param_ranges[1, 0]), np.log10(param_ranges[1, 1]), num_points)
    
    mae_values = np.full((num_points, num_points), np.nan)
    
    for i, lambda_ in enumerate(lambda_range):
        for j, tol in enumerate(tol_range):
            try:
                mae_values[i, j] = calculate_mae_for_params(spectrum, true_spectrum, lambda_, tol)
            except Exception as e:
                print(f"Error at lambda={lambda_}, tol={tol}: {str(e)}")
    
    valid_mae = mae_values[~np.isnan(mae_values)]
    if len(valid_mae) > 0:
        optimal_index = np.unravel_index(np.nanargmin(mae_values), mae_values.shape)
#         optimal_index = np.unravel_index(np.argmin(mae_values), mae_values.shape)
        grid_optimal_lambda = lambda_range[optimal_index[0]]
        grid_optimal_tol = tol_range[optimal_index[1]]
        grid_optimal_mae = mae_values[optimal_index]
    else:
        grid_optimal_lambda = grid_optimal_tol = grid_optimal_mae = np.nan
    
    return {
        'lambda_range': lambda_range,
        'tol_range': tol_range,
        'mae_values': mae_values,
        'grid_optimal_lambda': grid_optimal_lambda,
        'grid_optimal_tol': grid_optimal_tol,
        'grid_optimal_mae': grid_optimal_mae
    }

def save_calculation_results(results, filename):
    with open(filename, 'wb') as f:
        pickle.dump(results, f)
    print("Grid search results saved")


In [16]:
from functools import lru_cache
import time

@lru_cache(maxsize=None)
def memoized_objective(spectrum, true_spectrum, lambda_, tol):
    spectrum_array = np.array(spectrum)
    true_spectrum_array = np.array(true_spectrum)
    return calculate_mae_for_params(spectrum_array, true_spectrum_array, lambda_, tol)

def round_to_6(x):
    return np.round(x, 25)

def adaptive_grid_search(initial_lambda, initial_tol, initial_resolution, calculate_mae_for_params, spectrum, true_spectrum, param_ranges, min_iterations=5, consecutive_stable_iterations=5, max_non_decreasing=20):
    start_time = time.time()
    
    def objective(params):
        try:
            return memoized_objective(tuple(spectrum), tuple(true_spectrum), 10**params[0], 10**params[1])
        except Exception as e:
            print(f"Error in objective function: {str(e)}")
            print(f"Parameters: lambda={10**params[0]}, tol={10**params[1]}")
            print("Traceback:")
            traceback.print_exc()
            return np.inf

    center = np.array([np.log10(initial_lambda), np.log10(initial_tol)])
    resolution = initial_resolution
    num_points = 21

    iterations = []
    maes = []
    lambdas = []
    tols = []
    relative_errors = []

    initial_mae = objective(center)
    prev_mae = initial_mae
    best_mae = initial_mae
    best_params = center.copy()

    print(f"\nInitial values:")
    print(f"Lambda: {round_to_6(initial_lambda)}")
    print(f"Tolerance: {round_to_6(initial_tol)}")
    print(f"Initial MAE: {round_to_6(initial_mae)}")

    iteration = 0
    stable_iterations = 0
    non_decreasing_count = 0
    
    while True:
        iteration += 1
        print(f"\nIteration {iteration}")
        
        # Set the best point as the new center
        center = best_params.copy()
        
        # Generate grid in log space, enforcing parameter ranges and including the center point
        lambda_range = np.array([10**(center[0] + i*resolution[0]/2) for i in range(-10, 11)])
        tol_range = np.array([10**(center[1] + i*resolution[1]/2) for i in range(-10, 11)])
        
        # Enforce parameter ranges
        lambda_range = np.clip(lambda_range, param_ranges[0, 0], param_ranges[0, 1])
        tol_range = np.clip(tol_range, param_ranges[1, 0], param_ranges[1, 1])
        
        print(f"Lambda range: {round_to_6(lambda_range[0])} to {round_to_6(lambda_range[-1])}")
        print(f"Tol range: {round_to_6(tol_range[0])} to {round_to_6(tol_range[-1])}")
        print(f"Resolution: {resolution}")
        
        # Calculate MAE for each point in the grid
        mae_values = np.zeros((num_points, num_points))
        for i, lambda_ in enumerate(lambda_range):
            for j, tol in enumerate(tol_range):
                mae_values[i, j] = objective([np.log10(lambda_), np.log10(tol)])
#                 print(f'MAE value: {mae_values[i, j]}')
        
#         print("MAE values:")
#         print(mae_values)
        
        # Find the new optimal point, ignoring NaN values
        if np.any(~np.isnan(mae_values)):
            optimal_index = np.unravel_index(np.nanargmin(mae_values), mae_values.shape)
            new_center = np.array([np.log10(lambda_range[optimal_index[0]]), np.log10(tol_range[optimal_index[1]])])
            new_mae = mae_values[optimal_index]
        else:
            print("Warning: All MAE values are NaN. Keeping the previous best point.")
            new_center = best_params.copy()
            new_mae = best_mae
            
        print(f"New center: {round_to_6(10**new_center[0])}, {round_to_6(10**new_center[1])}")
        print(f"New MAE: {round_to_6(new_mae)}")
        
        # Calculate relative error
        relative_error = (new_mae - prev_mae) / prev_mae if prev_mae != 0 and not np.isnan(prev_mae) else 0
        print(f"Relative error: {round_to_6(relative_error)}")
        
        # Store results (store the best MAE)
        iterations.append(iteration)
        maes.append(min(new_mae, best_mae))
        lambdas.append(10**new_center[0])
        tols.append(10**new_center[1])
        relative_errors.append(relative_error)
        
        # Update best MAE and parameters if necessary
        if new_mae < best_mae:
            best_mae = new_mae
            best_params = new_center.copy()
            print("Best MAE is updated")
            non_decreasing_count = 0
        else:
            non_decreasing_count += 1
        
        # Check stopping criteria
        if abs(relative_error) < 0.05:
            stable_iterations += 1
            if stable_iterations >= consecutive_stable_iterations and iteration >= min_iterations:
                print(f"Stopping criterion met: {stable_iterations} consecutive stable iterations")
                break
        else:
            stable_iterations = 0
        
        if non_decreasing_count >= max_non_decreasing:
            print(f"Stopping criterion met: {max_non_decreasing} consecutive non-decreasing iterations")
            break
        
        # Check if we need to adjust resolution
        distance_from_center = np.array([
            abs(optimal_index[0] - num_points // 2) / (num_points // 2),
            abs(optimal_index[1] - num_points // 2) / (num_points // 2)
        ])
        if np.all(distance_from_center < 0.5):
            resolution *= 0.5
            print("Increasing resolution")
        
        # Update for next iteration
        prev_mae = new_mae
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    
    print(f"\nOptimization completed in {round_to_6(elapsed_time)} seconds")
    print(f"Number of iterations: {iteration}")
    print(f"Best MAE: {round_to_6(best_mae)}")
    print(f"Best parameters: lambda={round_to_6(10**best_params[0])}, tol={round_to_6(10**best_params[1])}")
    
    trajectory = {
        'iterations': iterations,
        'maes': maes,
        'lambdas': lambdas,
        'tols': tols,
        'relative_errors': relative_errors
    }
    
    return [10**best_params[0], 10**best_params[1]], best_mae, trajectory, memoized_objective.cache_info()

In [17]:
def visualize_parameter_space_new(calc_results, optimal_params, trajectory, ax, initial_params):
    lambda_range = calc_results['lambda_range']
    tol_range = calc_results['tol_range']
    mae_values = calc_results['mae_values']
    
    # Create a custom colormap
    n_bin = 100
    cmap = plt.get_cmap('viridis', n_bin)
    
    # Use LogNorm for better distribution of colors, emphasizing lower values
    norm = LogNorm(vmin=np.nanmin(mae_values), vmax=np.nanmax(mae_values))
    
    # Create contour plot
    contour = ax.pcolormesh(np.log10(lambda_range), np.log10(tol_range), mae_values.T, 
                            cmap=cmap, norm=norm, shading='auto')
    
    # Create a custom colorbar
    cbar = plt.colorbar(contour, ax=ax, label='MAE (log scale)')
    
    # Set colorbar ticks to show actual MAE values
    tick_locations = np.logspace(np.log10(np.nanmin(mae_values)), np.log10(np.nanmax(mae_values)), num=6)
    cbar.set_ticks(tick_locations)
    cbar.set_ticklabels([f'{t:.2e}' for t in tick_locations])
    
    ax.set_xlabel('log10(lambda)')
    ax.set_ylabel('log10(tol)')
    ax.set_title('MAE Landscape for airPLS Parameters')
    
    # Plot full optimization trajectory, including initial parameters
    full_lambda_trajectory = [initial_params[0]] + trajectory['lambdas']
    full_tol_trajectory = [initial_params[1]] + trajectory['tols']
    full_iterations = [0] + trajectory['iterations']
    
    ax.plot(np.log10(full_lambda_trajectory), np.log10(full_tol_trajectory), 'r--', linewidth=2, alpha=0.7)
    scatter = ax.scatter(np.log10(full_lambda_trajectory), np.log10(full_tol_trajectory), 
                         c=full_iterations, cmap='cool', s=50, zorder=5)
    
    # Add colorbar for iterations
    cbar_iter = plt.colorbar(scatter, ax=ax, label='Iteration', pad=0.01)
    
    # Plot initial parameters
    ax.plot(np.log10(initial_params[0]), np.log10(initial_params[1]), 'k*', 
            markersize=15, label='Initial Parameters')
    
    # Plot first iteration result
    ax.plot(np.log10(trajectory['lambdas'][0]), np.log10(trajectory['tols'][0]), 'g*', 
            markersize=15, label='First Iteration')
    
    # Plot final parameters
    ax.plot(np.log10(trajectory['lambdas'][-1]), np.log10(trajectory['tols'][-1]), 'w*', 
            markersize=15, label='Final Parameters')
    
    optimal_lambda, optimal_tol = optimal_params
    ax.plot(np.log10(optimal_lambda), np.log10(optimal_tol), 'y*', 
            markersize=15, label='Optimal Parameters')
    
    ax.legend()

### Option 1: process all csv's first n spectra

In [18]:
import os
import traceback

def process_all_csvs(base_dir, output_base_dir, start_spectra_index=0, end_spectra_index=4, display_plots=False, single_file=None, single_spectrum_index=None, initial_params=[100, 0.001]):
    if single_file and single_spectrum_index is not None:
        # Process a single spectrum from a specific file
        file_name = os.path.basename(single_file)
        type_shape = os.path.splitext(file_name)[0]
        output_dir = os.path.join(output_base_dir, type_shape)
        os.makedirs(output_dir, exist_ok=True)
        
        true_spectra_path = os.path.join(base_dir, "True spectra", f"{type_shape.split('_')[0]}_df.csv")
        true_baseline_path = os.path.join(base_dir, f"{type_shape}_baseline.csv")
        
        param_ranges = np.array([
            [1e0, 1e8],  # lambda range
            [10**(-7.5), 1e-2]  # tol range
        ])
        
        print(f"\nProcessing single spectrum {single_spectrum_index} from {file_name}")
        results, predicted_baselines, predicted_spectra = process_csv_with_optimization_and_grid_search(
            single_file, true_spectra_path, true_baseline_path, 
            param_ranges, initial_params, 
            output_dir=output_dir, n_spectra=1, display_plots=display_plots,
            single_spectrum_index=single_spectrum_index
        )
        
        # Save results for the single spectrum
        results.to_csv(os.path.join(output_dir, f"{type_shape}_spectrum_{single_spectrum_index}_optimization_results.csv"), index=False)
        predicted_baselines.to_csv(os.path.join(output_dir, f"{type_shape}_spectrum_{single_spectrum_index}_predicted_baseline.csv"), index=False)
        predicted_spectra.to_csv(os.path.join(output_dir, f"{type_shape}_spectrum_{single_spectrum_index}_predicted_spectra.csv"), index=False)
        
    else:
        # Process all CSVs or first n_spectra of all CSVs
        types = ['broad', 'convoluted', 'distinct']
        shapes = ['exp', 'gaussian', 'poly', 'sigmoidal']

        param_ranges = np.array([
            [1e0, 1e8],  # lambda range
            [10**(-7.5), 1e-2]  # tol range
        ])

        for type in types:
            for shape in shapes:
                print(f"\nProcessing {type}_{shape}")
                
                # Define file paths
                file_path = os.path.join(base_dir, f"{type}_{shape}.csv")
                true_spectra_path = os.path.join(base_dir, "True spectra", f"{type}_df.csv")
                true_baseline_path = os.path.join(base_dir, f"{type}_{shape}_baseline.csv")
                
                # Create output directory
                output_dir = os.path.join(output_base_dir, f"{type}_{shape}")
                os.makedirs(output_dir, exist_ok=True)

                # Process the CSV
                start_time = time.time()
                print(f"Working on spectral_shape={type}, baseline_shape={shape}")
                results, predicted_baselines, predicted_spectra = process_csv_with_optimization_and_grid_search(
                    file_path, true_spectra_path, true_baseline_path, 
                    param_ranges, initial_params=[100, 0.001], 
                    output_dir=output_dir, start_spectra_index=start_spectra_index, end_spectra_index=end_spectra_index, display_plots=display_plots
                )
                end_time = time.time()

                # Save results
                results.to_csv(os.path.join(output_dir, f"{type}_{shape}_optimization_results.csv"), index=False)
                predicted_baselines.to_csv(os.path.join(output_dir, f"{type}_{shape}_predicted_baseline.csv"), index=False)
                predicted_spectra.to_csv(os.path.join(output_dir, f"{type}_{shape}_predicted_spectra.csv"), index=False)
                
                print(f"Completed {type}_{shape} in {end_time - start_time:.2f} seconds")

    print("All processing completed.")

def process_csv_with_optimization_and_grid_search(file_path, true_spectra_path, true_baseline_path, param_ranges, initial_params=[1e4, 1e-3], output_dir=None, start_spectra_index=0, end_spectra_index=4, display_plots=False, single_spectrum_index=None):
    print(f"Processing CSV file: {file_path}")
    print(f"True spectra file: {true_spectra_path}")
    print(f"True baseline file: {true_baseline_path}")
    
    # Read the CSV files
    df = pd.read_csv(file_path)
    true_df = pd.read_csv(true_spectra_path, sep='\t')
    true_baseline_df = pd.read_csv(true_baseline_path)
    
    # Extract spectra
    wavenumbers = df.iloc[:, 0]  # Assuming the first column is wavenumber
    if single_spectrum_index is not None:
        spectra = df.iloc[:, single_spectrum_index:single_spectrum_index+1]
    else:
        spectra = df.iloc[:, start_spectra_index+1:(end_spectra_index+1)+1]  # Process only the spectra with index from input start to input end
    true_spectrum = true_df['y'].values
    
    print(f"Total number of spectra to process: {len(spectra.columns)}")
    
    results = []
    predicted_baselines = pd.DataFrame()
    predicted_spectra = pd.DataFrame()
    
    best_params = initial_params  # Initialize with the provided initial parameters
    
    for i, col in enumerate(tqdm(spectra.columns, desc="Processing spectra")):
        print(f"\nProcessing spectrum {i+1}/{len(spectra.columns)}: {col}")
        spectrum = spectra[col].values
        true_baseline = true_baseline_df[col].values
        
        # Use the best parameters from the previous spectrum as the starting point
        initial_optimal_params = best_params
        
        print(f"Starting optimization with initial parameters: lambda={initial_optimal_params[0]}, tol={initial_optimal_params[1]}")
        
        adaptive_grid_result = adaptive_grid_search(
            initial_optimal_params[0], initial_optimal_params[1], np.array([0.5, 0.5]),
            calculate_mae_for_params, spectrum, true_spectrum,
            param_ranges, consecutive_stable_iterations=5, max_non_decreasing=20
        )
        
        # Unpack the result based on the actual number of returned values
        if len(adaptive_grid_result) == 4:
            optimal_params, optimal_mae, trajectory, cache_info = adaptive_grid_result
        elif len(adaptive_grid_result) == 3:
            optimal_params, optimal_mae, trajectory = adaptive_grid_result
            cache_info = None
        else:
            raise ValueError(f"Unexpected number of return values from adaptive_grid_search: {len(adaptive_grid_result)}")
        
        # Update best_params for the next iteration
        best_params = optimal_params
        
        # Apply airPLS with final optimal parameters
        print("Applying airPLS with final optimal parameters...")
        baseline = airPLS(spectrum, lambda_=optimal_params[0], porder=2, itermax=1000, tol=optimal_params[1])
        
        # Calculate MAEs
        baseline_mae = calculate_mae(true_baseline, baseline)
        spectrum_mae = calculate_mae(true_spectrum, spectrum - baseline)
        
        print(f"Results for spectrum {col}:")
        print(f"  Optimal lambda: {optimal_params[0]}")
        print(f"  Optimal tolerance: {optimal_params[1]}")
        print(f"  Spectrum MAE: {spectrum_mae}")
        
        results.append({
            'spectrum': col,
            'optimal_lambda': optimal_params[0],
            'optimal_tol': optimal_params[1],
            'baseline_mae': baseline_mae,
            'spectrum_mae': spectrum_mae
        })
        
        # Store predicted baseline and spectrum
        predicted_baselines[col] = baseline
        predicted_spectra[col] = spectrum - baseline
        
        # Plot and save the results
        if output_dir:
            try:
                print(f"Attempting to save plots for spectrum {col}")
                
                # Ensure output directory exists
                os.makedirs(output_dir, exist_ok=True)
                print(f"Output directory confirmed: {output_dir}")

                # Spectrum plot
                plt.figure(figsize=(12, 8))
                plt.plot(wavenumbers, spectrum, label='Original Spectrum')
                plt.plot(wavenumbers, true_baseline, label='True Baseline')
                plt.plot(wavenumbers, baseline, label='Estimated Baseline')
                plt.plot(wavenumbers, spectrum - baseline, label='Corrected Spectrum')
                plt.plot(wavenumbers, true_spectrum, label='True Spectrum')
                plt.xlabel('Wavenumber')
                plt.ylabel('Intensity')
                plt.title(f'Spectrum {col} - AirPLS Baseline Correction\n' +
                          f'Optimal λ: {optimal_params[0]:.2e}, tol: {optimal_params[1]:.2e}, MAE: {spectrum_mae:.4f}')
                plt.legend()
                plt.tight_layout()
                spectrum_plot_path = os.path.join(output_dir, f'spectrum_{col}_plot.png')
                plt.savefig(spectrum_plot_path)
                print(f"Spectrum plot saved: {spectrum_plot_path}")
                plt.close('all')
                plt.clf()
                plt.cla()
                
                # Trajectory plots (2x3)
                fig, axs = plt.subplots(3, 2, figsize=(15, 20))
                
                axs[0, 0].plot(trajectory['iterations'], trajectory['maes'])
                axs[0, 0].set_xlabel('Iteration')
                axs[0, 0].set_ylabel('MAE')
                axs[0, 0].set_title('MAE vs Iteration')
                
                axs[0, 1].plot(trajectory['iterations'], np.log10(trajectory['maes']))
                axs[0, 1].set_xlabel('Iteration')
                axs[0, 1].set_ylabel('log10(MAE)')
                axs[0, 1].set_title('log10(MAE) vs Iteration')
                
                axs[1, 0].plot(trajectory['iterations'], np.log10(trajectory['lambdas']))
                axs[1, 0].set_xlabel('Iteration')
                axs[1, 0].set_ylabel('log10(lambda)')
                axs[1, 0].set_title('log10(lambda) vs Iteration')
                
                axs[1, 1].plot(trajectory['iterations'], np.log10(trajectory['tols']))
                axs[1, 1].set_xlabel('Iteration')
                axs[1, 1].set_ylabel('log10(tol)')
                axs[1, 1].set_title('log10(tol) vs Iteration')
                
                axs[2, 0].plot(trajectory['iterations'], trajectory['relative_errors'])
                axs[2, 0].set_xlabel('Iteration')
                axs[2, 0].set_ylabel('Relative Error')
                axs[2, 0].set_title('Relative Error vs Iteration')
                
                axs[2, 1].axis('off')
                
                for ax in axs.flat:
                    if ax.get_title():  # Only add grid to non-empty plots
                        ax.grid(True)
                
                plt.tight_layout()
                trajectory_plot_path = os.path.join(output_dir, f'spectrum_{col}_trajectory_plots.png')
                plt.savefig(trajectory_plot_path)
                print(f"Trajectory plots saved: {trajectory_plot_path}")
                plt.close('all')
                plt.clf()
                plt.cla()
                
                print(f"All plots for spectrum {col} saved successfully")
                
            except Exception as e:
                print(f"Error occurred while generating plots for spectrum {col}:")
                print(traceback.format_exc())
            
        print(f"Completed processing spectrum {col}")
        print(memoized_objective.cache_info())
        memoized_objective.cache_clear()
    
    # Add wavenumbers to predicted baselines and spectra
    predicted_baselines['wavenumbers'] = wavenumbers
    predicted_spectra['wavenumbers'] = wavenumbers
    
    # Convert results to DataFrame for easy analysis
    results_df = pd.DataFrame(results)
    
    print("\nOptimization Results Summary:")
    print(results_df.describe())
    
    return results_df, predicted_baselines, predicted_spectra

#### Option 1:

To process the i-th to j-th spectra of all CSVs:

Example: start_spectra_index=0: we start from the the first spectrum, end_spectra_index=4:  we stop after processing the 5th spectrum.

In [None]:
# Main execution
base_dir = r"E:\python\Jupyter\PhD\Raman Spectrum\Baseline removal\New_simulated\data\one spectral shape"
output_base_dir = r"E:\python\Jupyter\PhD\Raman Spectrum\Baseline removal\New_simulated\results\airpls_optimization"

print("Starting AirPLS optimization pipeline...")
process_all_csvs(base_dir, output_base_dir, start_spectra_index=0, end_spectra_index=4, display_plots=False) 
print("\nAll processing completed.")

#### Option 2:

To process a single spectrum from one CSV:

Example: broad_gaussian.csv: the csv we are going to process, single_spectrum_index=4: to process the 4th spectrum

In [None]:
# Main execution
base_dir = r"E:\python\Jupyter\PhD\Raman Spectrum\Baseline removal\New_simulated\data\one spectral shape"
output_base_dir = r"E:\python\Jupyter\PhD\Raman Spectrum\Baseline removal\New_simulated\results\airpls_optimization"

print("Starting AirPLS optimization pipeline...")

# OR, to process a single spectrum (e.g., spectrum #3, then single_spectrum_index=4) from a specific CSV (e.g., broad_gaussian.csv):
single_file = r"E:\python\Jupyter\PhD\Raman Spectrum\Baseline removal\New_simulated\data\one spectral shape\broad_gaussian.csv"
process_all_csvs(base_dir, output_base_dir, single_file=single_file, single_spectrum_index=4, initial_params=[100, 0.001])

print("\nAll processing completed.")