# Preprocessing of the raw training data

## 1. Loading libraries, parameters and data

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import json

from joblib import Parallel, delayed

from scipy.signal import savgol_filter

In [None]:
from ddae1d.utils import despike_iterative, baseline_correction
from ddae1d.paths import PROJECT_ROOT

### Loading parameters

Modify them in `./config.json` if needed.

In [None]:
with open("config.json", "r") as f:
    config = json.load(f)

In [None]:
raw_trainset_filename = config["raw_trainset_filename"]

trim_spectra = config["trim_spectra"]
spectral_ends = config["spectral_ends"]

apply_savgol = config["apply_savgol"]
savgol_params = config["savgol_params"]

apply_despike = config["apply_despike"]
n_jobs_despike = config["n_jobs_despike"]
despike_params = config["despike_params"]

apply_baseline_correction = config["apply_baseline_correction"]
polyfit_degree = config["polyfit_degree"]

### Loading raw training data

In [None]:
raw_trainset = np.load(PROJECT_ROOT / "data" / "raw" / "trainset" / raw_trainset_filename)

## 2. Trimming spectral ends and substracting offset

### *If needed, trim the spectral ends to remove artefacts such as Rayleigh scattering*

Modify the values of `trim_spectra` and `spectral_ends` in `./config.json` if needed.

In [None]:
if trim_spectra:
    # Trim spectral ends
    trainset = raw_trainset[:, :, spectral_ends[0]:spectral_ends[1]]
else:
    trainset = raw_trainset

In [None]:
del raw_trainset

***It is necessary for the resulting number of spectral channels to be divisible by the highest possible power of 2, for the autoencoder architecture to work properly. Indeed, the encoding and decoding process involves multiple downsampling and upsampling steps, typically by a factor of 2. If the input size is not compatible with these operations, it can lead to issues such as mismatched dimensions during the reconstruction phase.***

### *If relevant, subtract an 'offset' baseline - common to the whole dataset. This should be done from the raw spectra before further preprocessing.*

#### Suggested method for computing the offset: Savitzky-Golay smoothed spectrum with lowest mean intensity

*Tune Savitzky-Golay filter parameters in `config.json`*

In [None]:
# If Savitzky-Golay smoothing is enabled, subtract a smoothed baseline from all spectra
if apply_savgol:
    # Find the index of the spectrum with the lowest mean intensity
    lowest_mean_idx = tuple(int(arr[0]) for arr in np.where(np.nanmean(trainset, axis=2) == np.nanmean(trainset, axis=2).min()))
    print(f"Index of spectrum with lowest mean intensity: {lowest_mean_idx}")

    # Extract the spectrum with the lowest mean intensity
    lowest_mean_spectrum = trainset[lowest_mean_idx]

    # Show the Savitzky-Golay filter parameters
    print(f"Savitzky-Golay filter parameters: {savgol_params}")

    # Smooth the lowest mean intensity spectrum using Savitzky-Golay filter
    smoothed_spectrum = savgol_filter(lowest_mean_spectrum, **savgol_params)

    # Plot the original and smoothed spectrum for visual inspection
    plt.figure(figsize=(10, 6))
    plt.plot(lowest_mean_spectrum, label="Original", alpha=0.5)
    plt.plot(smoothed_spectrum, label="Smoothed", linewidth=2)
    plt.xlabel("Spectral Channel")
    plt.ylabel("Intensity (a.u.)")
    plt.title("Lowest Mean Intensity Spectrum (Original vs Smoothed)")
    plt.xlim(0, len(lowest_mean_spectrum) - 1)
    plt.legend()
    plt.tight_layout()
    plt.show()
    
    # Subtract the smoothed baseline from all spectra in the training set
    trainset_after_offset = trainset - smoothed_spectrum
else:
    trainset_after_offset = trainset

In [None]:
del trainset

## 3. Despiking (parallel processing)

Tune the number of parallel jobs: `n_jobs_despike` in `config.json`

In [None]:
def despike_task(point, repetition):
    despiked_spectrum, spikes_spectrum = despike_iterative(
        trainset_after_offset[point, repetition, :],
        **despike_params
    )
    return (point, repetition, despiked_spectrum, spikes_spectrum)

In [None]:
if apply_despike:
    print(f"Number of jobs for parallel despiking: {n_jobs_despike % (os.cpu_count() + 1)}")
    spikes = np.zeros_like(trainset_after_offset)
    trainset_after_despike = np.zeros_like(trainset_after_offset)
    print("Starting parallel despiking of final map...")
    print(f"Total spectra to despike (tasks to do): {trainset_after_offset.shape[0]}")
    results = Parallel(n_jobs=n_jobs_despike, verbose=1)(
        delayed(despike_task)(point, repetition)
        for point in range(trainset_after_offset.shape[0])
        for repetition in range(trainset_after_offset.shape[1])
    )
    for point, repetition, despiked_spectrum, spikes_spectrum in results:
        trainset_after_despike[point, repetition, :] = despiked_spectrum
        spikes[point, repetition, :] = spikes_spectrum
    del results
    nonzero_indices = np.argwhere(np.any(spikes != 0, axis=2)).squeeze()
    percentage_nonzero = len(nonzero_indices) / (trainset_after_despike.shape[0] * trainset_after_despike.shape[1]) * 100
    print(f"Percentage of spectra despiked: {percentage_nonzero:.2f}%")
    print("Number of spectra despiked:", len(nonzero_indices))
else:
    trainset_after_despike = trainset_after_offset

In [None]:
del trainset_after_offset

### Plot 20 random examples of despiked spectra vs removed spikes
Verify that despiking is effective and does not remove relevant signal

In [None]:
if apply_despike:
    n_plot = min(20, len(nonzero_indices))
    selected_indices = nonzero_indices[np.random.choice(len(nonzero_indices), n_plot, replace=False)]

    fig, axes = plt.subplots(4, 5, figsize=(20, 12), sharex=True, sharey=False)
    axes = axes.flatten()

    for ax, (point, repetition) in zip(axes, selected_indices):
        ax.plot(trainset_after_despike[point, repetition], label='Despiked', zorder=1)
        ax.plot((trainset_after_despike + spikes)[point, repetition], label='Removed Spikes', zorder=0)
        ax.set_title(f'Point {point}, Rep {repetition}')
        ax.legend()

    plt.tight_layout()
    plt.show()

## 4. Baseline correction

In [None]:
if apply_baseline_correction:
    trainset_after_baseline = baseline_correction(trainset_after_despike, polyfit_degree)
else:
    trainset_after_baseline = trainset_after_despike

In [None]:
del trainset_after_despike

## 5. Saving preprocessed training data

In [None]:
np.save(PROJECT_ROOT / "data" / "preprocessed" / "trainset" / "noisy.npy", trainset_after_baseline)