"""
HDF5 XPCS FitLab - Script for analyzing X-ray Photon Correlation Spectroscopy (XPCS) data from HDF5 files.

Dependencies:
- Python (version >= 3.6)
- h5py (to work with HDF5 files)
- numpy (for numerical calculations)
- scipy (for curve fitting)
- scikit-learn (for calculating R2 score)
- matplotlib (for data visualization)
- qtpy (for GUI-based directory selection)

Installation:
1. Make sure you have Python 3.6 or later installed. If not, download and install Python from https://www.python.org/downloads/

2. Install the required packages using pip. Open a terminal or command prompt and run the following command:

pip install h5py numpy scipy scikit-learn matplotlib qtpy

"""

In [1]:
import os
import h5py
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from qtpy.QtWidgets import QApplication, QFileDialog
import sys
import re
import warnings

import matplotlib
matplotlib.use('Agg')  # Configuration for generating PNG images
import matplotlib.pyplot as plt

# Directory where the .hdf5 files are located
directory = ''

# Initialize the QtPy application
app = QApplication([])

# Prompt for directory if not defined
if not directory:
    directory = QFileDialog.getExistingDirectory(None, "Select directory")

# Exit if no directory is selected
if not directory:
    print("No directory selected.")
    sys.exit()

# Get a list of .hdf5 files in the directory
hdf5_files = [file for file in os.listdir(directory) if file.endswith('.hdf5')]

# Exit if no .hdf5 files are found
if not hdf5_files:
    print("No .hdf5 files found in the directory.")
    sys.exit()

## Initialize counters
# Files
hdf5_files_count = 0
multi_tau_count = 0
invalid_data_count = 0
# Fits
success_count = 0
failure_count = 0

# Initialize lists
invalid_files = []
failed_files = []

# Initialize data for the table
table_data_single = []
table_data_stretched = []

# Loop over each .hdf5 file
for hdf5_file in hdf5_files:
    file_path = os.path.join(directory, hdf5_file)
    hdf5_files_count += 1
    
    # Get the base name for output file naming
    base_name = hdf5_file.replace('_saxs_0000_RESULTS.hdf5', '')
    
    with h5py.File(file_path, 'r') as hdf:
        if 'Multi-tau' not in hdf.keys():
            print(f"Skipping file: {hdf5_file}. 'Multi-tau' key not found.")
            continue
        
        multi_tau = hdf['Multi-tau']
        
        # Read the keys of the datasets within Multi-tau
        dataset_keys = list(multi_tau.keys())
        
        # Create a new figure for each HDF5 file
        fig, (ax_single, ax_stretched) = plt.subplots(1, 2, figsize=(20, 8))
        
        # Set the title of the figure to include the file name
        fig.suptitle(f"File: {base_name}", fontsize=16)
        
        # Initialize a color map for different q values
        cmap = plt.get_cmap('tab10')

        # Initialize lists to store lines and labels for legend items
        exp_lines_single = []    # Lines for experimental data for Single Exponential
        exp_labels_single = []   # Labels for experimental data for Single Exponential
        fit_lines_single = []    # Lines for fitted curves for Single Exponential
        fit_labels_single = []   # Labels for fitted curves for Single Exponential

        exp_lines_stretched = []  # Lines for experimental data for Stretched Exponential
        exp_labels_stretched = [] # Labels for experimental data for Stretched Exponential
        fit_lines_stretched = []  # Lines for fitted curves for Stretched Exponential
        fit_labels_stretched = [] # Labels for fitted curves for Stretched Exponential
        
        # Loop over each dataset within Multi-tau
        for i, dataset_key in enumerate(dataset_keys):
            multi_tau_count += 1
            # Remove any leading/trailing spaces from the dataset key name
            dataset_name = dataset_key.strip()
            
            output_name = f"{base_name}_export_{chr(ord('a')+i)}.dat"
            output_path = os.path.join(directory, output_name)
            
            # Extract data from the dataset within Multi-tau
            dataset = multi_tau.get(dataset_key)
            if dataset is None:
                continue
            
            column = dataset[:]
            
            # Check if the data has a single column
            if column.ndim != 1:
                continue
            
            # Get the column names from the HDF5 file
            column_names = list(column.dtype.names)
            
            # Get the values of the columns
            t = column['delay time (s)']
            g2 = column['g2']
            
            # Remove the first data point
            t = t[1:]
            g2 = g2[1:]
                
            # Check for invalid values in the data
            if np.isnan(t).any() or np.isnan(g2).any() or np.isinf(t).any() or np.isinf(g2).any():
                invalid_data_count += 1
                invalid_files.append(f"{base_name}_{dataset_name}")
                continue
            
            #### Single Exponential Model
            
            ## Fit the curve g2 = A + B * exp(-2C * t)
            
            # Initial guess for parameters A, B, and C:
            A0 = np.mean(g2[-5:])  # Average of the last 5 points
            B0 = np.mean(g2[:2]) - 1  # Average of the first 2 points minus 1
            
            # Calculate the initial guess for parameter C:
            # Obtain y
            y = (np.mean(g2[:2]) + np.mean(g2[-5:])) / 2
            # Find the closest x value in the experimental curve
            closest_index = np.abs(g2 - y).argmin()
            closest_x = t[closest_index]
            # Obtain C0:
            C0 = 1 / closest_x
            # # # # #
            
            p0 = [A0, B0, C0]
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")  # Ignore warnings
                try:
                    fit_params, _ = curve_fit(lambda t, A, B, C: A + B * np.exp(-2*C * t), t, g2, p0=p0)
                    A, B, C = fit_params
                    fitted_curve_single = A + B * np.exp(-2 * C * t)
                    r2_single = r2_score(g2, fitted_curve_single)
                    
                    # Extract q value from the dataset name
                    match = re.search(r"q = (\S+)", dataset_name)
                    if match:
                        q_value = float(match.group(1))
                    else:
                        q_value = np.nan
  
                    # Calculate Relaxation time and Diffusion coefficient for Single Exponential
                    relax_time_single = 1 / C
                    diffusion_coef_single = (C / (q_value ** 2)) * 0.00000001
                
                    #### Stretched Exponential Model
                    
                    ## Fit the curve g2 = A + B * np.exp(-2C * t)**gamma
                    
                    # Initial guess for parameter gamma
                    gamma0 = 1
                    
                    p0_stretched = [A, B, C, gamma0]
                    
                    fit_params_stretched, _ = curve_fit(lambda t, A, B, C, gamma: A + B * np.exp(-2 * C * t)**gamma, t, g2, p0=p0_stretched)
                    A_stretched, B_stretched, C_stretched, gamma = fit_params_stretched
                    fitted_curve_stretched = A_stretched + B_stretched * np.exp(-2 * C_stretched * t)**gamma
                    r2_stretched = r2_score(g2, fitted_curve_stretched)
                    
                    # Calculate Relaxation time and Diffusion coefficient for Stretched Exponential
                    relax_time_stretched = 1 / C_stretched
                    diffusion_coef_stretched = (C_stretched / (q_value ** 2)) * 0.00000001
                    
                    ## Plot the experimental data and the fitted curves for Single Exponential
                    color = cmap(i % 10)  # Get color based on index
                    line_exp_single = ax_single.semilogx(t, g2, 'o', color=color)
                    exp_lines_single.append(line_exp_single[0])
                    exp_labels_single.append(f"q = {q_value}")
                    
                    # Append the fitted curve for Single Exponential to the legend lists
                    line_fit_single = ax_single.semilogx(t, fitted_curve_single, color=color)
                    fit_lines_single.append(line_fit_single[0])
                    fit_labels_single.append(f"R2: {r2_single:.2f}")
                    
                    ## Plot the experimental data and the fitted curves for Stretched Exponential
                    color = cmap(i % 10)  # Get color based on index
                    line_exp_stretched = ax_stretched.semilogx(t, g2, 'o', color=color)
                    exp_lines_stretched.append(line_exp_stretched[0])
                    exp_labels_stretched.append(f"q = {q_value}")
                    
                    # Append the fitted curve for Stretched Exponential to the legend lists
                    line_fit_stretched = ax_stretched.semilogx(t, fitted_curve_stretched, color=color, linestyle='--')
                    fit_lines_stretched.append(line_fit_stretched[0])
                    fit_labels_stretched.append(f"R2: {r2_stretched:.2f}")
                    
                    # Append data to the table for Single Exponential
                    table_row_single = [base_name, q_value, A, B, C, relax_time_single, diffusion_coef_single, r2_single]
                    table_data_single.append(table_row_single)
                    
                    # Append data to the table for Stretched Exponential
                    table_row_stretched = [base_name, q_value, A_stretched, B_stretched, C_stretched, gamma, relax_time_stretched, diffusion_coef_stretched, r2_stretched]
                    table_data_stretched.append(table_row_stretched)
                    
                    # Write data to .dat file
                    with open(output_path, 'w') as dat_file:
                        # Write comments for Single Exponential
                        dat_file.write(f"###Fit model: Single Exponential\n")
                        dat_file.write(f"#Fit parameters: baseline={A}, beta={B}, relax rate (1/s)={C}\n")
                        dat_file.write(f"#R2: {r2_single}\n")
                        dat_file.write(f"#Derived parameters: Relax. time (s) = {relax_time_single}, Diffusion coef (u2/s) = {diffusion_coef_single}\n")
                        dat_file.write("##Fit model: Stretched Exponential\n")
                        dat_file.write(f"#Fit parameters: baseline={A_stretched}, beta={B_stretched}, relax rate (1/s)={C_stretched} gamma={gamma}\n")
                        dat_file.write(f"#R2: {r2_stretched}\n")
                        dat_file.write(f"#Derived parameters: Relax. time (s) = {relax_time_stretched}, Diffusion coef (u2/s) = {diffusion_coef_stretched}\n")
                        dat_file.write("#delay time (s)\tg2\tstd\tFitted g2 Single\tFitted g2 Stretched\n")

                        # Calculate the fitted curve for Single Exponential
                        fitted_curve_single = A + B * np.exp(-2 * C * t)
                        # Calculate the fitted curve for Stretched Exponential
                        fitted_curve_stretched = A_stretched + B_stretched * np.exp(-2 * C_stretched * t)**gamma

                        # Write data and fitted curves for both models
                        for i in range(len(t)):
                            row_str = f"{t[i]}\t{g2[i]}\t{np.std(g2)}\t{fitted_curve_single[i]}\t{fitted_curve_stretched[i]}\n"
                            dat_file.write(row_str)
                            
                        success_count += 1
                
                except (RuntimeError, ValueError):
                    # Failed to fit the curve or encountered invalid values
                    A, B, C, gamma, r2_single, r2_stretched = np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
                    relax_time_single, relax_time_stretched = np.nan, np.nan
                    diffusion_coef_single, diffusion_coef_stretched = np.nan, np.nan
                    failure_count += 1
                    failed_files.append(f"{base_name}_{dataset_name}")
                
        # Set the title and labels for the subplot (for Single Exponential)
        ax_single.set_title("Single Exponential")
        ax_single.set_xlabel('Delay Time (s)')
        ax_single.set_ylabel('g2')

        # Set the title and labels for the subplot (for Stretched Exponential)
        ax_stretched.set_title("Stretched Exponential")
        ax_stretched.set_xlabel('Delay Time (s)')
        ax_stretched.set_ylabel('g2')

        # Create the legend for experimental data for Single Exponential
        legend_exp_single = ax_single.legend(exp_lines_single, exp_labels_single, loc='upper right', bbox_to_anchor=(0.84, 1), borderaxespad=0)
        ax_single.add_artist(legend_exp_single)

        # Create the legend for fitted curves for Single Exponential
        legend_fit_single = ax_single.legend(fit_lines_single, fit_labels_single, loc='upper right', bbox_to_anchor=(1, 1), borderaxespad=0)
        ax_single.add_artist(legend_fit_single)

        # Create the legend for experimental data for Stretched Exponential
        legend_exp_stretched = ax_stretched.legend(exp_lines_stretched, exp_labels_stretched, loc='upper right', bbox_to_anchor=(0.84, 1), borderaxespad=0)
        ax_stretched.add_artist(legend_exp_stretched)

        # Create the legend for fitted curves for Stretched Exponential
        legend_fit_stretched = ax_stretched.legend(fit_lines_stretched, fit_labels_stretched, loc='upper right', bbox_to_anchor=(1, 1), borderaxespad=0)
        ax_stretched.add_artist(legend_fit_stretched)

        # Save the plot as a PNG file
        png_name = f"{base_name}.png"
        png_path = os.path.join(directory, png_name)
        plt.savefig(png_path)

        # Close the figure to free up memory
        plt.close(fig)
    
# Exit the QtPy application
app.exit()

# Print the counts
print(f"Total HDF5 files: {hdf5_files_count}")
print(f"Total x files: {multi_tau_count}")
print(f"Total invalid files: {invalid_data_count}")
print(f"Successful fits: {success_count}")
print(f"Failed fits: {failure_count}")

# Print the base_names of failed files
print("Failed files:")
for file_name in failed_files:
    print(file_name)

# Print the base_names of invalid files
print("Invalid files:")
for file_name in invalid_files:
    print(file_name)
    
# Generate the new .dat files with the table for each model
table_output_path_single = os.path.join(directory, "fit_results_single.dat")
table_output_path_stretched = os.path.join(directory, "fit_results_stretched.dat")

with open(table_output_path_single, 'w') as table_file_single, open(table_output_path_stretched, 'w') as table_file_stretched:
    # Write header for Single Exponential
    header_single = "Filename\tq (A^-1)\tBaseline\tbeta\tRelax. time (s)\tRelax. rate (s-1)\tDiffusion coefficient (u2/s)\tR2\n"
    table_file_single.write(header_single)

    # Write header for Stretched Exponential
    header_stretched = "Filename\tq (A^-1)\tBaseline\tbeta\tRelax. rate (s-1)\t Gamma \tRelax. time (s)\tDiffusion coefficient (u2/s)\tR2\n"
    table_file_stretched.write(header_stretched)

    # Write table data for Single Exponential
    for row_single in table_data_single:
        row_str_single = '\t'.join(str(item) for item in row_single)
        table_file_single.write(f"{row_str_single}\n")
    
    # Write table data for Stretched Exponential
    for row_stretched in table_data_stretched:
        row_str_stretched = '\t'.join(str(item) for item in row_stretched)
        table_file_stretched.write(f"{row_str_stretched}\n")

print("Fit results tables generated successfully.")

Total HDF5 files: 69
Total x files: 414
Total invalid files: 7
Successful fits: 398
Failed fits: 9
Failed files:
XPCS_SiNP400_FBS_2_q = 4.67e-04 A^-1
XPCS_SiNP400_FBS_2_q = 6.84e-04 A^-1
XPCS_SiNP400_FBS_2_q = 9.01e-04 A^-1
XPCS_mv70p1000_10Hz_q = 5.01e-04 A^-1
XPCS_SiNP_270_PBS1x_1_q = 6.84e-04 A^-1
XPCS_mv70p1000_10Hz_8_q = 7.35e-04 A^-1
XPCS_mv70p1000_10Hz_8_q = 9.68e-04 A^-1
XPCS_mv70p1000_100Hz_q = 5.01e-04 A^-1
XPCS_mv70p1000_100Hz_q = 7.35e-04 A^-1
Invalid files:
XPCS_SiNP_PEG_1_q = 1.12e-03 A^-1
XPCS_SiNP_PEG_1_q = 1.33e-03 A^-1
XPCS_SiNP_PEG_1_q = 2.50e-04 A^-1
XPCS_SiNP_PEG_1_q = 4.67e-04 A^-1
XPCS_SiNP_PEG_1_q = 6.84e-04 A^-1
XPCS_SiNP_PEG_1_q = 9.01e-04 A^-1
XPCS_mv70p1000_1000Hz_q = 1.44e-03 A^-1
Fit results tables generated successfully.
