## Todo list

Use this document as an example document for RamanLib.
- [ ] Visualising the effect of preprocessing steps on data

## Next Steps:

Preprocessing:
- Implement automatic substrate removal:
  - https://pubs.rsc.org/en/content/articlelanding/2009/an/b821856k
  - https://doi.org/10.1016%2Fj.forsciint.2013.04.033
- And substrate removal for heterogenous substrate (multivariate curve resolution) (e.g. in forensic analysis)
  - https://www.nature.com/articles/s41596-021-00620-3#ref-CR55
- Implement automatic baseline optimisation from Guo and Bocklitz:
  - https://doi.org/10.1039%2FC6AN00041J


## Useful Sources

RamanSPy docs: https://ramanspy.readthedocs.io/en/latest/

Guo et al Raman Data Analysis SOP paper and data availability: https://www.nature.com/articles/s41596-021-00620-3

Worked example using CLS regression: https://medium.com/data-science/data-science-for-raman-spectroscopy-a-practical-example-e81c56cf25f

Raman Spectrum of E2: https://pmc.ncbi.nlm.nih.gov/articles/PMC6201238/

Carly Shea's slides on E2 peak identification: https://docs.google.com/presentation/d/1n_KoqyFBQshDlwcOG39hBMPghaZAIIzefnRZgJ6n77w/edit?usp=sharing

Baseline subtraction algorithms (pybaselines): https://pybaselines.readthedocs.io/en/latest/algorithms/index.html

Paul Eiler 'A Perfect Smoother' paper: https://doi.org/10.1021/ac034173t

Whitaker-Hayes despiking: https://www.sciencedirect.com/science/article/pii/S0169743918301758

## Import packages

In [1]:
import os
os.chdir('/home/linux_thoma/git/RamanLib/src')

In [4]:
import ramanlib as rl

In [6]:
# Import packages
# import ramanspy as rp
# import pandas as pd
# import matplotlib.pyplot as plt
# import random
# import numpy as np
# import seaborn as sns
# import os
# import re
# from datetime import datetime
# from datetime import date
# import shutil
# import sklearn.linear_model as linear_model
# from scipy.optimize import curve_fit
# from collections import Counter
# import warnings

# RamanLib: adding custom functionality to RamanSPy for personal use and added convenience

In my intitial work with RamanSPy, I have found many frequently-used functionalities which I would like to implement as custom classes / methods in a new library - RamanLib.

Some examples of core functionality include:
- adding a metadata parameter to spectral containers, allowing each spectrum within a container to be grouped by certain attributes of their metadata, like replicate / sample, collection parameters, labelling etc.
- adding plot functionality for plots I have needed such as correlation coefficient plots, comparisons to mean spectra etc.

## GroupedSpectralContainer class
Next step: Make a version that is suited to storing and manipulating spectral images.

In [None]:
class GroupedSpectralContainer:
    # GroupedSpectralContainer is a wrapper for a pandas DataFrame object, which should always have a "spectrum" column of
    # rp.Spectrum objects, and other columns representing corresponding metadata.

    # Basic data manipulation methods. For more advanced functionality, the user is expected to
    # use GroupedSpectralContainer.df to apply changes to the data.
    def __init__(self, spectral_list, metadata):   # metadata is a list of dictionaries
        # Raise errors if data isn't spectrum objects or not the same length as the metadata
        if not all(isinstance(s, rp.Spectrum) for s in spectral_list):
            raise TypeError("All items in spectral_list must be RamanSPy Spectrum objects.")
        if len(spectral_list) != len(metadata):
            raise ValueError("spectral_list and metadata must be the same length.")
        
        rows = [{"spectrum": s, **meta} for s, meta in zip(spectral_list, metadata)]
        self.df = pd.DataFrame(rows)
    
    @classmethod
    def from_dataframe(cls, df):
        # Create a GroupedSpectralContainer using a dataframe, going through the __init__ constructor
        # for increased robustness

        # Check for a spectrum column and the type of the spectrum columns
        if "spectrum" not in df.columns:
            raise ValueError("DataFrame must contain a 'spectrum' column.")
        if not all(isinstance(s, rp.Spectrum) for s in df["spectrum"]):
            raise TypeError("All entries in 'spectrum' column must be Spectrum objects.")
        
        spectra = df["spectrum"].tolist()
        metadata = df.drop(columns=["spectrum"]).to_dict(orient="records")
        return cls(spectra, metadata)

    def copy(self):
        return GroupedSpectralContainer.from_dataframe(self.df.copy())

    def to_spectral_container(self):
        # Return a SpectralContainer with the data of the GroupedSpectralContainer.
        axes = [s.spectral_axis for s in self.df['spectrum']]
        # Check that all spectra within have the same spectral axis.
        if not all((axes[0] == ax).all() for ax in axes[1:]):
            raise ValueError("All spectra must have the same spectral axis to convert to a SpectralContainer.")
        return rp.SpectralContainer.from_stack(self.df["spectrum"].tolist())

    def mean(self, by=None, include_stats=False, ddof=1):
        """
        Compute mean spectra per group and return a new GroupedSpectralContainer
        with one row per group.

        Parameters
        ----------
        by : str | list[str] | callable | None
            Grouping key(s) for pandas .groupby. If None, the whole container is one group.
        include_stats : bool
            If True, add columns: 'n' (group size), 'std_vector', 'var_vector'.
            (Vectors are numpy arrays aligned with the spectrum's spectral_axis.)
        ddof : int
            Delta degrees of freedom for variance/std (default=1 ⇒ sample stats).

        Returns
        -------
        GroupedSpectralContainer
            A new GSC with mean Spectrum per group. Group keys are preserved
            back into metadata columns.
        """
        # Build iterable of (group_key, group_df)
        grouped = [("all", self.df)] if by is None else list(self.df.groupby(by, dropna=False))

        rows = []
        for key, gdf in grouped:
            if gdf.empty:
                continue

            spectra = gdf["spectrum"].tolist()
            container = rp.SpectralContainer.from_stack(spectra)
            mean_spec = container.mean

            # Rehydrate group key(s) into metadata columns
            meta = {}
            if by is None:
                meta["group"] = "all"
            else:
                by_cols = by if isinstance(by, (list, tuple)) else [by]
                # pandas returns tuple keys for multi-column groups
                key_vals = key if isinstance(key, tuple) else (key,)
                meta.update(dict(zip(by_cols, key_vals)))

            # Always include the mean Spectrum
            meta["spectrum"] = mean_spec

            if include_stats:
                var = np.var(container.spectral_data, axis=0, ddof=ddof)
                meta["n"] = container.shape[0]
                meta["var_vector"] = var
                meta["std_vector"] = np.sqrt(var)

            rows.append(meta)

        mean_df = pd.DataFrame(rows)
        return GroupedSpectralContainer.from_dataframe(mean_df)

    def plot_mean(self, by=None, interval=None, plot_type="separate", ci_z=1.96, **kwargs):
        return mean_per_group(self, by=by, interval=interval, plot_type=plot_type, ci_z=ci_z, **kwargs)

    def plot_random(self, by=None, n_samples=3, plot_type="single", seed=None, **kwargs):
        return random_per_group(self, by=by, n_samples=n_samples, plot_type=plot_type, seed=seed, **kwargs)

    def apply_pipeline(self, pipeline):
        return apply_pipeline_to_container(self, pipeline)
    
    def __len__(self):
        return len(self.df)

    def __getitem__(self, key):
        result = self.df[key]
        if isinstance(result, pd.DataFrame):
            return GroupedSpectralContainer.from_dataframe(result)
        return result

    def __repr__(self):
        return f"GroupedSpectralContainer({len(self.df)} spectra)\n\n{self.df.head()}"


## Custom Library Functions

In [None]:
                            ###### Domain-specific functions: ######
# While they provide less flexibility, they are intended to allow for common
# functions to be applied to the data quickly with little customisation. If more functionality is required,
# the user is encouraged to create a copy of the GroupedSpectralContainer to do any required filtering
# and use the to_spectral_container method to use other RamanSPy functions on the spectral data.

In [None]:
# Note: for grouped functions, inputting by=None (the default value for these parameters) causes
# no grouping to take place, so the function is applied to the entire set of data without grouping.

In [None]:
# Note: I have decided not to add plt.show() at the end of plot functions, so that the user has the ability
# to modify the plot within their code before displaying / closing the plot.

### Utilities - To be used only by internal functions, not by users

In [None]:
from typing import List
import matplotlib.axes as maxes

def _normalize_axes_obj(axes_obj) -> List[maxes.Axes]:
    """
    RamanSPy claims to return Axes or list[Axes], but some plot_types (e.g. 'stacked')
    may return a Figure. Normalize to a list[Axes].
    """
    if isinstance(axes_obj, list):
        return axes_obj
    # Single Axes
    if hasattr(axes_obj, "plot") and hasattr(axes_obj, "fill_between"):
        return [axes_obj]
    # Figure
    if hasattr(axes_obj, "get_axes"):
        return axes_obj.get_axes()
    # Unknown
    raise TypeError(f"Unexpected return type from rp.plot.spectra: {type(axes_obj)}")

### Plot Spectra
Functions for plotting spectra

In [None]:
def mean_per_group(gsc, by=None, interval=None, plot_type="separate", ci_z=1.96, **kwargs):
    """
    Plot mean spectrum per group using precomputed statistics from GSC.mean(include_stats=True).
    interval: None | 'ci' | 'sd'
        'ci' plots ± z * (std / sqrt(n)); 'sd' plots ± std.
    """
    # 1) Compute group means + stats once
    means_gsc = gsc.mean(by=by, include_stats=True, ddof=1)
    df = means_gsc.df

    # 2) Prepare means and labels
    group_means = df["spectrum"].tolist()
    
    # Use user labels if provided; otherwise build them. Avoid double-passing via kwargs.
    group_labels = kwargs.pop("label", None)
    if group_labels is None:
        if by is None:
            group_labels = ["all"]
        else:
            grouped = df.groupby(by, dropna=False)
            group_labels = [
                ", ".join(map(str, key)) if isinstance(key, tuple) else str(key)
                for key, _ in grouped
            ]
    spectral_axis = group_means[0].spectral_axis if group_means else None

    # 3) Precompute bands if requested
    bands = [None] * len(df)
    if interval in ("ci", "sd"):
        if "std_vector" not in df or (interval == "ci" and "n" not in df):
            warnings.warn("Required statistics not present; skipping interval bands.")
        else:
            for i, row in df.iterrows():
                std = row["std_vector"]
                if std is None or (isinstance(std, np.ndarray) and std.size == 0):
                    bands[i] = None
                    continue
                if interval == "sd":
                    band = std
                else:  # 'ci'
                    n = int(row["n"])
                    band = (ci_z * std / np.sqrt(n)) if n > 0 else None
                bands[i] = band

    # 4) Plot means
    axes_obj = rp.plot.spectra(group_means, label=group_labels, plot_type=plot_type, **kwargs)
    axes = _normalize_axes_obj(axes_obj)

    # 5) Handle stacked-with-offset limitation
    if (plot_type or "").lower() == "single stacked" and interval is not None:
        warnings.warn("Interval bands disabled for 'single stacked' due to vertical offsets.")
        return axes_obj

    # 6) Overlay bands on the correct axes
    if spectral_axis is not None and any(b is not None for b in bands):
        for ax, mean_spec, band in zip(axes, group_means, bands):
            if band is None:
                continue
            ax.fill_between(
                spectral_axis,
                mean_spec.spectral_data - band,
                mean_spec.spectral_data + band,
                alpha=0.2
            )

    return axes_obj

In [None]:
def random_per_group(gsc, by=None, n_samples=3, plot_type="single", seed=None, **kwargs):
    """
    Plot n random spectra from each group in the container.

    Parameters:
        gsc (GroupedSpectralContainer)
        by (str or list or callable): Metadata column(s) or grouping method.
        n_samples (int): Number of spectra to sample per group.
        plot_type (str): Passed to rp.plot.spectra.
        seed (int or None): Random seed for reproducible sampling. Default None.
        **kwargs: Forwarded to rp.plot.spectra.

    Returns:
        Axes | list[Axes] | Figure: Whatever rp.plot.spectra returns.
    """
    rng = random.Random(seed)  # local RNG

    def _sample_k(spectra, k):
        if len(spectra) == 0:
            return []
        if k <= len(spectra):
            return rng.sample(spectra, k)
        return spectra[:] + rng.choices(spectra, k=k - len(spectra))

    if by is None:
        spectra = gsc.df["spectrum"].tolist()
        spectra_groups = [_sample_k(spectra, n_samples)]
        group_labels = ["all"]
    else:
        grouped = gsc.df.groupby(by)
        spectra_groups, group_labels = [], []
        for key, group_df in grouped:
            sample = _sample_k(group_df["spectrum"].tolist(), n_samples)
            spectra_groups.append(sample)
            group_labels.append(", ".join(map(str, key)) if isinstance(key, tuple) else str(key))

    return rp.plot.spectra(spectra_groups, label=group_labels, plot_type=plot_type, **kwargs)

### Quality Control
Functions for automatically identifying potential outliers in data using statistics.

Next step: Add different algorithms for detecting outliers like citation: [Penny, K. I. & Jolliffe, I. T. A comparison of multivariate outlier detection methods for clinical laboratory safety data. J. R. Stat. Soc. D. 50, 295–307 (2001).]
- Hostelling's T-squared
- Mahalanobis distance
- Q-residuals

Next step: Search for and plot all spectra with a metric outside a certain threshold. Threshold can be mu + t * sigma where mu and sigma are the mean and standard deviation of the metric across all spectra and t is tuned to get a good threshold value.

In [None]:
def outliers_per_group(gsc, metric, by=None, n_spectra=3, highest=True):
    """
    Compute, per group, the indices of the n highest/lowest spectra by a metric
    against the group's mean spectrum. Also returns the group mean (Spectrum).

    Parameters
    ----------
    gsc : GroupedSpectralContainer
    metric : callable
        Like rp.metrics.* with signature metric(spec_a: rp.Spectrum, spec_b: rp.Spectrum) -> float.
    by : str | list[str] | callable | None
        Grouping key for pandas .groupby. If None, the whole dataset is one group.
    n_spectra : int
        Number of spectra to select per group (clipped to group size).
    highest : bool
        If True, select largest metric values; else smallest.

    Returns
    -------
    results : dict[str, tuple[list[int], rp.Spectrum]]
        { group_label: ([row_indices_into_gsc_df], mean_spectrum) }
    """
    grouped = [("all", gsc.df)] if by is None else list(gsc.df.groupby(by))

    results = {}
    for key, group_df in grouped:
        if group_df.empty:
            continue

        spectra = group_df["spectrum"].tolist()
        cont = rp.SpectralContainer.from_stack(spectra)
        mean_spec = cont.mean

        scores = np.array([metric(spec, mean_spec) for spec in spectra])
        order = np.argsort(scores)
        if highest:
            order = order[::-1]

        k = min(n_spectra, len(group_df))
        pick_local = order[:k]
        pick_global = group_df.index.values[pick_local].tolist()

        label = ", ".join(map(str, key)) if isinstance(key, tuple) else str(key)
        results[label] = (pick_global, mean_spec)

    return results

In [None]:
def outliers_per_group(gsc, results, **kwargs):
    """
    Plot outlier spectra per group and overlay the group mean.
    Layout is fixed to 'separate' for robustness. Additional keyword args
    are forwarded to rp.plot.spectra (e.g., color, linewidth, title, ax, etc.).

    Parameters
    ----------
    gsc : GroupedSpectralContainer
    results : dict[str, tuple[list[int], rp.Spectrum]]
        Output of calc_outlier_indices_by_group: {group_label: (row_indices, mean_spectrum)}
    **kwargs :
        Forwarded to rp.plot.spectra, EXCEPT 'plot_type' which is ignored.

    Returns
    -------
    axes_obj :
        Whatever rp.plot.spectra returns (Axes | list[Axes] | Figure).
    """
    if not results:
        return None

    # Don’t allow plot_type injection here to avoid edge cases.
    if "plot_type" in kwargs:
        warnings.warn("plot_type is fixed to 'separate' for this plot; ignoring provided plot_type.")
        kwargs = {k: v for k, v in kwargs.items() if k != "plot_type"}

    group_labels = list(results.keys())

    spectra_groups = []
    means_for_overlay = []
    for label in group_labels:
        idxs, mean_spec = results[label]
        spectra_groups.append(gsc.df.loc[idxs, "spectrum"].tolist())
        means_for_overlay.append(mean_spec)

    axes_obj = rp.plot.spectra(
        spectra_groups,
        label=group_labels,
        plot_type="separate",
        **kwargs
    )

    # Normalize to a list of Axes to overlay the mean
    if isinstance(axes_obj, list):
        axes_list = axes_obj
    elif hasattr(axes_obj, "fill_between"):   # single Axes
        axes_list = [axes_obj]
    elif hasattr(axes_obj, "get_axes"):       # Figure
        axes_list = axes_obj.get_axes()
    else:
        axes_list = []

    # Overlay the mean line (no labels/legend/tight_layout/show here)
    for ax, mean_spec in zip(axes_list, means_for_overlay):
        ax.plot(mean_spec.spectral_axis, mean_spec.spectral_data, color="red", linewidth=1.5)

    return axes_obj

### Calibration
Functions for calibrating based on a calibration spectrum.

Next step: shift wavenumbers of spectra based on the calibration peak of the closest day.

Next step: apply a polynomial calibration based on calibration spectra with multiple peaks (like polystyrene or gas emission spectra)

In [None]:
from scipy.optimize import curve_fit
def fit_peak(calib: rp.Spectrum, plot: bool = False):
    """
    Fit a Gaussian to a single peak in a calibration spectrum (e.g., internal Si).

    Parameters
    ----------
    calib : rp.Spectrum
        Has `.spectral_axis` and `.spectral_data` (1D arrays).
    plot : bool
        If True, plot data + Gaussian fit using RamanSPy. (No plt.show() here.)

    Returns
    -------
    peak_center : float
        Estimated center of the peak (cm⁻¹).
    peak_height : float
        Estimated peak intensity (amplitude).
    sigma : float
        Estimated Gaussian σ (related to FWHM via FWHM = 2*sqrt(2*ln2)*σ).
    """

    # Gaussian model
    def gaussian(x, a, x0, sigma):
        return a * np.exp(-(x - x0) ** 2 / (2 * sigma ** 2))

    x = np.asarray(calib.spectral_axis)
    y = np.asarray(calib.spectral_data)

    # Initial guesses
    a0 = float(np.max(y))
    x0 = float(x[np.argmax(y)])
    sigma0 = 5.0
    p0 = [a0, x0, sigma0]

    # Fit
    popt, _ = curve_fit(gaussian, x, y, p0=p0)
    a_fit, x0_fit, sigma_fit = popt

    if plot:
        # Build a Spectrum for the fitted curve so we can use rp.plot.spectra
        y_fit = gaussian(x, *popt)
        fit_spec = rp.Spectrum(y_fit, x)

        # Plot data + fit in a single axes using RamanSPy
        axes_obj = rp.plot.spectra(
            [calib, fit_spec],
            label=["Data", "Gaussian Fit"],
            plot_type="single"
        )

        # Ensure we have an Axes to draw the vertical line on
        ax = _normalize_axes_obj(axes_obj)[0]
        if ax is not None:
            ax.axvline(x0_fit, linestyle=":", color="r")  # Peak center marker

        # (No plt.show(); let caller decide)

    return x0_fit, a_fit, sigma_fit

In [None]:
def get_wn_shift(calib: rp.Spectrum, expected_x0_value, plot=False):
    '''Fit a gaussian to a single calibration peak of spectrum calib. Return the distance to expected_x0_value.
    
    expected_x0_value is a single number representing the value to correct to.'''
    x0, _, _ = fit_peak(calib, plot=plot)
    return x0 - expected_x0_value

In [None]:
def get_gsc_wn_shifts(
    calib_gsc: GroupedSpectralContainer,
    expected_x0_range,
    expected_x0_value: float,
    in_place: bool = False,
    plot: bool = False,
):
    """
    Fit calibration peaks for all spectra in a GSC and annotate the GSC with:
      - 'x0_fit' : fitted peak center (cm^-1)
      - 'shift'  : x0_fit - expected_x0_value

    Parameters
    ----------
    calib_gsc : GroupedSpectralContainer
        Must have a 'spectrum' column of rp.Spectrum objects.
    expected_x0_range : iterable of two floats
        [low, high] expected cm^-1 window for valid peak centers.
    expected_x0_value : float
        Target cm^-1 value to calibrate to.
    in_place : bool
        If True, modify calib_gsc.df in place. If False, return a new GSC.
    plot : bool
        If True, plot each fit (slow; for debugging small sets).

    Returns
    -------
    fail_gsc : pd.DataFrame                     (if in_place=True)
        Rows from the (now-annotated) calib_gsc.df that fail the range check.
    fail_gsc, new_gsc : (GroupedSpectralContainer, GroupedSpectralContainer)   (if in_place=False)
        A new annotated GSC plus the failing rows DataFrame.
    """
    if expected_x0_range is None or len(expected_x0_range) != 2:
        raise ValueError("expected_x0_range must be a two-element iterable [low, high].")
    low, high = sorted(expected_x0_range)

    # Work on a copy unless modifying in-place
    df = calib_gsc.df if in_place else calib_gsc.df.copy()

    # Compute fits
    x0_vals = []
    shift_vals = []
    for _, spec in df["spectrum"].items():
        try:
            x0_fit, _, _ = fit_peak(spec, plot=plot)
            shift_val = x0_fit - expected_x0_value
        except Exception:
            x0_fit = np.nan
            shift_val = np.nan
        x0_vals.append(x0_fit)
        shift_vals.append(shift_val)

    # Attach columns
    df["x0_fit"] = x0_vals
    df["shift"] = shift_vals

    # Failing mask: NaN or outside range
    mask_fail = ~np.isfinite(df["x0_fit"]) | (df["x0_fit"] < low) | (df["x0_fit"] > high)
    fail_df = df.loc[mask_fail].copy()

    if in_place:
        # Persist columns back to the original df
        calib_gsc.df["x0_fit"] = df["x0_fit"]
        calib_gsc.df["shift"]  = df["shift"]
        return GroupedSpectralContainer.from_dataframe(fail_df)

    # Build new GSCs with the augmented dataframes
    new_gsc = GroupedSpectralContainer.from_dataframe(df)
    fail_gsc = GroupedSpectralContainer.from_dataframe(fail_df)
    
    return fail_gsc, new_gsc

### Preprocessing
Functions for plotting and optimising the baseline correction / preprocessing pipeline. Next step: add Guo et al's baseline optimisation

In [None]:
def apply_pipeline_to_container(container, pipeline):
    df = container.df.assign(spectrum=container.df["spectrum"].apply(pipeline.apply))
    return GroupedSpectralContainer.from_dataframe(df)

In [None]:
def baseline(spectrum, baseline_process, **kwargs):
    '''Plot the baseline resulting from a single baseline process alongside its raw and baseline-subtracted spectra'''
    corrected_spectrum = baseline_process.apply(spectrum)
    baseline = rp.Spectrum(spectrum.spectral_data - corrected_spectrum.spectral_data, spectrum.spectral_axis)
    spectra = [spectrum, baseline, corrected_spectrum]
    labels = ["Original spectrum", "removed baseline", "corrected spectrum"]
    return rp.plot.spectra(spectra, label=labels, plot_type="single", alpha=0.9, **kwargs)

In [None]:
def n_baselines(raw_gsc, baseline_process, process_name, n_samples=3, figsize=(8,7), seed=None):
    '''Plot n randomly selected spectra along with their baselines and corrected spectra in a single figure
    with mulitple subplots'''
    spec_samples = raw_gsc.df.sample(n=n_samples)["spectrum"]
    fig, axs = plt.subplots(n_samples, 1, figsize=figsize)
    for i, spec in enumerate(spec_samples):
        baseline(spec, baseline_process, ax=axs[i], title="", xlabel="")
    fig.suptitle(f"{process_name}")
    plt.xlabel("Wavenumber (cm⁻¹)")
    plt.tight_layout()
    return axs

In [None]:
def compare_baselines(spectrum, baseline_processes, process_names, figsize=(8,7)):
    '''Plot a single spectrum with multiple different baseline processes applied to compare how
    the algorithm handles that spectrum'''
    fig, axs = plt.subplots(len(baseline_processes), 1, figsize=figsize)
    for i, process in enumerate(baseline_processes):
        baseline(spectrum, process, ax=axs[i], title=f"{process_names[i]}", xlabel="")
    plt.xlabel("Wavenumber(cm⁻¹)")
    plt.tight_layout()
    return axs

### Statistical Plots

Next step: PCA Loading and Score Plots (Guo et al SOP point 23)

In [None]:
def mean_difference(group1_stats, group2_stats, ci_z=1.96):
    """
    Compute the difference in mean spectra and 95% confidence interval band.

    group1_stats and group2_stats are gsc rows containing the mean spectrum in column "spectrum" and containing
    stats columns "n", "var_vector" and "std_vector" for group statistics.

    Returns:
        - Spectrum object for the difference
        - np.ndarray for the CI band at each wavenumber
    """

    if any(stat not in group1_stats.df.columns for stat in ["n", "var_vector", "std_vector"]) or (len(group1_stats) != 1):
        raise ValueError("group1_stats missing statistics columns or includes multiple rows. Use include_stats=True in \
GSC.mean() to ensure stats are included.")
    
    s1 = group1_stats["spectrum"].iloc[0]
    s2 = group2_stats["spectrum"].iloc[0]

    diff = s1.spectral_data - s2.spectral_data
    axis = s1.spectral_axis
    diff_spectrum = rp.Spectrum(diff, axis)

    var1 = group1_stats["var_vector"].iloc[0]
    var2 = group2_stats["var_vector"].iloc[0]
    n1 = group1_stats["n"].iloc[0]
    n2 = group2_stats["n"].iloc[0]

    ci_band = ci_z * np.sqrt((var1 / n1) + (var2 / n2))

    return diff_spectrum, ci_band

In [None]:
import warnings
def mean_difference(diff_spectrum, ci_band, label="Difference in Means", **kwargs):
    """
    Plot the difference between two mean spectra with a 95% CI band centered at 0.

    Parameters:
        diff_spectrum (rp.Spectrum): Difference between two mean spectra.
        ci_band (np.ndarray): 1D array of CI boundsz.
        title (str): Plot title.
        color (str): Line color.
        **kwargs: Additional matplotlib parameters forwarded to rp.plot.spectra().
    """

    if "plot_type" in kwargs.keys():
        warnings.warn('Only plot_type="single" is supported for mean_difference')
        kwargs.pop("plot_type")

    ax_obj = rp.plot.spectra(diff_spectrum, label=label, plot_type="single", **kwargs)
    axs = _normalize_axes_obj(ax_obj)
    axs[0].fill_between(diff_spectrum.spectral_axis, -ci_band, ci_band, color='gray', alpha=0.3, label='95% Confidence Band')
    plt.axhline(0, color='gray', linestyle='--', linewidth=1)
    return ax_obj

In [None]:
def mean_correlation_per_group(gsc, by):
    """
    Compute the Pearson correlation matrix between mean spectra of each group in a GroupedSpectralContainer.

    Parameters:
        gsc (GroupedSpectralContainer): The grouped spectral data.
        by (str): Column name to group by.

    Returns:
        correlation_matrix (pd.DataFrame): Correlation matrix between mean spectra of each group.
    """
    group_means_gsc = gsc.mean(by=by)
    spectral_data = [row["spectrum"].spectral_data for _, row in group_means_gsc.df.iterrows()]
    group_keys = [", ".join(map(str, k)) if isinstance(k, tuple) else str(k) for k in group_means_gsc.df.groupby(by).groups.keys()]
    df_group_means = pd.DataFrame({k: v for k, v in zip(group_keys, spectral_data)})
    correlation_matrix = df_group_means.corr(method='pearson')
    return correlation_matrix

In [None]:
def mean_correlation_per_group(
    correlation_matrix,
    title="Correlation Matrix of Raman Spectra",
    vmin=0,
    vmax=1,
    annot=True,
    cmap="coolwarm",
    figsize=(8, 6),
    **kwargs
):
    """
    Plot a heatmap of the correlation matrix between group mean spectra.

    Parameters:
        correlation_matrix (pd.DataFrame): Correlation matrix to plot.
        title (str): Title of the plot.
        vmin (float): Minimum value for heatmap color scale.
        vmax (float): Maximum value for heatmap color scale.
        annot (bool): Whether to annotate cells with values.
        cmap (str): Color map to use for the heatmap.
        figsize (tuple): Figure size.
        **kwargs: Additional keyword arguments passed to seaborn.heatmap().
    """
    plt.figure(figsize=figsize)
    sns.heatmap(correlation_matrix, annot=annot, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
    plt.title(title)

### Classical Least Squares for Spectral Subtraction

In [None]:
def CLS(query_spec, components_spec, component_names, plot=True, verbose=True):
    """
    Perform classical least squares (CLS) spectral unmixing.

    Decomposes a query spectrum into a linear combination of given component spectra
    using linear regression. Returns the fitted coefficients, residual spectrum, and
    individual component contributions. Optionally displays a plot and prints the results.

    Parameters:
        query_spec (rp.Spectrum): The target spectrum to be decomposed.
        components_spec (list of rp.Spectrum): Reference component spectra.
        component_names (list of str): Names corresponding to each component.
        plot (bool): If True, plots the query, residual, and component spectra.
        verbose (bool): If True, prints component names and their corresponding coefficients.

    Returns:
        tuple:
            - cs (np.ndarray): Fitted CLS coefficients.
            - res_spec (rp.Spectrum): Residual spectrum after fitting.
            - fitted_components_spec (list of rp.Spectrum): Scaled component spectra.
    """
    
    # Check that all components have the same number of datapoints as the query spectrum
    assert all(len(query_spec.spectral_axis) == len(component.spectral_axis) for component in components_spec), "All components must have the same spectral axis length as the mixture spectrum"

    # Get the CLS coefficients
    components = np.array([component.spectral_data for component in components_spec])
    query = query_spec.spectral_data
    cs = linear_model.LinearRegression().fit(components.T, query).coef_

    # Calculate the residual spectrum and fitted components spectra
    res = query_spec.spectral_data.copy()
    fitted_components_spec = []
    for c, component in zip(cs, components):
        res -= c * component
        fitted_components_spec.append(rp.Spectrum(c * component, query_spec.spectral_axis))
    
    res_spec = rp.Spectrum(res, query_spec.spectral_axis)
    
    print("components:\n" + "\n".join(f"{name}, {c}" for name, c in zip(component_names, cs)))

    # First figure: query, residual, and fitted components
    rp.plot.spectra(
        [query_spec, res_spec] + fitted_components_spec,
        label=["query", "residual"] + component_names,
        plot_type="single",
        alpha=0.8,
    )

    return cs, res_spec, fitted_components_spec

## Code Examples
Todo: add a section on mean_spectrum_per_group() and filtering using GSC

### Loading Raman Data From .csv File

In [None]:
def load_colon_spectra(csv_path):
    """
    Load colon tissue Raman spectra into a GroupedSpectralContainer from a CSV file.

    Filters to spectra labeled as 'Normal', 'Adenom', or 'Karzinom' and tissue type 'preparation'.
    Returns a GroupedSpectralContainer and a NumPy array of binary labels (0 = normal, 1 = abnormal).
    """

    # Read CSV
    df = pd.read_csv(csv_path)

    # Filter to valid, usable spectra
    df = df[
        df['Annotation'].isin(['Normal', 'Adenom', 'Karzinom']) &
        (df['Tissue'] == 'preparation')
    ]

    # Create binary labels
    label_map = {'Normal': 0, 'Adenom': 1, 'Karzinom': 1}
    labels = df['Annotation'].map(label_map).to_numpy()

    # Extract metadata
    replicates = df["Name"].str.split('_').str[0]
    metadata = [
        {
            "annotation": annotation,
            "name": name,
            "tissue": tissue,
            "replicate": replicate,
            "label": label  # binary: 0 or 1
        }
        for annotation, name, tissue, replicate, label in zip(
            df['Annotation'], df['Name'], df['Tissue'], replicates, labels
        )
    ]

    # Isolate spectral data
    df = df.loc[:, ~df.columns.str.contains('^Unnamed')]  # Remove index column if present
    spectral_df = df.drop(columns=['Name', 'Annotation', 'Tissue'])

    # Extract spectra and axis
    spectra_array = spectral_df.to_numpy()
    spectral_axis = spectral_df.columns.astype(float).to_numpy()

    # Build Spectrum objects
    spectra = [
        rp.Spectrum(spectral_data, spectral_axis)
        for spectral_data in spectra_array
    ]

    # Return wrapped container
    container = GroupedSpectralContainer(spectral_list=spectra, metadata=metadata)
    return container, labels

In [None]:
raw_gsc, labels = load_colon_spectra("/home/linux_thoma/git/aranexx_sers_analysis/ramanlib/DATA_dp_wc_bc.csv")

### Visualising Raw Data

In [None]:
mean_per_group(raw_gsc, "annotation", plot_type="stacked", interval="sd")

In [None]:
random_per_group(raw_gsc, "annotation", n_samples=3, plot_type="stacked")

In [None]:
len(raw_gsc)

In [None]:
# Comments:
# While the mean spectra look like they have a relatively low variance, plotting random spectra show that the spectra are actually all over the place.
# I thought that they may have already been pre-processed but it appears that I misinterpreted the text of the paper.
# I will certainly need to preprocess as well as remove outliers where there were significant issues in the data.

In [None]:
# See counts of each annotation
print("Annotation counts:")
print(raw_gsc.df['annotation'].value_counts())

In [None]:
# Check combinations of tissue type, annotation and label
print("Annotation by tissue type:")
print(raw_gsc.df.groupby(['tissue', 'annotation', 'label']).size())

### Calibrate Raw Spectra

#### Example: Detecting wavenumber shift from internal Si calibration standard

To demonstrate the calibration functions, I will generate hypothetical internal silicon spectral data, which contains a single high-intensity peak. A study conducted by Itoh et al. estimates that the main silicon band is situated at 520.45 ± 0.28 cm-1 for Si 100.

In [None]:
# Generate hypothetical internal silicon standard spectra
from datetime import date

# Parameters
n = 100
ideal_center = 520.45
allowed_halfwidth = 0.28
center_sd = 0.2      # ~few % outside the window
amp_range = (800, 1200)
sigma_range = (0.3, 0.6)
noise_frac = 0.02
seed = 7

rng = np.random.default_rng(seed)

# Wavenumber axis around the Si line
x = np.linspace(480, 560, 1601)

# Random peak parameters
centers = rng.normal(ideal_center, center_sd, size=n)
amps    = rng.uniform(*amp_range, size=n)
sigmas  = rng.uniform(*sigma_range, size=n)

spectra = []
for c, a, s in zip(centers, amps, sigmas):
    y = a * np.exp(-(x - c)**2 / (2 * s**2))
    y += rng.normal(0.0, noise_frac * a, size=x.shape)  # noise
    spectra.append(rp.Spectrum(y, x))

# Minimal fake metadata: rotating dates and 3x3 spots
dates = [date(2025, 8, 1), date(2025, 8, 2)]
spots = [f"({r},{c})" for r in range(1, 4) for c in range(1, 4)]
metadata = [{"date": dates[i % len(dates)], "spot": spots[i % len(spots)]} for i in range(n)]

# Build GroupedSpectralContainer
calib_gsc = GroupedSpectralContainer(spectra, metadata)

Plot randomly selected spectra to visualy verify.

In [None]:
calib_gsc.plot_random(n_samples=3, seed=seed)

Check the calibration of a single spectrum

In [None]:
shift_value = get_wn_shift(calib_gsc["spectrum"].iloc[5], 520.45, plot=True)
print(shift_value)
plt.show()

Check the calibration of the entire GSC and return a GSC of uncalibrated spectra

In [None]:
fail_gsc, new_gsc = get_gsc_wn_shifts(
    calib_gsc,
    [520.17, 520.73],
    520.45
)

In [None]:
print(fail_gsc)
print(fail_gsc["shift"].abs().min())
# All shifts in the fail gsc are above 0.28, as expected, since we used the range 520.45 +- 0.28

In [None]:
print(new_gsc)

Using the shift values and dates given in these grouped spectra containers, we are able to apply a wavenumber shift to other data collected on or near the same day.

### Choose Best Baseline Subtraction Algorithm

In [None]:
baseline(raw_gsc["spectrum"].iloc[0], rp.preprocessing.baseline.ASLS(lam=1000))

In [None]:
n_baselines(raw_gsc, rp.preprocessing.baseline.ASLS(lam=1000), "ASLS")

In [None]:
processes = [
    rp.preprocessing.baseline.IARPLS(lam=10),
    rp.preprocessing.baseline.IARPLS(lam=100),
    rp.preprocessing.baseline.IARPLS(lam=1000),
    rp.preprocessing.baseline.IARPLS(lam=10000),
    rp.preprocessing.baseline.IARPLS(lam=100000)
]

names = [
    "lam = 10",
    "lam = 100",
    "lam = 1000",
    "lam = 10000",
    "lam = 100000"
]
compare_baselines(raw_gsc["spectrum"].iloc[0], processes, names, figsize=(10,10))

### Pipeline

In [None]:
# Define the pipeline.
pipe = rp.preprocessing.Pipeline([
    rp.preprocessing.despike.WhitakerHayes(),
    rp.preprocessing.baseline.IARPLS(lam = 300),
    rp.preprocessing.normalise.MaxIntensity(),
])

In [None]:
prepro_gsc = apply_pipeline_to_container(raw_gsc, pipe)

### Visualising Preprocessed Spectra

#### Individual Spectra

In [None]:
# Plot random spectra before and after preprocessing
# Create 1 row, 2 column subplot layout
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))  # Adjust figsize as needed

# Plot randomly selected raw spectra
random_per_group(raw_gsc, by="annotation", plot_type="single", ax=ax1, title="Raw Spectra", n_samples=1)
# Plot randomly selected preprocessed spectra
random_per_group(prepro_gsc, by="annotation", plot_type="single", ax=ax2, title="Preprocessed Spectra", n_samples=1)
plt.tight_layout()
plt.show()

#### Mean Spectra

In [None]:
# Demonstrate the full customisability of the plot functions
prepro_gsc.plot_mean(
    by=["annotation", "label"],
    interval="sd",
    plot_type="single",
    linewidth=1,
    label=["Adenoma", "Carcinoma", "Normal"],
    color=["red", "blue", "yellow"],
    title="My Raman Spectra",
    xlabel="My x axis",
    ylabel="Counts"
)
                     

In [None]:
# Plot mean spectra before and after preprocessing
# Create 1 row, 2 column subplot layout
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))  # Adjust figsize as needed

# Plot mean spectra before and after preprocessing
mean_per_group(raw_gsc, "annotation", plot_type="single", ax=ax1, title="Raw Means", interval="sd")
mean_per_group(prepro_gsc, "annotation", plot_type="single", ax=ax2, title="Preprocessed Means", interval="sd")
plt.tight_layout()
plt.show()

### Quality Control

In [None]:
results = outliers_per_group(prepro_gsc, rp.metrics.MSE, by="annotation", n_spectra=1)

In [None]:
results

In [None]:
axes = outliers_per_group(prepro_gsc, results, linewidth=1)
plt.show()

The Carcinoma sample at index 69 looks like an error in measurement has made it too noisy, so we will remove it from the dataset.

In [None]:
cleaned_gsc = GroupedSpectralContainer.from_dataframe(prepro_gsc.df.drop(prepro_gsc.df.index[69]))

In [None]:
results = calc_outlier_indices_per_group(cleaned_gsc, rp.metrics.MSE, by="annotation", n_spectra=1)

In [None]:
results

In [None]:
axes = plot_outlier_indices_per_group(cleaned_gsc, results, linewidth=1)
plt.show()

### Between-Group Analysis

#### Difference in Mean Spectra

In [None]:
group_stats = prepro_gsc.mean(by="annotation", include_stats=True)
group1_stats = group_stats[group_stats["annotation"]=="Normal"]
group2_stats = group_stats[group_stats["annotation"]=="Adenom"]
group3_stats = group_stats[group_stats["annotation"]=="Karzinom"]

In [None]:
diff_spectrum_1, ci_band_1 = mean_difference(group1_stats, group2_stats)
diff_spectrum_2, ci_band_2 = mean_difference(group2_stats, group3_stats)
diff_spectrum_3, ci_band_3 = mean_difference(group3_stats, group1_stats)

In [None]:
rp.plot.spectra([group1_stats["spectrum"].iloc[0], group2_stats["spectrum"].iloc[0]], plot_type="single", label=["Normal", "Adenom"])

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(8, 6))
mean_difference(diff_spectrum_1, ci_band_1, title="Normal - Adenoma", ax=axs[0])
mean_difference(diff_spectrum_2, ci_band_2, title="Adenoma - Carcinoma", ax=axs[1])
mean_difference(diff_spectrum_3, ci_band_3, title="Carcinoma - Normal", ax=axs[2])
plt.tight_layout()
plt.show()


#### Correlation Coefficient by Annotation

In [None]:
corr = mean_correlation_per_group(prepro_gsc, "annotation")

In [None]:
mean_correlation_per_group(corr, vmin=0.98)
plt.show()

#### Correlation Coefficient by Replicate

In [None]:
corr = mean_correlation_per_group(prepro_gsc, "replicate")

In [None]:
mean_correlation_per_group(corr, annot=False, vmin=0.90, vmax=1.0)

#### Next step: Loading and score plots (Guo et al SOP step 23)

### Classical Least Squares for Spectral Subtraction

In [None]:
import os

# Base directory: folder containing this script
BASE_DIR = os.getcwd()

# Example data folder inside ramanlib
DATA_DIR = os.path.join(BASE_DIR, "CLS_example_data")

# Build relative paths
grating = rp.load.renishaw(os.path.join(DATA_DIR, "grating 10s 50_ 3acc 532nm 1200lmm extended range 200-3600 spot 1.wdf"))
water = rp.load.renishaw(os.path.join(DATA_DIR, "water 10s 10_ 3acc 532nm 1200lmm extended range 200-3600 spot 3.wdf"))
ethanol = rp.load.renishaw(os.path.join(DATA_DIR, "ethanol 10s 1_ 3acc 532nm 1200lmm extended range 200-3600 spot 1.wdf"))
water_and_ethanol = rp.load.renishaw(os.path.join(DATA_DIR, "water and ethanol 10s 10_ 3acc 532nm 1200lmm extended range 200-3600 spot 1.wdf"))

In [None]:
rp.plot.spectra([grating, water, ethanol, water_and_ethanol], plot_type="single")

In [None]:
pipe = rp.preprocessing.Pipeline([
    rp.preprocessing.despike.WhitakerHayes(),
    rp.preprocessing.baseline.IARPLS(lam = 3000),
    rp.preprocessing.normalise.MaxIntensity(),
])

grating_prepro = pipe.apply(grating)
water_prepro = pipe.apply(water)
ethanol_prepro = pipe.apply(ethanol)
water_and_ethanol_prepro = pipe.apply(water_and_ethanol)

In [None]:
rp.plot.spectra([grating_prepro, water_prepro, ethanol_prepro, water_and_ethanol_prepro], plot_type="single")

In [None]:
cs, res_spec, fitted_components_spec = CLS(water_and_ethanol_prepro, [grating_prepro, water_prepro, ethanol_prepro], ["grating", "water", "ethanol"])