In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
from glob import glob
from datetime import datetime

In [None]:
HERE = os.path.abspath(os.getcwd())

In [None]:
from AnalyticalLabware.analysis.spinsolve_spectrum import SpinsolveNMRSpectrum


spectrum = SpinsolveNMRSpectrum()

In [None]:
exp_code = "terp-GA-acid-v1"

In [None]:
def get_folder_paths(parent_folder):
    folder_paths = []
    for root, dirs, _ in tqdm(os.walk(parent_folder)):
        for directory in dirs:
            if "processed" not in directory and "Enhanced" not in directory:
                folder_paths.append(os.path.join(root, directory))
    return folder_paths


parent_folder = os.path.join(
    r"C:\Users\rober\Downloads\raw spectra\raw spectra\NMR", exp_code
)
datapaths = get_folder_paths(parent_folder)
print(f"Found {len(datapaths)} data files in {parent_folder}")

In [None]:
# just for getting the total monitoring time
timestamps = [filepath.split("\\")[-1][:15] for filepath in datapaths]
times = sorted(
    [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in timestamps]
)
print("Total experiment time: ", times[-1] - times[0])

In [None]:
PROCESSED = os.path.join(parent_folder, "processed")

if not os.path.exists(PROCESSED):
    os.makedirs(PROCESSED)

for datapath in tqdm(datapaths):
    spectrum.load_spectrum(data_path=datapath)
    # spectrum.default_processing()

    spectrum.apodization(function="gm", g1=1.2, g2=4.5)
    spectrum.zero_fill()
    spectrum.fft()
    # spectrum.reference_spectrum(new_position=1.8, reference='highest')
    # spectrum.correct_baseline()
    # spectrum.autophase()
    # spectrum.correct_baseline()

    spectrum.find_peaks(decimals=1, threshold=0.1)
    spectrum.save_pickle(os.path.join(PROCESSED, os.path.basename(datapath)))

In [None]:
processed_datafiles = sorted(glob(f"{PROCESSED}/*"))
processed_datafiles = [file for file in processed_datafiles if file.endswith(".pkl")]
timestamps = [filepath.split("\\")[-1][:15] for filepath in processed_datafiles]

In [None]:
# for datafile in tqdm(processed_datafiles):
#   spectrum.load_data(datafile)
#   spectrum.highpass_filter(threshold=100)
#   spectrum.find_peaks(decimals=1, threshold=0.1)
#   spectrum.save_data(datafile.removesuffix(".pickle"))

In [None]:
fig, ax = plt.subplots(figsize=(12, 8))
# plotting only every xth spectrum for clarity
for datafile in processed_datafiles[::4]:
    spectrum = SpinsolveNMRSpectrum().from_pickle(datafile)
    ax.plot(spectrum.x_data, spectrum.y_data.real)
# ax.set_xlim(7,8)
# ax.set_ylim(top=5e4)
ax.invert_xaxis()
ax.set_xlabel("Chemical Shift (ppm)")
ax.set_ylabel("Signal Intensity (a.u.)")
# plt.savefig("nmr_moni.svg", transparent=True)

In [None]:
# Initialize an empty dictionary to store peak data
peak_data = {}

for datafile, timestamp in tqdm(zip(processed_datafiles, timestamps)):
    spectrum.load_data(datafile)
    spectrum.find_peaks(decimals=1, threshold=500)
    spectrum.reference_spectrum(new_position=-60, reference="closest")

    # Create a list to store peak areas for this timestamp
    timestamp_peak_areas = {}

    for peak in spectrum.peaks[:, 0]:
        # if peak <2:
        #     continue
        peak_id = peak.real
        timestamp_peak_areas[peak_id] = spectrum.integrate_peak(peak)

    # Add the timestamp data to the dictionary
    peak_data[timestamp] = timestamp_peak_areas

# Create the DataFrame from the dictionary
peak_areas = pd.DataFrame.from_dict(peak_data, orient="index")

In [None]:
# norm_peak_areas = peak_areas.div(peak_areas[0.3], axis=0)

In [None]:
df = peak_areas
# df = norm_peak_areas
times = sorted(
    [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in df.index]
)
print("Total experiment time: ", times[-1] - times[0])
for column in df.columns:
    # plt.plot(times, df[column], label=column)
    plt.scatter(times, df[column], marker="o", label=column)
plt.xticks(rotation=45)
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.xlabel("Time")
plt.ylabel("Normalized peak area")
plt.title(f"Automatic peak integration for {exp_code}")
# plt.xlim(times[0], times[-1])
# plt.ylim(0,1.0)
plt.show()

In [None]:
from scipy.stats import linregress


def detect_plateau_from_slope(
    times, values, num_datapoints: int = 5, threshold: float = 1e-4
) -> tuple:
    """
    Detects a plateau in a list of values by calculating the slope of a linear regression
    for the last num_datapoints and comparing it to a threshold value.
    """
    for i in range(len(values) - num_datapoints):
        if i < num_datapoints:
            continue
        slope = linregress(
            range(num_datapoints),
            np.array(values[i - num_datapoints : i]) / max(values),
        ).slope
        if abs(slope) < threshold:
            elapsed_time = datetime.strptime(
                times[i], "%Y%m%d_%H%M%S"
            ) - datetime.strptime(times[0], "%Y%m%d_%H%M%S")
            print(f"Reaction has reached plateau after {elapsed_time}")
            return (elapsed_time, i)
    return (None, len(values) - 1)

In [None]:
# Initialize an empty dictionary to store peak data
peak_data = {}
peak_1 = 7.6  # -125
peak_2 = 7.3  # -124.3

for datafile, timestamp in tqdm(zip(processed_datafiles, timestamps)):
    spectrum = SpinsolveNMRSpectrum().from_pickle(datafile)

    # Create a list to store peak areas for this timestamp
    timestamp_peak_areas = {}

    timestamp_peak_areas[peak_1] = spectrum.integrate_area((peak_1 + 0.1, peak_1 - 0.1))
    timestamp_peak_areas[peak_2] = spectrum.integrate_area((peak_2 + 0.2, peak_2 - 0.2))

    # Add the timestamp data to the dictionary
    peak_data[timestamp] = timestamp_peak_areas

# Create the DataFrame from the dictionary
peak_areas = pd.DataFrame.from_dict(peak_data, orient="index")

In [None]:
plt.rcParams.update(
    {
        "font.size": 24,
        "axes.linewidth": 3,
        "xtick.major.width": 3,
        "ytick.major.width": 3,
        "xtick.minor.width": 3,
        "ytick.minor.width": 3,
    }
)

In [None]:
# Define the ppm values for the peaks
peak1_ppm = 7.6  # -125#-60#10.0#-61#7.8#4.88#2.2#7.6#11.5
peak2_ppm = 7.3  # -124.3#-58#8.6#-59.5#7.5#2.6#4.6#7.3#9.6
slope_threshold = 1e-3

# create new dataframe with only the peaks of interest
peak_areas_diagnostic = peak_areas[[peak1_ppm, peak2_ppm]]

# drop the rows with NaN values from peak_areas
peak_areas_diagnostic = peak_areas_diagnostic.dropna()

# Extract the columns corresponding to the peaks at 11.5 ppm and 9.6 ppm
peak1_column = peak_areas_diagnostic[peak1_ppm]
peak2_column = peak_areas_diagnostic[peak2_ppm]

# Calculate the ratio of the peaks at each timestamp
peak_ratio = peak2_column / (peak1_column + peak2_column)

elapsed_time, plateau_index = detect_plateau_from_slope(
    peak_areas_diagnostic.index, peak_ratio, num_datapoints=5, threshold=slope_threshold
)

formatted_times = sorted(
    [
        datetime.strptime(timestamp, "%Y%m%d_%H%M%S")
        for timestamp in peak_areas_diagnostic.index
    ]
)
formatted_times = [
    (timestamp - formatted_times[0]).total_seconds() / 60
    for timestamp in formatted_times
]
# formatted_times = [(datetime.strptime(timestamp, '%Y%m%d_%H%M%S') - datetime.strptime(peak_areas_diagnostic.index[0], '%Y%m%d_%H%M%S')).total_seconds() for timestamp in peak_areas_diagnostic.index]

# Plot the peak ratio over time
plt.figure(figsize=(8, 8))
plt.plot(formatted_times, peak_ratio, marker="o", linestyle="None", markersize=10)
plt.axvline(
    x=formatted_times[plateau_index],
    color="r",
    linestyle="--",
    label=f"Plateau after {elapsed_time}",
)
plt.xlabel("Time / min")
plt.ylabel("Conversion")
# format yticks to only have one decimal place
plt.yticks(np.round(np.arange(0.7, 0.9, 0.05), 2))
plt.ylim(min(peak_ratio) - 0.01, max(peak_ratio) + 0.01)
# plt.title(f'Conversion for experiment {exp_code}')
# plt.xticks(rotation=45)
plt.legend()
# plt.savefig(exp_code + '_conversion.eps', format='eps', bbox_inches='tight')
plt.savefig(
    exp_code + "_conversion.svg", bbox_inches="tight", dpi=300, transparent=True
)
plt.show()

In [None]:
PPM_PRECISION = 4
SOLVENT_PEAKS = [1.16, 3.44]
# HIGHPASS = 1000

# ppm in reverse order because NMR convention is to have ppm decrease from left to right
# start_ppm, end_ppm = 0.0, -200.0
start_ppm, end_ppm = 12.0, 0.0

timestamps = [filepath.split("\\")[-1][:15] for filepath in processed_datafiles]

# Save spectrum objects in a dict
spectra = {}


def highpass_filter(spectrum: SpinsolveNMRSpectrum, threshold: float, inplace=True):
    s = spectrum
    s.y_data = np.where(s.y_data.real > threshold, s.y_data.real, 0)
    return s


def cut(spectrum: SpinsolveNMRSpectrum, low_ppm: float, high_ppm: float, inplace=True):
    s = spectrum
    selector = (s.x_data < low_ppm) | (s.x_data > high_ppm)
    s.x_data = s.x_data[selector]
    s.y_data = s.y_data[selector]
    return s


def trim(spectrum: SpinsolveNMRSpectrum, low_ppm: float, high_ppm: float, inplace=True):
    s = spectrum
    selector = (s.x_data >= low_ppm) & (s.x_data <= high_ppm)
    s.x_data = s.x_data[selector]
    s.y_data = s.y_data[selector]
    return s


# process first spectrum individually to get common x-axis
spectrum = SpinsolveNMRSpectrum().from_pickle(processed_datafiles[0])
trim(spectrum, end_ppm, start_ppm)
for solvent_peak in SOLVENT_PEAKS:
    cut(spectrum, solvent_peak - 0.5, solvent_peak + 0.5)

common_x_axis = np.round(spectrum.x_data, decimals=PPM_PRECISION)

print("Creating array of spectra...")
spectra_array = np.zeros((len(processed_datafiles), len(common_x_axis)))
for i in tqdm(range(len(processed_datafiles))):
    spectrum = SpinsolveNMRSpectrum().from_pickle(processed_datafiles[i])
    # highpass_filter(spectrum, HIGHPASS)
    indices = np.digitize(common_x_axis, spectrum.x_data)
    spectra_array[i] = spectrum.y_data.real[indices]

In [None]:
from scipy.integrate import trapezoid

nmr_error = 1e3


# Calculate Jaccard index for pair of spectra
def jaccard_next_two_spectra(index_1, index_2, reference_spectrum=None):
    if reference_spectrum is None:
        spectrum1_y = spectra_array[index_1]
    else:
        spectrum1_y = reference_spectrum
    spectrum2_y = spectra_array[index_2]
    union, intersection = [], []
    for y1, y2 in zip(spectrum1_y, spectrum2_y):
        if abs(y1 - y2) < nmr_error:
            union.append(y2)
            intersection.append(y2)
        else:
            union.append(max(y1, y2))
            intersection.append(min(y1, y2))

    intersection_area = trapezoid(intersection, common_x_axis)
    union_area = trapezoid(union, common_x_axis)
    return intersection_area / union_area


print("Calculating Jaccard indices...")
jaccards = {}
for i in tqdm(range(len(timestamps))):
    jaccards[timestamps[i]] = jaccard_next_two_spectra(0, i)

print("Calculating reverse Jaccard indices...")
reverse_jaccards = {}
for i in tqdm(range(len(timestamps))):
    reverse_jaccards[timestamps[i]] = jaccard_next_two_spectra(i, len(timestamps) - 1)

print("Calclating neighbor Jaccard indices...")
neighbor_jaccards = {}
for i in tqdm(range(len(timestamps) - 1)):
    neighbor_jaccards[timestamps[i]] = jaccard_next_two_spectra(i, i + 1)

In [None]:
times = [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in jaccards.keys()]
plt.figure(figsize=(24, 6), dpi=100)
plt.plot(
    times,
    jaccards.values(),
    marker="o",
    label=f"Jaccard index {start_ppm} to {end_ppm} ppm",
)
plt.plot(
    times,
    reverse_jaccards.values(),
    marker="o",
    label=f"Reverse Jaccard index {start_ppm} to {end_ppm} ppm",
)
plt.plot(
    times[:-1],
    neighbor_jaccards.values(),
    marker="o",
    label=f"Neighbor Jaccard index {start_ppm} to {end_ppm} ppm",
)

plt.xticks(times[::2], times[::2], rotation=90)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Jaccard index")
plt.title(f"Jaccard index for {exp_code}")
plt.show()

In [None]:
elapsed_time, plateau_index = detect_plateau_from_slope(
    list(jaccards.keys()), list(jaccards.values()), num_datapoints=5, threshold=1e-3
)

In [None]:
times = [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in jaccards.keys()]
times_min = [(time - times[0]).total_seconds() / 60 for time in times]

plt.figure(figsize=(8, 8))
plt.plot(
    times_min,
    jaccards.values(),
    marker="o",
    label=f"Jaccard index {start_ppm} to {end_ppm} ppm",
    linestyle="None",
    markersize=10,
)
# plt.plot(times_min, reverse_jaccards.values(), marker='o', label=f'Reverse Jaccard index {start_ppm} to {end_ppm} ppm')
# plot vertical line at plateau index
plt.axvline(
    x=times_min[plateau_index],
    color="r",
    linestyle="--",
    label=f"Plateau after {elapsed_time}",
)

plt.legend().get_frame().set_linewidth(3)

plt.xlabel("Time / min")
plt.ylabel("Jaccard index")
plt.yticks(np.round(np.arange(0.7, 1.0, 0.1), 1))
plt.ylim(min(jaccards.values()) - 0.01, max(jaccards.values()) + 0.01)

# plt.title(f'Jaccard index for {exp_code}')#\n excluding signals from {SOLVENT_PEAKS} ppm')
# plt.savefig(exp_code + '_jaccard_error_0.eps', format='eps', bbox_inches='tight')
plt.savefig(
    f"figures/{exp_code}_jaccard.svg", bbox_inches="tight", dpi=300, transparent=True
)
plt.show()

In [None]:
np.savetxt(
    exp_code + "plotdata_jaccard.csv",
    np.column_stack((times_min, list(jaccards.values()))),
    header="Time, Jaccard index",
)

In [None]:
from scipy.optimize import curve_fit

PLATEAU_THRESHOLD_MARGIN = 0.01
PLATEAU_DATAPOINTS = 5

# exclude outliers with jaccard index < 0.1 from jaccards dict
jaccards_filtered = {k: v for k, v in jaccards.items() if v >= 0.0}
times = [
    datetime.strptime(timestamp, "%Y%m%d_%H%M%S")
    for timestamp in jaccards_filtered.keys()
]

datapoints = np.array(list(jaccards_filtered.values()))


# Define fit function for reaction of first order (exponential decay)
def exp_decay(x, a, b, c):
    return a * np.exp(-b * x) + c


# Define fit function for reaction of second order
def second_order(x, a, b, c):
    return 1 / (1 / a + b * x) + c


# Convert the timestamps to elapsed time in seconds
elapsed_time_sec = np.array([(time - times[0]).total_seconds() for time in times])

# Fit the exponential decay function to the Jaccard index time series
popt1, pcov1 = curve_fit(exp_decay, elapsed_time_sec, datapoints, p0=(1, 1e-3, 0.5))

# Fit the second order function to the Jaccard index time series
popt2, pcov2 = curve_fit(second_order, elapsed_time_sec, datapoints, p0=(1, 1e-3, 0.5))

# find out where the plateau starts
# check if 5 datapoints in a row are within 1% of the c parameter of the fitted function
# if so, the reaction has reached the plateau
plateau_index = len(datapoints) - 1
elapsed_time = None
for i in range(len(datapoints) - PLATEAU_DATAPOINTS):
    if i < PLATEAU_DATAPOINTS:
        continue
    if all(
        abs(datapoints[i - PLATEAU_DATAPOINTS : i] - popt1[2])
        < PLATEAU_THRESHOLD_MARGIN
    ):
        print(i)
        print(f"Reaction has reached plateau at {list(jaccards.keys())[i]}")
        plateau_index = i
        elapsed_time = times[plateau_index] - times[0]
        break

# Plot the fitted function
plt.plot(
    times_min,
    exp_decay(elapsed_time_sec, *popt1),
    "r-",
    label="fit 1st: a=%.3f, b=%.3e, c=%.3f" % tuple(popt1),
)
# plt.plot(times, second_order(elapsed_time_sec, *popt2), 'g-', label='fit 2nd: a=%.3f, b=%.3e, c=%.3f' % tuple(popt2))
plt.plot(
    times_min, jaccards.values(), marker="o", linestyle="None"
)  # , label=f'Jaccard index {start_ppm} to {end_ppm} ppm')
plt.axvline(
    x=times_min[plateau_index],
    color="r",
    linestyle="--",
    label=f"Plateau after {elapsed_time}",
)
plt.xticks(rotation=45)
plt.legend()
plt.xlabel("Time / min")
plt.ylabel("Jaccard index")
plt.title(
    f"Jaccard index for {exp_code}"
)  # \n excluding signals from {SOLVENT_PEAKS} ppm')
# plt.savefig(exp_code + '_jaccard_exp.eps', format='eps', bbox_inches='tight')
# plt.savefig(exp_code + '_jaccard_exp.png', bbox_inches='tight', transparent=True, dpi=300)
plt.show()

In [None]:
from sklearn.metrics import r2_score

# Calculate R-squared value for the first fit
r2_1 = r2_score(datapoints, exp_decay(elapsed_time_sec, *popt1))

# Calculate R-squared value for the second fit
r2_2 = r2_score(datapoints, second_order(elapsed_time_sec, *popt2))

print("R-squared value for the first fit:", r2_1)
# parameters for first fit including errors
print("a = ", popt1[0], "+/-", np.sqrt(pcov1[0, 0]))
print("b = ", popt1[1], "+/-", np.sqrt(pcov1[1, 1]))
print("c = ", popt1[2], "+/-", np.sqrt(pcov1[2, 2]))

print("R-squared value for the second fit:", r2_2)

In [None]:
transposed = spectra_array.T
nmr_stds = np.zeros(len(transposed))
nmr_cvs = np.zeros(len(transposed))
nmr_rel_stds = np.zeros(len(transposed))

for i, y_values in enumerate(transposed):
    nmr_stds[i] = np.std(y_values)
    nmr_cvs[i] = np.std(y_values) / np.mean(y_values)
    nmr_rel_stds[i] = np.std(y_values) / np.max(y_values)

print(
    f"Mean NMR std: {np.mean(nmr_stds)}, mean NMR cv: {np.mean(nmr_cvs)}, mean NMR rel std: {np.mean(nmr_rel_stds)}"
)
print(
    f"Median NMR std: {np.median(nmr_stds)}, median NMR cv: {np.median(nmr_cvs)}, median NMR rel std: {np.median(nmr_rel_stds)}"
)
print(
    f"Max NMR std: {np.max(nmr_stds)}, max NMR cv: {np.max(nmr_cvs)}, max NMR rel std: {np.max(nmr_rel_stds)}"
)
print(
    f"Min NMR std: {np.min(nmr_stds)}, min NMR cv: {np.min(nmr_cvs)}, min NMR rel std: {np.min(nmr_rel_stds)}"
)
print(
    f"Max std as percentage of overall max value: {np.max(nmr_stds) / np.max(np.max(spectra_array)) * 100}%"
)

In [None]:
import matplotlib.pyplot as plt

plt.hist(transposed[0], bins=10)
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.title(f"Histogram of Row {1}")
plt.show()

In [None]:
times = [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in jaccards.keys()]
times_min = [(time - times[0]).total_seconds() / 60 for time in times]
datapoints = np.array(list(jaccards.values()))

In [None]:
# Assuming y is your data
window_size = 20
moving_avg = np.convolve(datapoints, np.ones(window_size) / window_size, mode="valid")

plt.plot(times_min[: len(moving_avg)], moving_avg)
plt.xlabel("X-axis")
plt.ylabel("Smoothed Y-axis")
plt.title("Moving Average")
plt.show()

In [None]:
from statsmodels.nonparametric.smoothers_lowess import lowess

# Assuming x and y are your data
smooth_data = lowess(datapoints, times_min, frac=0.1)

plt.plot(times_min, smooth_data[:, 1])
plt.xlabel("X-axis")
plt.ylabel("Smoothed Y-axis")
plt.title("LOESS Smoothing")
plt.show()

In [None]:
from scipy.stats import spearmanr

# Assuming x and y are your data
correlation, _ = spearmanr(times_min, datapoints)
print(f"Spearman's Rank Correlation: {correlation}")

In [None]:
spear_corr = []
for index in range(len(times_min)):
    correlation, _ = spearmanr(times_min[:index], datapoints[:index])
    spear_corr.append(correlation)

plt.plot(
    times_min,
    spear_corr,
    marker="o",
    label=f"Spearman correlation {start_ppm} to {end_ppm} ppm",
)
plt.xlabel("Time / min")
plt.ylabel("Spearman correlation")
plt.title(f"Spearman correlation for {exp_code}")
# plt.xticks(rotation=45)
plt.legend()
# plt.ylim(-1.1,0)
# plt.savefig(exp_code + '_spearman.eps', format='eps', bbox_inches='tight')
plt.show()

In [None]:
np.savetxt(
    "plotdata.csv", np.column_stack((times_min, spear_corr)), header="Time, Spearman"
)

In [None]:
num_datapoints = 5
spear_corr = []
for index in range(num_datapoints, len(elapsed_time_sec)):
    correlation, _ = spearmanr(
        elapsed_time_sec[index - num_datapoints : index],
        datapoints[index - num_datapoints : index],
    )
    spear_corr.append(correlation)

plt.plot(
    elapsed_time_sec[num_datapoints:],
    spear_corr,
    marker="o",
    label=f"Spearman correlation {start_ppm} to {end_ppm} ppm",
)
plt.xlabel("Time")
plt.ylabel("Spearman correlation")
plt.title(f"Spearman correlation for {exp_code}")
plt.xticks(rotation=45)
plt.legend()
# plt.ylim(-1.1,0)
plt.show()

In [None]:
import numpy as np

U, S, V = np.linalg.svd(spectra_array, full_matrices=False)

print(U.shape, S.shape, V.shape)

In [None]:
reconstructed_spectra = np.transpose(V) * S

In [None]:
reconstructed_spectra = V[:, : S.shape[0]] * S

In [None]:
reconstructed_timepoints = U * S

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)

ax1.plot(reconstructed_timepoints[:, 0])
ax1.set_ylabel("SVD 0")

ax2.plot(reconstructed_timepoints[:, 1])
ax2.set_ylabel("SVD 1")

plt.xlabel("experiment number")
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)

ax1.plot(common_x_axis, reconstructed_spectra[:, 0])
ax1.set_ylabel("SVD 0")

ax2.plot(common_x_axis, reconstructed_spectra[:, 1])
ax2.set_ylabel("SVD 1")

# invert x axis
ax1.set_xlim(ax1.get_xlim()[::-1])
plt.xlabel("chemical shift / ppm")
plt.show()

In [None]:
plt.plot(S)
# make y axes logarithmic
plt.yscale("log")

In [None]:
def find_peak(
    wavelengths: np.array, intensities: np.array, peak: float, width: float
) -> tuple[int, float, float]:
    if len(wavelengths) != len(intensities):
        raise ValueError("x_values and intensity must have the same length")
    selector = (wavelengths >= peak - width) & (wavelengths <= peak + width)
    index_of_max_intensity = intensities[selector].argmax()
    max_wavelength = wavelengths[selector][index_of_max_intensity]
    max_intensity = intensities[selector][index_of_max_intensity]
    return index_of_max_intensity, max_wavelength, max_intensity

In [None]:
peak_shifts = []
for datafile in tqdm(processed_datafiles):
    spectrum.load_data(datafile)
    _, shift, _ = find_peak(spectrum.x, spectrum.y, peak=3.5, width=0.5)
    peak_shifts.append(shift)

In [None]:
plt.plot(times_min, peak_shifts, marker=".", linestyle="None")
plt.xlabel("Time / min")
plt.ylabel("Peak position / ppm")
plt.title("Tracking the peak between 3 and 4 ppm")