In [None]:
from AnalyticalLabware.analysis.spinsolve_spectrum import SpinsolveNMRSpectrum
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import os

In [None]:
HERE = os.path.abspath(os.getcwd())

In [None]:
spectrum = SpinsolveNMRSpectrum()

In [None]:
exp_code = "RR-M3-5bb-moni-v3"

In [None]:
from scipy.stats import linregress


def detect_plateau_from_slope(
    times, values, num_datapoints: int = 5, threshold: float = 1e-4
) -> tuple:
    """
    Detects a plateau in a list of values by calculating the slope of a linear regression
    for the last num_datapoints and comparing it to a threshold value.
    """
    for i in range(len(values) - num_datapoints):
        if i < num_datapoints:
            continue
        slope = linregress(
            range(num_datapoints),
            np.array(values[i - num_datapoints : i]) / max(values),
        ).slope
        if abs(slope) < threshold:
            elapsed_time = times[i] - times[0]
            print(f"Reaction has reached plateau after {elapsed_time}")
            return (elapsed_time, i)
    return (None, len(values) - 1)

In [None]:
def get_folder_paths(parent_folder):
    folder_paths = []
    for dir in tqdm(os.scandir(parent_folder)):
        if (
            dir.is_dir()
            and dir.name not in ["processed", "averaged"]
            and "STANDBY" not in dir.name
        ):
            folder_paths.append(dir.path)
    return folder_paths


# specify the parent folder containing the experiment data
parent_folder = os.path.join("DATAPATH", exp_code)
rm_folder = [dir for dir in os.listdir(parent_folder) if "RM" in dir][0]
datapaths = get_folder_paths(os.path.join(parent_folder, rm_folder))
print(f"Found {len(datapaths)} datapaths in {rm_folder}")

In [None]:
spectrum.load_spectrum(datapaths[0], preprocessed=False)

fid_ref = spectrum.y_data
spectral_width = spectrum.udic[0]["sw"]
print("Spectral width (Hz): ", spectral_width)

# calculate sampling rate in Hz from x axis
sampling_rate = 1 / (spectrum.x_data[1] - spectrum.x_data[0])
print("Sampling rate (Hz): ", sampling_rate)

# Assume fid is complex FID of the solvent peak
phase = np.unwrap(np.angle(fid_ref))[:10000]
time = np.arange(len(fid_ref))[:10000] / sampling_rate
slope, intercept = np.polyfit(time, phase, 1)

drift_Hz = slope / (2 * np.pi)
print("Drift (Hz): ", drift_Hz)

spectrum.default_processing()
default_ppm = spectrum.x_data
spectrum.find_peaks()
spectrum.show_spectrum()

In [None]:
import numpy as np
from scipy.signal import correlate


def cross_correlation_similarity(fid1, fid2):
    # Zero-mean signals (optional, to remove DC bias)
    fid1 = fid1 - np.mean(fid1)
    fid2 = fid2 - np.mean(fid2)

    # Full cross-correlation
    corr = correlate(fid1, fid2, mode="full")
    max_corr = np.max(np.abs(corr))

    # Normalize by autocorrelation to get similarity score
    norm = np.sqrt(np.sum(np.abs(fid1) ** 2) * np.sum(np.abs(fid2) ** 2))
    similarity = max_corr / norm

    return similarity

In [None]:
from copy import deepcopy


timestamps: list[int] = []
raw_spectra: list[SpinsolveNMRSpectrum] = []
fids: list[np.ndarray] = []
similarities: list[float] = []

for datapath in tqdm(datapaths):
    spectrum.load_spectrum(datapath, preprocessed=False)
    timestamps.append(spectrum.timestamp)
    raw_spectra.append(deepcopy(spectrum))
    fids.append(spectrum.y_data)

    fid = spectrum.y_data
    similarity = cross_correlation_similarity(fid_ref, fid)
    similarities.append(similarity)

In [None]:
times_min = [(ts - timestamps[0]) / 60 for ts in timestamps]
print(f"Total experiment time: {times_min[-1]} min")
print(f"Mean sampling interval: {np.mean(np.diff(times_min))*60} sec")

In [None]:
import nmrglue as ng

processed_data = []
for fid in fids:
    # apodization
    temp = ng.proc_base.gm(data=fid, g1=1.2 / spectral_width, g2=4.5 / spectral_width)
    # zero-filling
    temp = ng.proc_base.zf_double(data=temp, n=1)
    # Fourier transform
    temp = ng.proc_base.fft(data=temp)
    processed_data.append(temp)

In [None]:
from AnalyticalLabware.analysis.base_spectrum import GenericSpectrum
from AnalyticalLabware.analysis.spec_utils import jaccard_two_spectra


def trim_cut(x_axis, data, start_ppm: float, end_ppm: float, peaks: list[float]):
    selector = (x_axis >= end_ppm) & (x_axis <= start_ppm)
    x_axis = x_axis[selector]
    data = data[selector]
    for peak in peaks:
        selector = (x_axis < peak - 0.5) | (x_axis > peak + 0.5)
        x_axis = x_axis[selector]
        data = data[selector]
    return x_axis, data


def get_jaccard_list(x_axis, data, ref_index=0, start_ppm=15, end_ppm=0):
    PPM_PRECISION = 4
    nmr_error = 1e3
    SOLVENT_PEAKS = []

    # process reference spectrum individually
    ref_y = data[ref_index]
    ref_x, ref_y = trim_cut(x_axis, data[ref_index], start_ppm, end_ppm, SOLVENT_PEAKS)

    reference_spectrum = GenericSpectrum()
    reference_spectrum.load_data(ref_x, ref_y, None)

    print("Calculating Jaccard indices...")
    jaccards = []
    for dataset in tqdm(data):
        # process each spectrum
        spectrum_x, spectrum_y = trim_cut(
            x_axis, dataset, start_ppm, end_ppm, SOLVENT_PEAKS
        )
        spectrum = GenericSpectrum()
        spectrum.load_data(spectrum_x, spectrum_y, None)
        jaccards.append(
            jaccard_two_spectra(reference_spectrum, spectrum, PPM_PRECISION, nmr_error)
        )

    return jaccards

Moving average of single FIDs

In [None]:
window_size = 16

averaged_data = []
for i, fid in enumerate(tqdm(fids)):
    fid_ma = np.mean(
        fids[max(0, i - window_size) : min(len(fids), i + window_size)], axis=0
    )

    # apodization
    temp = ng.proc_base.gm(
        data=fid_ma, g1=1.2 / spectral_width, g2=4.5 / spectral_width
    )
    # zero-filling
    temp = ng.proc_base.zf_double(data=temp, n=1)
    # Fourier transform
    temp = ng.proc_base.fft(data=temp)
    averaged_data.append(temp)

In [None]:
jacc_con = get_jaccard_list(default_ppm, processed_data)
jacc_ave = get_jaccard_list(default_ppm, averaged_data)

In [None]:
plt.rcParams.update(
    {
        "font.size": 24,
        "axes.linewidth": 3,
        "xtick.major.width": 3,
        "ytick.major.width": 3,
        "xtick.minor.width": 3,
        "ytick.minor.width": 3,
    }
)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
x_axis = times_min

plt.plot(
    x_axis, jacc_con, marker="o", linestyle="None", label="Conventional processing"
)

plt.plot(x_axis, jacc_ave, marker="o", linestyle="None", label="Averaged processing")

plt.legend().get_frame().set_linewidth(3)

plt.xlabel("Time / min")
plt.ylabel("Jaccard index")
plt.title(f"Jaccard index from 0 to 15 ppm")
plt.savefig(
    os.path.join(f"{exp_code}_jaccard_full.svg"),
    dpi=300,
    bbox_inches="tight",
    transparent=True,
)
plt.show()

In [None]:
jacc_con_trim = get_jaccard_list(default_ppm, processed_data, start_ppm=7, end_ppm=5)
jacc_ave_trim = get_jaccard_list(default_ppm, averaged_data, start_ppm=7, end_ppm=5)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
x_axis = times_min

plt.plot(
    x_axis, jacc_con_trim, marker="o", linestyle="None", label="Conventional processing"
)

plt.plot(
    x_axis, jacc_ave_trim, marker="o", linestyle="None", label="Averaged processing"
)

plt.legend().get_frame().set_linewidth(3)

plt.xlabel("Time / min")
plt.ylabel("Jaccard index")
plt.title(f"Jaccard index from 5 to 7 ppm")
plt.savefig(
    os.path.join(f"{exp_code}_jaccard_5to7.svg"),
    dpi=300,
    bbox_inches="tight",
    transparent=True,
)
plt.show()

In [None]:
def find_peak(
    wavelengths: np.ndarray, intensities: np.ndarray, peak: float, width: float
) -> tuple[int, float, float]:
    if len(wavelengths) != len(intensities):
        raise ValueError("x_values and intensity must have the same length")
    selector = (wavelengths >= peak - width) & (wavelengths <= peak + width)
    index_of_max_intensity = int(intensities[selector].argmax())
    max_wavelength: float = wavelengths[selector][index_of_max_intensity]
    max_intensity: float = intensities[selector][index_of_max_intensity]
    return index_of_max_intensity, max_wavelength, max_intensity

In [None]:
peak_shifts = []
for data in tqdm(averaged_data):
    _, shift, _ = find_peak(default_ppm, data, peak=11, width=2)
    peak_shifts.append(shift)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

plt.plot(times_min, peak_shifts, marker="o", linestyle="None")
# plt.legend()
ax.invert_yaxis()
plt.xlabel("Time / min")
plt.ylabel("Peak position / ppm")
plt.title(f"TCA peak position")
plt.savefig(
    os.path.join(f"{exp_code}_peak_shifts.svg"),
    dpi=300,
    bbox_inches="tight",
    transparent=True,
)
plt.show()

In [None]:
ave_similarities = []
ave_fid_ref = None

for i in tqdm(range(len(fids))):
    fid_ma = np.mean(
        fids[max(0, i - window_size) : min(len(fids), i + window_size)], axis=0
    )
    if ave_fid_ref is None:
        ave_fid_ref = fid_ma

    similarity = cross_correlation_similarity(ave_fid_ref, fid_ma)
    ave_similarities.append(similarity)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

# conventional processing
plt.plot(
    times_min,
    similarities,
    marker="o",
    linestyle="None",
    label="Conventional processing",
)

# averaged processing
plt.plot(
    times_min,
    ave_similarities,
    marker="o",
    linestyle="None",
    label="Averaged processing",
)

plt.legend().get_frame().set_linewidth(3)

plt.xlabel("Time (min)")
plt.ylabel("Similarity")
plt.title("Cross correlation of FID data")
plt.savefig(
    os.path.join(f"{exp_code}_ave_FID_similarity.svg"),
    dpi=300,
    bbox_inches="tight",
    transparent=True,
)
plt.show()