# Baseline Correction Methods

Compare the four built-in baseline correction methods on synthetic data.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from spectrakit import (
    baseline_als,
    baseline_polynomial,
    baseline_rubberband,
    baseline_snip,
)
from spectrakit.plot import plot_baseline

## Create Synthetic Spectrum

In [None]:
wavenumbers = np.linspace(400, 4000, 1000)

def gaussian(x, c, a, s):
    return a * np.exp(-((x - c) ** 2) / (2 * s**2))

peaks = (
    gaussian(wavenumbers, 1000, 2.0, 25)
    + gaussian(wavenumbers, 1650, 1.5, 35)
    + gaussian(wavenumbers, 2900, 2.5, 40)
)

# Curved baseline
true_baseline = 1.0 + 0.5 * np.sin(wavenumbers / 600) + 0.0001 * (wavenumbers - 2000) ** 2 / 1e4
spectrum = peaks + true_baseline

plt.figure(figsize=(10, 4))
plt.plot(wavenumbers, spectrum, label="Spectrum")
plt.plot(wavenumbers, true_baseline, "--", label="True Baseline")
plt.legend()
plt.xlabel("Wavenumber")
plt.ylabel("Intensity")
plt.title("Synthetic Spectrum with Known Baseline")
plt.gca().invert_xaxis()
plt.show()

## Compare Methods

In [None]:
methods = {
    "ALS (lam=1e6)": lambda y: y - baseline_als(y, lam=1e6, p=0.01),
    "SNIP (40 iters)": lambda y: y - baseline_snip(y, num_iterations=40),
    "Polynomial (order=3)": lambda y: y - baseline_polynomial(y, poly_order=3),
    "Rubberband": lambda y: y - baseline_rubberband(y),
}

fig, axes = plt.subplots(2, 2, figsize=(14, 8), sharex=True)

for ax, (name, method) in zip(axes.flat, methods.items()):
    corrected = method(spectrum)
    ax.plot(wavenumbers, peaks, "--", alpha=0.5, label="True Signal")
    ax.plot(wavenumbers, corrected, label="Corrected")
    ax.set_title(name)
    ax.legend(fontsize=8)
    ax.invert_xaxis()

plt.suptitle("Baseline Correction Comparison", fontsize=14)
plt.tight_layout()
plt.show()

## ALS Parameter Sensitivity

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharex=True, sharey=True)

for ax, lam in zip(axes, [1e4, 1e6, 1e8]):
    corrected = baseline_als(spectrum, lam=lam, p=0.01)
    plot_baseline(spectrum, spectrum - corrected + corrected * 0,
                 wavenumbers, ax=ax, show_corrected=False)
    # Re-plot manually for clarity
    ax.clear()
    ax.plot(wavenumbers, spectrum, alpha=0.7, label="Original")
    bl = spectrum - corrected
    ax.plot(wavenumbers, bl, "--r", label="Estimated BL")
    ax.plot(wavenumbers, true_baseline, ":g", label="True BL")
    ax.set_title(f"lam = {lam:.0e}")
    ax.legend(fontsize=8)
    ax.invert_xaxis()

plt.suptitle("ALS: Effect of Lambda Parameter", fontsize=14)
plt.tight_layout()
plt.show()