# SpectraKit Quick Start

This notebook demonstrates the basic SpectraKit workflow:
1. Generate synthetic spectral data
2. Apply smoothing, baseline correction, and normalization
3. Compare results

In [None]:
import matplotlib.pyplot as plt
import numpy as np

from spectrakit import (
    baseline_als,
    normalize_snv,
    smooth_savgol,
)
from spectrakit.plot import plot_comparison, plot_spectrum

## Generate Synthetic Data

Create a synthetic spectrum with peaks, baseline drift, and noise.

In [None]:
rng = np.random.default_rng(42)
wavenumbers = np.linspace(400, 4000, 1000)


# Create peaks
def gaussian(x, center, amp, sigma):
    return amp * np.exp(-((x - center) ** 2) / (2 * sigma**2))


signal = (
    gaussian(wavenumbers, 1000, 2.0, 30)
    + gaussian(wavenumbers, 1650, 1.5, 40)
    + gaussian(wavenumbers, 2900, 3.0, 50)
    + gaussian(wavenumbers, 3400, 1.0, 60)
)

# Add baseline drift and noise
baseline = 0.5 + 0.3 * np.sin(wavenumbers / 800)
noise = rng.normal(0, 0.05, 1000)
raw = signal + baseline + noise

plot_spectrum(raw, wavenumbers, title="Raw Synthetic Spectrum")
plt.show()

## Step 1: Smoothing

In [None]:
smoothed = smooth_savgol(raw, window_length=11, polyorder=3)

plot_comparison(
    raw,
    smoothed,
    wavenumbers,
    labels=("Raw", "Smoothed"),
    title="Effect of Savitzky-Golay Smoothing",
)
plt.show()

## Step 2: Baseline Correction

In [None]:
corrected = baseline_als(smoothed, lam=1e6, p=0.01)

plot_comparison(
    smoothed,
    corrected,
    wavenumbers,
    labels=("Smoothed", "Baseline Corrected"),
    title="Baseline Correction with ALS",
)
plt.show()

## Step 3: Normalization

In [None]:
normalized = normalize_snv(corrected)

plot_comparison(
    corrected,
    normalized,
    wavenumbers,
    labels=("Corrected", "SNV Normalized"),
    title="SNV Normalization",
)
plt.show()

## Full Pipeline

Chain all steps together using the Pipeline class.

In [None]:
from spectrakit.pipeline import Pipeline

pipe = Pipeline()
pipe.add("smooth", smooth_savgol, window_length=11, polyorder=3)
pipe.add("baseline", baseline_als, lam=1e6, p=0.01)
pipe.add("normalize", normalize_snv)

# Apply to a batch of spectra
batch = np.vstack([raw + rng.normal(0, 0.02, 1000) for _ in range(10)])
processed_batch = pipe.transform(batch)

plot_spectrum(processed_batch, wavenumbers, title="Batch of 10 Processed Spectra")
plt.show()

print(f"Input shape: {batch.shape}")
print(f"Output shape: {processed_batch.shape}")
print(pipe)