In [None]:
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
from glob import glob
from datetime import datetime

In [None]:
HERE = os.path.abspath(os.getcwd())

In [None]:
from AnalyticalLabware.analysis.spinsolve_spectrum import SpinsolveNMRSpectrum


spectrum = SpinsolveNMRSpectrum()

In [None]:
exp_code = "BH_v7_1H"

In [None]:
def get_folder_paths(parent_folder):
    folder_paths = []
    for root, dirs, _ in tqdm(os.walk(parent_folder)):
        for directory in dirs:
            if "processed" not in directory and "Enhanced" not in directory:
                folder_paths.append(os.path.join(root, directory))
    return folder_paths


parent_folder = os.path.join(
    r"C:\Users\rober\Downloads\raw spectra\raw spectra\NMR", exp_code
)
datapaths = get_folder_paths(parent_folder)
print(f"Found {len(datapaths)} data files in {parent_folder}")

In [None]:
# just for getting the total monitoring time
timestamps = [filepath.split("\\")[-1][:15] for filepath in datapaths]
times = sorted(
    [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in timestamps]
)
print("Total experiment time: ", times[-1] - times[0])

In [None]:
PROCESSED = os.path.join(parent_folder, "processed")

if not os.path.exists(PROCESSED):
    os.makedirs(PROCESSED)

for datapath in tqdm(datapaths):
    spectrum.load_spectrum(data_path=datapath)
    # spectrum.default_processing()

    spectrum.apodization(function="gm", g1=1.2, g2=4.5)
    spectrum.zero_fill()
    spectrum.fft()
    spectrum.reference_spectrum(new_position=3.2, reference="highest")
    # spectrum.correct_baseline()
    spectrum.autophase()
    spectrum.correct_baseline()

    spectrum.find_peaks(decimals=1, threshold=0.1)
    spectrum.save_pickle(os.path.join(PROCESSED, os.path.basename(datapath)))

In [None]:
spectrum.show_spectrum()

In [None]:
processed_datafiles = sorted(glob(f"{PROCESSED}/*"))
processed_datafiles = [file for file in processed_datafiles if file.endswith(".pkl")]
timestamps = [filepath.split("\\")[-1][:15] for filepath in processed_datafiles]

In [None]:
from scipy.stats import linregress


def detect_plateau_from_slope(
    times, values, num_datapoints: int = 5, threshold: float = 1e-4
) -> tuple:
    """
    Detects a plateau in a list of values by calculating the slope of a linear regression
    for the last num_datapoints and comparing it to a threshold value.
    """
    for i in range(len(values) - num_datapoints):
        if i < num_datapoints:
            continue
        slope = linregress(
            range(num_datapoints),
            np.array(values[i - num_datapoints : i]) / max(values),
        ).slope
        if abs(slope) < threshold:
            elapsed_time = datetime.strptime(
                times[i], "%Y%m%d_%H%M%S"
            ) - datetime.strptime(times[0], "%Y%m%d_%H%M%S")
            print(f"Reaction has reached plateau after {elapsed_time}")
            return (elapsed_time, i)
    return (None, len(values) - 1)

In [None]:
plt.rcParams.update(
    {
        "font.size": 24,
        "axes.linewidth": 3,
        "xtick.major.width": 3,
        "ytick.major.width": 3,
        "xtick.minor.width": 3,
        "ytick.minor.width": 3,
    }
)

In [None]:
PPM_PRECISION = 4
SOLVENT_PEAKS = [3.2]
# HIGHPASS = 1000

# ppm in reverse order because NMR convention is to have ppm decrease from left to right
# start_ppm, end_ppm = 0.0, -200.0
start_ppm, end_ppm = 12.0, 0.0

timestamps = [filepath.split("\\")[-1][:15] for filepath in processed_datafiles]

# Save spectrum objects in a dict
spectra = {}


def highpass_filter(spectrum: SpinsolveNMRSpectrum, threshold: float, inplace=True):
    s = spectrum
    s.y_data = np.where(s.y_data.real > threshold, s.y_data.real, 0)
    return s


def cut(spectrum: SpinsolveNMRSpectrum, low_ppm: float, high_ppm: float, inplace=True):
    s = spectrum
    selector = (s.x_data < low_ppm) | (s.x_data > high_ppm)
    s.x_data = s.x_data[selector]
    s.y_data = s.y_data[selector]
    return s


def trim(spectrum: SpinsolveNMRSpectrum, low_ppm: float, high_ppm: float, inplace=True):
    s = spectrum
    selector = (s.x_data >= low_ppm) & (s.x_data <= high_ppm)
    s.x_data = s.x_data[selector]
    s.y_data = s.y_data[selector]
    return s


# process first spectrum individually to get common x-axis
spectrum = SpinsolveNMRSpectrum().from_pickle(processed_datafiles[0])
trim(spectrum, end_ppm, start_ppm)
for solvent_peak in SOLVENT_PEAKS:
    cut(spectrum, solvent_peak - 0.5, solvent_peak + 0.5)

common_x_axis = np.round(spectrum.x_data, decimals=PPM_PRECISION)

print("Creating array of spectra...")
spectra_array = np.zeros((len(processed_datafiles), len(common_x_axis)))
for i in tqdm(range(len(processed_datafiles))):
    spectrum = SpinsolveNMRSpectrum().from_pickle(processed_datafiles[i])
    # highpass_filter(spectrum, HIGHPASS)
    indices = np.digitize(common_x_axis, spectrum.x_data)
    spectra_array[i] = spectrum.y_data.real[indices]

In [None]:
from scipy.integrate import trapezoid

nmr_error = 1e3


# Calculate Jaccard index for pair of spectra
def jaccard_next_two_spectra(index_1, index_2, reference_spectrum=None):
    if reference_spectrum is None:
        spectrum1_y = spectra_array[index_1]
    else:
        spectrum1_y = reference_spectrum
    spectrum2_y = spectra_array[index_2]
    union, intersection = [], []
    for y1, y2 in zip(spectrum1_y, spectrum2_y):
        if abs(y1 - y2) < nmr_error:
            union.append(y2)
            intersection.append(y2)
        else:
            union.append(max(y1, y2))
            intersection.append(min(y1, y2))

    intersection_area = trapezoid(intersection, common_x_axis)
    union_area = trapezoid(union, common_x_axis)
    return intersection_area / union_area


print("Calculating Jaccard indices...")
jaccards = {}
for i in tqdm(range(len(timestamps))):
    jaccards[timestamps[i]] = jaccard_next_two_spectra(0, i)

print("Calculating reverse Jaccard indices...")
reverse_jaccards = {}
for i in tqdm(range(len(timestamps))):
    reverse_jaccards[timestamps[i]] = jaccard_next_two_spectra(i, len(timestamps) - 1)

print("Calclating neighbor Jaccard indices...")
neighbor_jaccards = {}
for i in tqdm(range(len(timestamps) - 1)):
    neighbor_jaccards[timestamps[i]] = jaccard_next_two_spectra(i, i + 1)

In [None]:
times = [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in jaccards.keys()]
plt.figure(figsize=(24, 6), dpi=100)
plt.plot(
    times,
    jaccards.values(),
    marker="o",
    label=f"Jaccard index {start_ppm} to {end_ppm} ppm",
)
plt.plot(
    times,
    reverse_jaccards.values(),
    marker="o",
    label=f"Reverse Jaccard index {start_ppm} to {end_ppm} ppm",
)
plt.plot(
    times[:-1],
    neighbor_jaccards.values(),
    marker="o",
    label=f"Neighbor Jaccard index {start_ppm} to {end_ppm} ppm",
)

plt.xticks(times[::2], times[::2], rotation=90)
plt.legend()
plt.xlabel("Time")
plt.ylabel("Jaccard index")
plt.title(f"Jaccard index for {exp_code}")
plt.show()

In [None]:
elapsed_time, plateau_index = detect_plateau_from_slope(
    list(jaccards.keys()), list(jaccards.values()), num_datapoints=5, threshold=1e-3
)

In [None]:
times = [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in jaccards.keys()]
times_min = [(time - times[0]).total_seconds() / 60 for time in times]

plt.figure(figsize=(12, 6))
plt.plot(
    times_min,
    jaccards.values(),
    marker="o",
    label=f"Jaccard index {start_ppm} to {end_ppm} ppm",
    linestyle="None",
    markersize=10,
)
# plt.plot(times_min, reverse_jaccards.values(), marker='o', label=f'Reverse Jaccard index {start_ppm} to {end_ppm} ppm')
# plot vertical line at plateau index
plt.axvline(
    x=times_min[plateau_index],
    color="r",
    linestyle="--",
    label=f"Plateau after {elapsed_time}",
)

plt.legend().get_frame().set_linewidth(3)

plt.xlabel("Time / min")
plt.ylabel("Jaccard index")
# plt.yticks(np.round(np.arange(0.7, 1.0, 0.1), 1))
# plt.ylim(min(jaccards.values()) - 0.01, max(jaccards.values()) + 0.01)

plt.title(f"Jaccard index for {exp_code}\n excluding signals from {SOLVENT_PEAKS} ppm")
plt.savefig(f"{exp_code}_jaccard.svg", bbox_inches="tight", dpi=300, transparent=True)
plt.show()

In [None]:
np.savetxt(
    exp_code + "_plotdata_jaccard.csv",
    np.column_stack((times_min, list(jaccards.values()))),
    header="Time, Jaccard index",
)

In [None]:
from scipy.optimize import curve_fit

PLATEAU_THRESHOLD_MARGIN = 0.01
PLATEAU_DATAPOINTS = 5

# exclude outliers with jaccard index < 0.1 from jaccards dict
jaccards_filtered = {k: v for k, v in jaccards.items() if v >= 0.0}
times = [
    datetime.strptime(timestamp, "%Y%m%d_%H%M%S")
    for timestamp in jaccards_filtered.keys()
]

datapoints = np.array(list(jaccards_filtered.values()))


# Define fit function for reaction of first order (exponential decay)
def exp_decay(x, a, b, c):
    return a * np.exp(-b * x) + c


# Convert the timestamps to elapsed time in seconds
elapsed_time_sec = np.array([(time - times[0]).total_seconds() for time in times])

# Fit the exponential decay function to the Jaccard index time series
popt1, pcov1 = curve_fit(exp_decay, elapsed_time_sec, datapoints, p0=(1, 1e-3, 0.5))

# find out where the plateau starts
# check if 5 datapoints in a row are within 1% of the c parameter of the fitted function
# if so, the reaction has reached the plateau
plateau_index = len(datapoints) - 1
elapsed_time = None
for i in range(len(datapoints) - PLATEAU_DATAPOINTS):
    if i < PLATEAU_DATAPOINTS:
        continue
    if all(
        abs(datapoints[i - PLATEAU_DATAPOINTS : i] - popt1[2])
        < PLATEAU_THRESHOLD_MARGIN
    ):
        print(i)
        print(f"Reaction has reached plateau at {list(jaccards.keys())[i]}")
        plateau_index = i
        elapsed_time = times[plateau_index] - times[0]
        break

# Plot the fitted function
plt.figure(figsize=(12, 6))
plt.plot(
    times_min,
    exp_decay(elapsed_time_sec, *popt1),
    "r-",
    label="fit 1st: a=%.3f, b=%.3e, c=%.3f" % tuple(popt1),
)
plt.plot(
    times_min, jaccards.values(), marker="o", linestyle="None"
)  # , label=f'Jaccard index {start_ppm} to {end_ppm} ppm')
plt.axvline(
    x=times_min[plateau_index],
    color="r",
    linestyle="--",
    label=f"Plateau after {elapsed_time}",
)
plt.legend()
plt.xlabel("Time / min")
plt.ylabel("Jaccard index")
plt.title(f"Jaccard index for {exp_code}")
plt.savefig(
    f"{exp_code}_jaccard_exp.svg", bbox_inches="tight", dpi=300, transparent=True
)
plt.show()

In [None]:
times = [datetime.strptime(timestamp, "%Y%m%d_%H%M%S") for timestamp in jaccards.keys()]
times_min = [(time - times[0]).total_seconds() / 60 for time in times]
datapoints = np.array(list(jaccards.values()))

In [None]:
from scipy.stats import spearmanr

# Assuming x and y are your data
correlation, _ = spearmanr(times_min, datapoints)
print(f"Spearman's Rank Correlation: {correlation}")

In [None]:
spear_corr = []
for index in range(len(times_min)):
    correlation, _ = spearmanr(times_min[:index], datapoints[:index])
    spear_corr.append(correlation)

plt.figure(figsize=(12, 6))
plt.plot(times_min, spear_corr, marker="o", linestyle="None")
plt.xlabel("Time / min")
plt.ylabel("Spearman correlation")
plt.title(f"Spearman correlation for {exp_code}")
plt.savefig(f"{exp_code}_spearman.svg", bbox_inches="tight", dpi=300, transparent=True)
plt.show()

In [None]:
np.savetxt(
    exp_code + "_plotdata_spearman.csv",
    np.column_stack((times_min, spear_corr)),
    header="Time, Spearman",
)