# Import

In [None]:
#from dataloaders import Dataloader as DL
import csv
import os
from pathlib import Path
import h5py
import pandas as pd
import numpy as np
import arpespythontools as arp
import matplotlib.pyplot as plt
from ipywidgets import widgets, Layout, interactive_output
import astropy 
from mpl_toolkits.mplot3d import Axes3D
import pickle
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from matplotlib import colors, cm, animation
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
import os
import sys
import time
import itertools
import shutil
import pandas as pd
from scipy.interpolate import interp1d
import PyMca5.PyMcaCore.SpecFileDataSource as SpecFileDataSource
# for the fitting procedures
from lmfit import *

In [None]:
# Run the notebook script
%run data_loader.ipynb

# Access the cached data directly by variable names
map_data2 = map_data2
phi = phi
energy = energy
theta = theta
theta_v = theta_v


In [None]:
def angle_to_wavevector_binding(E_pho, work_f, theta, phi):
    #CONVERTS PIXELS TO THETA
    global theta_v
    
    #print(theta_v)
    E_k = E_pho - work_f
    k = 0.512*np.sqrt(E_k)
    theta_rad = np.radians(theta_v)
    phi_rad = np.radians(phi)
    
    #K_x and K_Y 
    k_x = k*np.sin(theta_rad)
    k_y = k*np.sin(phi_rad)
    
    #k_x = np.degrees(k_x)
    #k_y = np.degrees(k_y)
    return k_x, k_y
        
    
global k_x
global k_y

k_x, k_y = angle_to_wavevector_binding(85, 4, theta_v, phi)

# Variables

In [None]:
# Set global parameters
global emin1, emax2, vmax1, k_x1, k_x2, vmax2, pmin, pmax, vmax, k_y1, k_y2
emin1, emax2, vmax1 = -0.44, -0.41, 16226
k_x1, k_x2, vmax2 = -1, 360, 50000
k_y1, k_y2 = 0, 121
pmin, pmax, vmax = -1, 2, 35480

file_path = data_file_path

# Function to extract dataset name under '2D_Data' group
def get_dataset_name(file_path):
    with h5py.File(file_path, "r") as h5_file:
        if "2D_Data" in h5_file:
            dataset_names = list(h5_file["2D_Data"].keys())  # Extract dataset names
            return dataset_names[0]  # Assuming you want the first dataset name
        else:
            return "Group '2D_Data' not found in the file."

# Get the dataset name
dataset_name = str(get_dataset_name(file_path))

# Output the dataset name
print(f"The dataset name is: {dataset_name}")

# Slices

In [None]:
%matplotlib inline

# Function to plot waterfall plot
def waterfall_plot(k_y1, k_y2, Ev_Lowerbound, Ev_Upperbound, offset, interval, k_x_min, k_x_max):
    energies = np.arange(Ev_Lowerbound, Ev_Upperbound, interval)
    plt.figure(figsize=(8, 6))
    
    for i, e in enumerate(energies):
        surface = arp.plane_slice(map_data2, energy, e, e)  # Slice at constant energy
        kx_line = arp.line_profile(surface, k_y, k_y1, k_y2)
        
        if np.max(kx_line) > 0:
            kx_line /= np.max(kx_line)
        else:
            kx_line = np.zeros_like(kx_line)
        
        # Apply mask for k_x range
        mask = (k_x >= k_x_min) & (k_x <= k_x_max)
        plt.plot(k_x[mask], kx_line[mask] + i * offset, color='green', label=f"E = {e:.2f}")
    
    plt.ylabel("Intensity (Offset)")
    plt.xlabel("Wavevector (k_x)")
    plt.title("Waterfall Plot of Constant Energy Slices")
    plt.show()

# Function to save waterfall plot data
def waterfall_save_data(k_y1, k_y2, Ev_Lowerbound, Ev_Upperbound, offset, interval, k_x_min, k_x_max):
    energies = np.arange(Ev_Lowerbound, Ev_Upperbound, interval)
    data_dict = {}
    
    for e in energies:
        surface = arp.plane_slice(map_data2, energy, e, e)
        kx_line = arp.line_profile(surface, k_y, k_y1, k_y2)
        
        if np.max(kx_line) > 0:
            kx_line /= np.max(kx_line)
        
        # Apply mask for k_x range
        mask = (k_x >= k_x_min) & (k_x <= k_x_max)
        data_dict[e] = list(zip(k_x[mask], kx_line[mask]))
    
    max_len = max(len(v) for v in data_dict.values())
    header = []
    for e in sorted(data_dict.keys()):
        header.append(f"k_x E={e:.2f}")
        header.append(f"Intensity E={e:.2f}")
    
    formatted_data = []
    for row_idx in range(max_len):
        row = []
        for e in sorted(data_dict.keys()):
            energy_intensity_list = data_dict[e]
            if row_idx < len(energy_intensity_list):
                row.extend(energy_intensity_list[row_idx])
            else:
                row.extend(["", ""])
        formatted_data.append(row)
    
    # Create subfolder for saving the CSV file
    subfolder = os.path.join("ARPES Waterfall Plots", dataset_name)
    os.makedirs(subfolder, exist_ok=True)
    csv_filename = os.path.join(subfolder, "batch_fit.csv")
    
    # Save data to CSV
    with open(csv_filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(header)  # Write header
        writer.writerows(formatted_data)  # Write the transposed data
    
    print(f"CSV data saved successfully to {csv_filename}")

# Create widgets
k_y1_slider = widgets.FloatSlider(value=k_y[0], min=k_y[0], max=k_y[-1], step=0.01, description='k_y1')
k_y2_slider = widgets.FloatSlider(value=k_y[-1], min=k_y[0], max=k_y[-1], step=0.01, description='k_y2')
Ev_Lowerbound_slider = widgets.FloatSlider(value=-1.22, min=energy[0], max=energy[-1], step=0.01, description='Lower Energy')
Ev_Upperbound_slider = widgets.FloatSlider(value=-0.38, min=energy[0], max=energy[-1], step=0.01, description='Upper Energy')
offset_slider = widgets.FloatSlider(value=0.2, min=0.01, max=5.0, step=0.001, description='Offset')
interval_slider = widgets.FloatSlider(value=0.02, min=0.001, max=0.5, step=0.01, description='Interval')
k_x_min_slider = widgets.FloatSlider(value=np.min(k_x), min=np.min(k_x), max=np.max(k_x), step=0.01, description='k_x min')
k_x_max_slider = widgets.FloatSlider(value=np.max(k_x), min=np.min(k_x), max=np.max(k_x), step=0.01, description='k_x max')

# Button to save data
save_plot_button = widgets.Button(description="Plot Data as CSV")
save_plot_button.on_click(lambda _: waterfall_save_data(
    k_y1_slider.value, k_y2_slider.value, Ev_Lowerbound_slider.value, Ev_Upperbound_slider.value,
    offset_slider.value, interval_slider.value, k_x_min_slider.value, k_x_max_slider.value
))

display(widgets.HBox([save_plot_button]))

# Display interactive sliders
widgets.interact(
    waterfall_plot,
    k_y1=k_y1_slider,
    k_y2=k_y2_slider,
    Ev_Lowerbound=Ev_Lowerbound_slider,
    Ev_Upperbound=Ev_Upperbound_slider,
    offset=offset_slider,
    interval=interval_slider,
    k_x_min=k_x_min_slider,
    k_x_max=k_x_max_slider
)


# Read Batch of Plots

In [None]:
df = pd.read_csv(os.path.join("ARPES Waterfall Plots", dataset_name, "batch_fit.csv"))

In [None]:
num_cols = len(df.columns)
cols_reversed = []

# Loop through the columns in the desired pattern and reverse them
for i in range(num_cols // 2):
    cols_reversed.append(df.columns[num_cols - 2 * i - 2])  # Even-indexed column (2nd to last, 4th to last, etc.)
    cols_reversed.append(df.columns[num_cols - 2 * i - 1])  # Odd-indexed column (last, 3rd to last, etc.)

# Handle the case where there's an odd number of columns (the middle column stays the same)
if num_cols % 2 != 0:
    cols_reversed.append(df.columns[num_cols // 2])

# Reorder the DataFrame columns based on the reversed pattern
df1 = df[cols_reversed]

In [None]:
df1

In [None]:
df

# Initial Cut - this is our background

In [None]:
plt.plot(df1.iloc[:, 0],df1.iloc[:, 1]) #plotting the first curve
plt.xlabel('K_x')
plt.ylabel('Intensity')

In [None]:
# Button to cut CSV using existing k_x_min_slider and k_x_max_slider values
cut_button = widgets.Button(description="Cut CSV")
output = widgets.Output()

def on_button_click(b):
    with output:
        output.clear_output()
        
        print("Column Data Types:\n", df1.dtypes)
        
        # Use the entire DataFrame (assuming the first column is the k_x data)
        df_cut = df1.iloc[:, 0:]
        
        # Find indices closest to the selected k_x min and max values
        min_idx = (df_cut.iloc[:, 0] - k_x_min_slider.value).abs().idxmin()
        max_idx = (df_cut.iloc[:, 0] - k_x_max_slider.value).abs().idxmin()

        minvalue = df_cut.iloc[min_idx, 0]
        maxvalue = df_cut.iloc[max_idx, 0]

        print(f"Selected Range: {minvalue} to {maxvalue}")

        if minvalue > maxvalue:
            print("Error: Lower bound is greater than upper bound. Adjust the k_x sliders.")
            return

        # Filter rows between the selected k_x values
        sliced_df = df_cut[(df_cut.iloc[:, 0] >= minvalue) & (df_cut.iloc[:, 0] <= maxvalue)]

        print(f"Number of rows after filtering: {len(sliced_df)}")

        sliced_df.to_csv('fitting_data.csv', index=False)
        print("Saved fitting_data.csv")

        display(sliced_df.head())

cut_button.on_click(on_button_click)
display(cut_button, output)


In [None]:
df_initial = pd.read_csv("fitting_data.csv")

In [None]:
df_initial

In [None]:
plt.scatter(df_initial.iloc[:, 0],df_initial.iloc[:, 1])
plt.xlabel('k_x')
plt.ylabel('Intensity')

In [None]:
# Fit a parabola (degree 2 polynomial) to the data
coefficients = np.polyfit(df_initial.iloc[:, 0], df_initial.iloc[:, 1], 2)

# The coefficients are returned in the order [a, b, c] for ax^2 + bx + c
a, b, c = coefficients

# Create a function to represent the fitted parabola
def parabola(x):
  return a * x**2 + b * x + c

# Generate points for plotting the fitted parabola
x_fit = np.linspace(min(df_initial.iloc[:, 0]), max(df_initial.iloc[:, 0]), len(df_initial.iloc[:, 0]))
y_fit = parabola(x_fit)

In [None]:
plt.scatter(df_initial.iloc[:, 0],df_initial.iloc[:, 1])
plt.plot(x_fit,y_fit,'red')

Now we remove this parabola from each set of y values

In [None]:
for i in range(len(df_initial.columns)):
    name = df_initial.columns[i]
    if i%2==1:
        df_initial[name] = df_initial[name] - y_fit

In [None]:
df_initial

In [None]:
#see what first "real data" slice looks after parabola subtraction
plt.scatter(df_initial.iloc[:, 2],df_initial.iloc[:, 3])



# Batch Fit GUI

In [None]:
%matplotlib notebook

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from lmfit import Parameters, minimize, report_fit
import ipywidgets as widgets
from IPython.display import display


slice_index = 0               
last_fit_result = None        
fitted_results = []           

active_peaks_count = 4  
max_peaks = 8


def load_slice(index):
    global xdata, ydata
    col_x = 2 + 2 * index
    col_y = 3 + 2 * index
    
    if col_y >= df_initial.shape[1]:
        print("No more slices to load (index out of range).")
        return
    
    # Load new x and y data
    xdata = df_initial.iloc[:, col_x].values
    ydata = df_initial.iloc[:, col_y].values
    
    data_scatter.set_offsets(np.column_stack((xdata, ydata)))
    
    for line in indiv_lines:
        line.set_xdata(xdata)
        line.set_ydata(np.zeros_like(xdata))
    fit_line.set_xdata(xdata)
    fit_line.set_ydata(np.zeros_like(xdata))
    
    for i, widget_box in enumerate(peak_widgets, start=1):
        area_slider = widget_box.children[0]
        cen_slider  = widget_box.children[1]
        fwhm_slider = widget_box.children[2]
        
        new_min = np.min(xdata)
        new_max = np.max(xdata)
        current_val = cen_slider.value
        if current_val < new_min:
            new_min = current_val
        if current_val > new_max:
            new_max = current_val
        cen_slider.min = new_min
        cen_slider.max = new_max
    
    update_model(None)
    ax.relim()
    ax.autoscale_view()
    fig.canvas.draw_idle()
    print(f"Loaded slice #{index} using columns {col_x} (x) and {col_y} (y).")


def gaussian_peaks(x, peaks, a1, b1):
    total = np.zeros_like(x)
    for peak in peaks:
        area = peak['area']
        cen  = peak['cen']
        fwhm = peak['fwhm']
        sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
        total += (area / (np.sqrt(2 * np.pi) * sigma)) * np.exp(-((x - cen)**2) / (2 * sigma**2))
    return total + (a1 * x + b1)

def residuals(params, x, y, n_peaks):
    peak_params = []
    for i in range(1, n_peaks + 1):
        peak_params.append({
            'area': params[f"area{i}"].value,
            'cen':  params[f"cen{i}"].value,
            'fwhm': params[f"fwhm{i}"].value
        })
    a1 = params['a1'].value
    b1 = params['b1'].value
    return gaussian_peaks(x, peak_params, a1, b1) - y

def fit_peaks(init_params=None):
    global active_peaks_count
    params = Parameters()
    for i in range(1, active_peaks_count + 1):
        params.add(f"area{i}",
                   value=init_params.get(f"area{i}", 0.2),
                   min=0, max=9999)
        params.add(f"cen{i}",
                   value=init_params.get(f"cen{i}", np.mean(xdata)))
        params.add(f"fwhm{i}",
                   value=init_params.get(f"fwhm{i}", 0.4),
                   min=0.01, max=9999)
    params.add('a1', value=init_params.get('a1', 0.0))
    params.add('b1', value=init_params.get('b1', 0.0))
    
    minner = minimize(residuals, params, args=(xdata, ydata, active_peaks_count))
    result = minner
    
    if result.success:
        peak_params = []
        for i in range(1, active_peaks_count + 1):
            peak_params.append({
                'area': result.params[f"area{i}"].value,
                'cen':  result.params[f"cen{i}"].value,
                'fwhm': result.params[f"fwhm{i}"].value
            })
        fitted_curve = gaussian_peaks(xdata, peak_params,
                                      result.params['a1'].value,
                                      result.params['b1'].value)
        return fitted_curve, result, report_fit(result.params)
    else:
        print("Fit failed:", result.message)
        return None, None, "Fit failed"


fig, ax = plt.subplots(figsize=(8, 5))
data_scatter = ax.scatter([], [], color='green', label="Data")
fit_line, = ax.plot([], [], 'r-', lw=2, label="Total Fit")

# Create placeholders for the individual peak lines:
colors = ['blue', 'orange', 'purple', 'brown', 'magenta', 'cyan', 'gray', 'olive']
indiv_lines = []
for i in range(active_peaks_count):
    c = colors[i % len(colors)]
    line, = ax.plot([], [], '--', color=c, label=f"Peak {i+1}")
    indiv_lines.append(line)

ax.legend(loc='best')


def create_peak_widget(i):
    area_slider = widgets.FloatSlider(value=0.2, min=0, max=2, step=0.01,
                                      description=f'Area {i}:', continuous_update=True)
    cen_slider = widgets.FloatSlider(value=0.0, min=-5, max=5, step=0.01,
                                     description=f'Center {i}:', continuous_update=True)
    fwhm_slider = widgets.FloatSlider(value=0.4, min=0.01, max=2, step=0.01,
                                      description=f'FWHM {i}:', continuous_update=True)
    return widgets.HBox([area_slider, cen_slider, fwhm_slider])

peak_widgets = [create_peak_widget(i) for i in range(1, active_peaks_count + 1)]
peak_controls = widgets.VBox(peak_widgets)

# Background parameter sliders
a1_slider = widgets.FloatSlider(value=0.0, min=-2, max=2, step=0.01,
                                description='Slope (a1):', continuous_update=True)
b1_slider = widgets.FloatSlider(value=0.0, min=-2, max=2, step=0.01,
                                description='Intercept (b1):', continuous_update=True)

def update_model(change):
    # Read current slider values
    current_params = {}
    for i, widget_box in enumerate(peak_widgets, start=1):
        sliders = widget_box.children
        current_params[f"area{i}"] = sliders[0].value
        current_params[f"cen{i}"]  = sliders[1].value
        current_params[f"fwhm{i}"] = sliders[2].value
    current_params['a1'] = a1_slider.value
    current_params['b1'] = b1_slider.value

    # Update individual peak lines
    for i in range(active_peaks_count):
        area = current_params[f"area{i+1}"]
        cen  = current_params[f"cen{i+1}"]
        fwhm = current_params[f"fwhm{i+1}"]
        sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
        curve = (area / (np.sqrt(2 * np.pi)*sigma)) * np.exp(-((xdata - cen)**2)/(2*sigma**2))
        indiv_lines[i].set_xdata(xdata)
        indiv_lines[i].set_ydata(curve)
    
    # Total model line update
    peaks = []
    for i in range(active_peaks_count):
        peaks.append({
            'area': current_params[f"area{i+1}"],
            'cen':  current_params[f"cen{i+1}"],
            'fwhm': current_params[f"fwhm{i+1}"]
        })
    total_curve = gaussian_peaks(xdata, peaks, current_params['a1'], current_params['b1'])
    fit_line.set_xdata(xdata)
    fit_line.set_ydata(total_curve)
    
    ax.relim()
    ax.autoscale_view()
    fig.canvas.draw_idle()

# Attach the update callback to all slider changes
for widget_box in peak_widgets:
    for slider in widget_box.children:
        slider.observe(update_model, names='value')
a1_slider.observe(update_model, names='value')
b1_slider.observe(update_model, names='value')


add_peak_button = widgets.Button(description="Add Peak", button_style='info')
remove_peak_button = widgets.Button(description="Remove Peak", button_style='warning')

def add_peak(_):
    global active_peaks_count, peak_widgets, indiv_lines
    if active_peaks_count < max_peaks:
        active_peaks_count += 1
        new_widget = create_peak_widget(active_peaks_count)
        for slider in new_widget.children:
            slider.observe(update_model, names='value')
        peak_widgets.append(new_widget)
        peak_controls.children = tuple(peak_widgets)
        color = colors[(active_peaks_count-1) % len(colors)]
        new_line, = ax.plot(xdata, np.zeros_like(xdata), '--', color=color, label=f"Peak {active_peaks_count}")
        indiv_lines.append(new_line)
        ax.legend(loc='best')
        update_model(None)

def remove_peak(_):
    global active_peaks_count, peak_widgets, indiv_lines
    if active_peaks_count > 1:
        active_peaks_count -= 1
        peak_widgets.pop()
        peak_controls.children = tuple(peak_widgets)
        line_to_remove = indiv_lines.pop()
        line_to_remove.remove()
        ax.legend(loc='best')
        update_model(None)

add_peak_button.on_click(add_peak)
remove_peak_button.on_click(remove_peak)


fit_button = widgets.Button(description="Fit Data", button_style='success')
output = widgets.Output()

def run_fit(_):
    global last_fit_result  # Store the latest result here
    with output:
        output.clear_output()
        # Gather current slider values as initial parameters
        init_params = {}
        for i, widget_box in enumerate(peak_widgets, start=1):
            sliders = widget_box.children
            init_params[f"area{i}"] = sliders[0].value
            init_params[f"cen{i}"]  = sliders[1].value
            init_params[f"fwhm{i}"] = sliders[2].value
        init_params['a1'] = a1_slider.value
        init_params['b1'] = b1_slider.value

        fitted_curve, result, report_str = fit_peaks(init_params)
        
        if result is not None:
            # Update the main fit line and individual peak lines
            fit_line.set_ydata(fitted_curve)
            for i in range(active_peaks_count):
                area = result.params[f"area{i+1}"].value
                cen  = result.params[f"cen{i+1}"].value
                fwhm = result.params[f"fwhm{i+1}"].value
                sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
                curve = (area / (np.sqrt(2 * np.pi)*sigma)) * np.exp(-((xdata - cen)**2)/(2*sigma**2))
                indiv_lines[i].set_ydata(curve)
            ax.relim()
            ax.autoscale_view()
            fig.canvas.draw_idle()
            
            # Save the fit result so that the "Next Slice" callback can use it.
            last_fit_result = result
            
            print("Fitting Results:")
            print(report_str)
        else:
            print("Fitting failed!")

fit_button.on_click(run_fit)


next_slice_button = widgets.Button(description="Next Slice", button_style='info')

def on_next_slice(_):
    global slice_index, last_fit_result
    
    # If a fit was done on the current slice, store its parameters and update slider values.
    if last_fit_result is not None:
        param_dict = {"slice_index": slice_index}
        for i in range(1, active_peaks_count+1):
            p_area = last_fit_result.params[f"area{i}"]
            p_cen  = last_fit_result.params[f"cen{i}"]
            p_fwhm = last_fit_result.params[f"fwhm{i}"]
            
            param_dict[f"area{i}"]     = p_area.value
            param_dict[f"area{i}_err"] = p_area.stderr
            param_dict[f"cen{i}"]      = p_cen.value
            param_dict[f"cen{i}_err"]  = p_cen.stderr
            param_dict[f"fwhm{i}"]     = p_fwhm.value
            param_dict[f"fwhm{i}_err"] = p_fwhm.stderr
            
            # Update slider values using the fitted parameters
            widget_box = peak_widgets[i-1]
            sliders = widget_box.children
            sliders[0].value = p_area.value
            sliders[1].value = p_cen.value
            sliders[2].value = p_fwhm.value
        
        # Background parameters
        p_a1 = last_fit_result.params["a1"]
        p_b1 = last_fit_result.params["b1"]
        param_dict["a1"]     = p_a1.value
        param_dict["a1_err"] = p_a1.stderr
        param_dict["b1"]     = p_b1.value
        param_dict["b1_err"] = p_b1.stderr
        
        a1_slider.value = p_a1.value
        b1_slider.value = p_b1.value
        
        # Append this slice's results to the list
        fitted_results.append(param_dict)
        last_fit_result = None  # Clear fit result
        
    # Move on to the next slice
    slice_index += 1
    load_slice(slice_index)

next_slice_button.on_click(on_next_slice)


done_button = widgets.Button(description="Done", button_style='danger')

def on_done(_):
    global results_df 
    results_df = pd.DataFrame(fitted_results)
    results_df.set_index("slice_index", inplace=True)
    print("All fitted results collected so far:")
    print(results_df)
    display(results_df)

done_button.on_click(on_done)


ui = widgets.VBox([
    widgets.HBox([add_peak_button, remove_peak_button]),
    peak_controls,
    widgets.HBox([a1_slider, b1_slider]),
    fit_button,
    next_slice_button,
    done_button,
    output
])
display(ui)

load_slice(slice_index)


In [None]:
results_df