"""

HDF5 XPCS FitLab - Script for analyzing X-ray Photon Correlation Spectroscopy (XPCS) data from HDF5 files.

Author: Marilina Cathcarth [mcathcarth@gmail.com]

Version: 2.1

Date: September 18, 2023

**Important Note:**

- This script relies on functions defined in the 'XPCS_functions.py' file. Please ensure that 'XPCS_functions.py' is located in the same directory as this script for proper execution.

Dependencies:

- Python (version >= 3.6)
- h5py (to work with HDF5 files)
- numpy (for numerical calculations)
- scipy (for curve fitting)
- scikit-learn (for calculating R2 score)
- matplotlib (for data visualization)
- qtpy (for GUI-based directory selection)

Installation:

1. Make sure you have Python 3.6 or later installed. If not, download and install Python from https://www.python.org/downloads/


2. Install the required packages using pip. Open a terminal or command prompt and run the following command:

   pip install h5py numpy scipy scikit-learn matplotlib qtpy


"""

In [1]:
import os
import sys
import h5py
import numpy as np
import re
import matplotlib.pyplot as plt
import math
from scipy.optimize import curve_fit
from XPCS_functions import select_synchrotron, select_directory, initialize_error_and_success_counters
from XPCS_functions import generate_base_name, process_multi_tau_sirius, process_multi_tau_aps
from XPCS_functions import fit_model, single_exponential, stretched_exponential, calculate_relaxation_and_diffusion, fit_linear_model
from XPCS_functions import custom_cmap, initialize_plot, plot_data_and_curves, configure_subplot
from XPCS_functions import print_summary, add_data_to_table, write_dat_file, generate_fit_results_tables

# Default directory (can be left empty)
directory = ''

# Get the selected synchrotron
selected_synchrotron = select_synchrotron()

# Exit if no directory selected
if selected_synchrotron is None:
    print("User closed the window.")
    sys.exit()

# Call the select_directory function only if the directory is not defined
if not directory:
    directory = select_directory()

# Get a list of .hdf5 files in the directory
hdf5_files = [file for file in os.listdir(directory) if file.endswith('.hdf5')]

# Exit if no .hdf5 files are found
if not hdf5_files:
    print("No .hdf5 files found in the directory.")
    sys.exit()

# Initialize error and success counters using the dictionary
counters = initialize_error_and_success_counters()

# Initialize empty lists for table data
table_data = {
    "single": [],     # Empty list for single table data
    "stretched": []   # Empty list for stretched table data
}

#################################
### Loop over each .hdf5 file ###
#################################

for hdf5_file in hdf5_files:
    file_path = os.path.join(directory, hdf5_file)
    counters['hdf5_files'] += 1  # Increment the hdf5_file counter

    # Generate the appropriate base name for output file naming based on selected synchrotron
    base_name = generate_base_name(selected_synchrotron, hdf5_file)

    ### Open the HDF5 file
    with h5py.File(file_path, 'r') as hdf:
        # Check if the selected synchrotron is 'Sirius'
        if selected_synchrotron == 'Sirius':
            # Call the function to process 'Multi-tau' data for Sirius
            dataset_keys, multi_tau = process_multi_tau_sirius(hdf)
            q_len = len(dataset_keys)

        # Check if the selected synchrotron is 'APS'
        elif selected_synchrotron == 'APS':
            # Call the function to process 'Multi-tau' data for APS
            dataset_keys = process_multi_tau_APS(hdf)
            q_len = len(dataset_keys)

        # Call the function to initialize the plot
        fig, ax_single, ax_stretched, ax_linear, cmap, lines_labels_dict = initialize_plot(q_len)

        # Initialize empty lists for relax_rate vs q^2 values
        mod_values = {
            "q_square": [],   # Empty list for q_square model values
            "relax_rate": []  # Empty list for relax_rate model values
        }

        ###########################################
        # Loop over each dataset within Multi-tau #
        ###########################################
        
        for i, dataset_key in enumerate(dataset_keys):
            counters['multi_tau'] += 1                   # Increment the multi_tau counter
            dataset_name = dataset_key.strip()           # Remove leading and trailing whitespace

            # Create the output filename
            output_name = f"{base_name}_export_{chr(ord('a')+i)}.dat"
            # Generate the complete output file path
            output_path = os.path.join(directory, output_name)

            # Extract data from the dataset within Multi-tau
            dataset = multi_tau.get(dataset_key)
            if dataset is None:
                continue

            column = dataset[:]

            # Check if the data has a single column
            if column.ndim != 1:
                continue

            # Get the column names from the HDF5 file
            column_names = list(column.dtype.names)

            # Get the values of the columns
            t = column['delay time (s)']
            g2 = column['g2']

            # Check for invalid values in the data using a boolean mask
            invalid_mask = np.isnan(t) | np.isnan(g2) | np.isinf(t) | np.isinf(g2)
            if invalid_mask.any():
                counters['invalid_data'] += 1  # Increment the invalid_data counter
                counters['invalid_files'].append(f"{base_name} - {dataset_name}")
                continue

            # Remove the first data point
            t = t[1:]
            g2 = g2[1:]

            ########################################################
            ###################### Fit models ######################
            ########################################################

            ########### Fit the Single Exponential Model ###########

            # Calculate the initial guess for parameters A and B:
            A0 = np.mean(g2[-5:])     # Average of the last 5 points
            B0 = np.mean(g2[:2]) - 1  # Average of the first 2 points minus 1

            # Calculate the initial guess for parameter C:
            # Obtain y
            y = (np.mean(g2[:2]) + np.mean(g2[-5:])) / 2
            # Find the closest x value in the experimental curve
            closest_index = np.abs(g2 - y).argmin()
            closest_x = t[closest_index]
            # Obtain C0:
            C0 = 1 / closest_x

            # Initial guess for parameters A, B, and C:
            initial_params_single = [A0, B0, C0]

            ## Fit the curve g2 = A + B * exp(-2C * t)

            initial_params_single = [A0, B0, C0]

            fit_params_single, fitted_curve_single, r2_single = fit_model(t, g2, single_exponential, initial_params_single)
            
            if math.isnan(r2_single):
                # Add the filename to the list if fitting fails
                counters['failure'] += 1       # Increment the failure counter
                counters['failed_files'].append(f"{base_name} - {dataset_name}")
            
            # Extract q value from the dataset name Sirius
            match = re.search(r"q = (\S+)", dataset_name)
            if match:
                q_value = float(match.group(1))
            else:
                q_value = np.nan

            # Calculate Relaxation time and Diffusion coefficient for Single Exponential
            relax_time_single, diffusion_coef_single = calculate_relaxation_and_diffusion(fit_params_single, q_value)

            #####################################################
            # Add the data for (relax_rate vs q^2) in the lists #
            #####################################################
            
            mod_values["q_square"].append(q_value**2)
            mod_values["relax_rate"].append(fit_params_single[2])
            
            #####################################################

            ########## Fit the Stretched Exponential model ##########

            ## Fit the curve g2 = A + B * np.exp(-2C * t)**gamma

            # Initial guess for parameter gamma
            gamma0 = 1

            # Use the parameters A, B, C from the Single Exponential fit and gamma0 as initial guess
            initial_params_stretched = [*fit_params_single, gamma0]

            fit_params_stretched, fitted_curve_stretched, r2_stretched = fit_model(t, g2, stretched_exponential, initial_params_stretched)
            
            if math.isnan(r2_stretched):
                # Create the filename string
                failed_file = f"{base_name} - {dataset_name}"
                # Add the filename to the list if it's not already present
                if failed_file not in counters['failed_files']:
                    counters['failure'] += 1       # Increment the failure counter
                    counters['failed_files'].append(failed_file)

            # Calculate Relaxation time and Diffusion coefficient for Stretched Exponential
            relax_time_stretched, diffusion_coef_stretched = calculate_relaxation_and_diffusion(fit_params_stretched, q_value)

            ########################################################
            ### Plot the experimental data and the fitted curves ###
            ########################################################

            # for Single Exponential
            color = cmap(i)  # Get color based on index
            plot_data_and_curves(ax_single, t, g2, fitted_curve_single, q_value, r2_single, lines_labels_dict, color, linestyle='-', model_type = "Single")

            # for Stretched Exponential
            #color = cmap(i)  # Get color based on index
            plot_data_and_curves(ax_stretched, t, g2, fitted_curve_stretched, q_value, r2_stretched, lines_labels_dict, color, linestyle='--', model_type = "Stretched")
            
            # Create a dictionary to map q values to colors
            #q_color_mapping = []
            #q_color_mapping.append(color)
            
            # Associate the color with the q value in the dictionary
            #q_color_mapping[mod_values["q_square"][i]] = color
            
            ################################
            ### Append data to the table ###
            ################################

            # for Single Exponential
            A = fit_params_single[0]
            B = fit_params_single[1]
            C = fit_params_single[2]
            # Append data to the table for Single Exponential
            add_data_to_table(table_data["single"], base_name, q_value, [A, B, C], relax_time_single, diffusion_coef_single, r2_single)

            # for Stretched Exponential
            A_stretched = fit_params_stretched[0]
            B_stretched = fit_params_stretched[1]
            C_stretched = fit_params_stretched[2]
            gamma = fit_params_stretched[3]
            # Append data to the table for Stretched Exponential
            add_data_to_table(table_data["stretched"], base_name, q_value, [A_stretched, B_stretched, C_stretched, gamma], relax_time_stretched, diffusion_coef_stretched, r2_stretched)

            # Write data to .dat file for Single Exponential and Stretched Exponential
            write_dat_file(output_path, q_value, t, g2, A, B, C, A_stretched, B_stretched, C_stretched, gamma,
                           r2_single, r2_stretched, relax_time_single, relax_time_stretched,
                           diffusion_coef_single, diffusion_coef_stretched)
            
            # Increment the hdf5_file counter
            counters['success'] += 1  

    ###############################################################
    ################ Fit linear relax_rate vs q^2  ################
    ###############################################################

    ## Fit the curve C = D * q**2
    
    try: 
        # Call the function to perform the fit
        D, pearson_r = fit_linear_model(mod_values["q_square"], mod_values["relax_rate"])

        # Create a color palette with the number of points
        n = len(mod_values["q_square"])
        #cmap = plt.get_cmap('gist_rainbow', n)
        #cmap = plt.get_cmap('turbo', n)
        
        # Plot the model data
        for i in range(n):
            color = cmap(i)
            line_data = ax_linear.plot(mod_values["q_square"][i], mod_values["relax_rate"][i], 'o', color=color, label="")[0]

        # Generate the linear fit line using the fitted value of D
        q_squared_fit = np.linspace(0, max(mod_values["q_square"]), 100)
        linear_fit_line = D * q_squared_fit

        # Plot the linear fit line
        ax_linear.plot(q_squared_fit, linear_fit_line, linestyle='-', color='black', label=f"Linear Fit")

        # Set the axes to start from zero
        ax_linear.set_xlim(0,(max(mod_values["q_square"])+0.1*max(mod_values["q_square"])))
        ax_linear.set_ylim(0,(max(mod_values["relax_rate"])+0.1*max(mod_values["relax_rate"])))

    except ValueError as e:
        continue

    # Set labels
    ax_linear.set_xlabel("q^2 ($\AA^{-2}$)")
    ax_linear.set_ylabel("Relaxation Rate (s^-1)")
    ax_linear.set_title("Relax rate vs q^2")
    ax_linear.legend()
    
    # Annotate the graph with slope (D) and Pearson correlation coefficient (pearson_r)
    slope_annotation = f"Slope (D) = {D:.2e}"
    r_annotation = f"Pearson r = {pearson_r:.2f}"
    
    # Get the coordinates of the graph label
    label_x, label_y = 0.02, ax_linear.get_ylim()[1]

    # Add the label in the upper-left corner without a border
    ax_linear.annotate(slope_annotation, (0.02, 0.90), xycoords='axes fraction', color="black", fontsize=12)

    # Add the annotations below the label
    ax_linear.annotate(r_annotation, (0.02, 0.85), xycoords='axes fraction', color="black", fontsize=12)

    ### Configure the subplots

    # Configure subplot for Single Exponential
    configure_subplot(ax_single, "Single Exponential", "Delay Time (s)", "(g2-1)",
                      lines_labels_dict["exp_lines_single"],
                      lines_labels_dict["exp_labels_single"],
                      lines_labels_dict["fit_lines_single"],
                      lines_labels_dict["fit_labels_single"])

    # Configure subplot for Stretched Exponential
    configure_subplot(ax_stretched, "Stretched Exponential", "Delay Time (s)", "(g2-1)",
                      lines_labels_dict["exp_lines_stretched"],
                      lines_labels_dict["exp_labels_stretched"],
                      lines_labels_dict["fit_lines_stretched"],
                      lines_labels_dict["fit_labels_stretched"])

    # Save the figure to a PDF file
    pdf_name = f"{base_name}.pdf"
    output_pdf_path = os.path.join(directory, pdf_name)
    plt.tight_layout()
    plt.savefig(output_pdf_path, format='pdf')
    #plt.savefig(output_pdf_path, bbox_inches='tight', format='pdf')

    # Close the figure to free up memory
    plt.close(fig)

plt.close()
    
# Print the counts, filed files and invalid files
print_summary(counters)

# Generate the new .dat files with the fit results tables for each model
generate_fit_results_tables(directory, table_data["single"], table_data["stretched"])

Total HDF5 files: 69
Total x files: 414
Total invalid files: 7
Invalid files:
XPCS_SiNP_PEG_1 - q = 1.12e-03 A^-1
XPCS_SiNP_PEG_1 - q = 1.33e-03 A^-1
XPCS_SiNP_PEG_1 - q = 2.50e-04 A^-1
XPCS_SiNP_PEG_1 - q = 4.67e-04 A^-1
XPCS_SiNP_PEG_1 - q = 6.84e-04 A^-1
XPCS_SiNP_PEG_1 - q = 9.01e-04 A^-1
XPCS_mv70p1000_1000Hz - q = 1.44e-03 A^-1
Successful fits: 407
Failed fits: 9
Failed fits files:
XPCS_SiNP400_FBS_2 - q = 4.67e-04 A^-1
XPCS_SiNP400_FBS_2 - q = 6.84e-04 A^-1
XPCS_SiNP400_FBS_2 - q = 9.01e-04 A^-1
XPCS_mv70p1000_10Hz - q = 5.01e-04 A^-1
XPCS_SiNP_270_PBS1x_1 - q = 6.84e-04 A^-1
XPCS_mv70p1000_10Hz_8 - q = 7.35e-04 A^-1
XPCS_mv70p1000_10Hz_8 - q = 9.68e-04 A^-1
XPCS_mv70p1000_100Hz - q = 5.01e-04 A^-1
XPCS_mv70p1000_100Hz - q = 7.35e-04 A^-1
Fit results tables generated successfully.
