# Gradient-based feature extraction from event windows

This notebook implements a standalone pipeline that reloads the per-event windows saved by the segmentation workflow, detects the onset/offset of each event from the filtered trace, and integrates the corresponding baseline-corrected raw signal.

## Workflow overview

1. **Load event windows** from the JSON files generated during screening.
2. **Pre-process the filtered trace** with a rolling mean and use its gradient to highlight transition points.
3. **Locate event boundaries** by thresholding the gradient magnitude.
4. **Reuse the stored baseline** for the window.
5. **Integrate the raw trace** above the baseline between the detected bounds.
6. **Persist derived features** for downstream analysis.
7. **Perform quality checks** by plotting selected events.

In [None]:
# Imports
from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from screening_sample_ssd import get_data_folders

In [None]:
# Configuration
DATA_PATH = Path('/Users/hugo/MOLECL/Molecl_data_H')  # Root directory or JSON file with event windows
OUTPUT_PATH = Path('feature_metrics.csv')

if not DATA_PATH.exists():
    raise FileNotFoundError(
        f"Expected event window data under {DATA_PATH}. Please update DATA_PATH before running the notebook."
    )

In [None]:
def _contains_baseline_segment(path: Path) -> bool:
    parts = [part.lower() for part in path.parts]
    return any('baseline' in part for part in parts)


def discover_event_files(data_path: Path) -> List[Path]:
    if data_path.is_file():
        return [data_path]

    if data_path.is_dir():
        event_files: List[Path] = []
        combined_here = data_path / 'combined_peaks_data.json'
        if combined_here.exists():
            event_files.append(combined_here)

        candidate_dirs: List[Path] = []
        try:
            candidate_dirs.extend(Path(folder) for folder in get_data_folders(str(data_path)))
        except Exception:
            candidate_dirs.extend(path for path in data_path.iterdir() if path.is_dir())

        seen_dirs = set()
        for folder in candidate_dirs:
            folder = Path(folder)
            if folder in seen_dirs or not folder.exists():
                continue
            seen_dirs.add(folder)
            if _contains_baseline_segment(folder):
                continue

            combined = folder / 'combined_peaks_data.json'
            if combined.exists():
                event_files.append(combined)

            for json_file in sorted(folder.rglob('peaks_data.json')):
                if not _contains_baseline_segment(json_file.parent):
                    event_files.append(json_file)

        if not event_files:
            for json_file in sorted(data_path.rglob('peaks_data.json')):
                if not _contains_baseline_segment(json_file.parent):
                    event_files.append(json_file)

        unique_files: List[Path] = []
        seen_paths = set()
        for json_file in event_files:
            json_file = Path(json_file)
            if not json_file.exists():
                continue
            resolved = json_file.resolve()
            if resolved in seen_paths:
                continue
            seen_paths.add(resolved)
            unique_files.append(json_file)

        if unique_files:
            return unique_files

    raise FileNotFoundError(
        f'No peak window JSON files found under {data_path}. Ensure the segmentation pipeline has generated peaks_data.json files.'
    )


def load_peak_windows(event_files: Iterable[Path]) -> List[Dict]:
    windows: List[Dict] = []
    for json_path in event_files:
        json_path = Path(json_path)
        with json_path.open() as f:
            payload = json.load(f)

        def append_event(event: Dict, idx: int, batch_key: Optional[str] = None) -> None:
            enriched = dict(event)
            enriched.setdefault('event_index', idx)
            if batch_key is not None:
                enriched.setdefault('batch_id', batch_key)
            enriched.setdefault('source_json', str(json_path))
            enriched.setdefault('sample_folder', json_path.parent.name)
            if 'source_file' not in enriched:
                enriched['source_file'] = json_path.stem
            enriched.setdefault('global_event_index', len(windows))
            windows.append(enriched)

        if isinstance(payload, list):
            for idx, event in enumerate(payload):
                if isinstance(event, dict):
                    append_event(event, idx)
        elif isinstance(payload, dict):
            for batch_key, batch_events in payload.items():
                if not isinstance(batch_events, Iterable):
                    continue
                for idx, event in enumerate(batch_events):
                    if isinstance(event, dict):
                        append_event(event, idx, batch_key=batch_key)
        else:
            raise ValueError(f'Unsupported JSON structure in {json_path}')

    if not windows:
        raise ValueError('No events found in the provided JSON files.')

    return windows


EVENT_FILES = discover_event_files(DATA_PATH)
raw_events = load_peak_windows(EVENT_FILES)
len(EVENT_FILES), len(raw_events)

In [None]:
def ensure_odd(window: int) -> int:
    window = int(window)
    return window + 1 if window % 2 == 0 else window


def smooth_signal(signal: np.ndarray, window: int = 11) -> np.ndarray:
    if window <= 1:
        return signal

    window = ensure_odd(window)
    series = pd.Series(signal)
    smoothed = series.rolling(window, center=True, min_periods=1).mean().to_numpy()
    return smoothed


def compute_gradients(filtered: np.ndarray, dt: float, window: int = 11) -> Tuple[np.ndarray, np.ndarray]:
    smoothed = smooth_signal(filtered, window=window)
    grad = np.gradient(smoothed, dt)
    max_abs = np.max(np.abs(grad))
    if max_abs == 0:
        norm_grad = np.zeros_like(grad)
    else:
        norm_grad = grad / max_abs
    norm_grad = smooth_signal(norm_grad, window=max(3, window // 2))
    return grad, norm_grad


def mad(arr: np.ndarray) -> float:
    median = np.median(arr)
    return float(np.median(np.abs(arr - median)))


def gradient_threshold(norm_grad: np.ndarray, scale: float = 3.5) -> float:
    noise_level = mad(norm_grad)
    if noise_level == 0:
        noise_level = np.std(norm_grad)
    threshold = scale * noise_level if noise_level > 0 else 0.1
    return float(np.clip(threshold, 0.05, 1.0))


def detect_bounds(norm_grad: np.ndarray, threshold: float, guard: int = 5, min_span: int = 10) -> Tuple[int, int]:
    if threshold <= 0:
        return 0, len(norm_grad) - 1

    above = np.where(np.abs(norm_grad) >= threshold)[0]
    if above.size == 0:
        return 0, len(norm_grad) - 1

    start = max(0, int(above[0]) - guard)
    end = min(len(norm_grad) - 1, int(above[-1]) + guard)

    if end - start < min_span:
        center = int(above[above.size // 2])
        half_span = max(min_span // 2, 1)
        start = max(0, center - half_span)
        end = min(len(norm_grad) - 1, center + half_span)

    return start, end


def integrate_baseline_corrected(raw_signal: np.ndarray, baseline: float, start: int, end: int, dt: float) -> float:
    segment = raw_signal[start : end + 1] - baseline
    return float(np.trapz(segment, dx=dt))

In [None]:
def compute_event_features(event: Dict, gradient_window: int = 21, threshold_scale: float = 3.5, guard: int = 10, min_span: int = 20) -> Dict:
    raw, filtered, time_axis, dt, baseline = prepare_event_arrays(event)
    _, norm_grad = compute_gradients(filtered, dt, window=gradient_window)
    threshold = gradient_threshold(norm_grad, scale=threshold_scale)
    start_idx, end_idx = detect_bounds(norm_grad, threshold, guard=guard, min_span=min_span)
    area = integrate_baseline_corrected(raw, baseline, start_idx, end_idx, dt)

    return {
        'global_event_index': event.get('global_event_index'),
        'event_index': event.get('event_index'),
        'batch_id': event.get('batch_id'),
        'sample_folder': event.get('sample_folder'),
        'source_file': event.get('source_file'),
        'source_json': event.get('source_json'),
        't_start': time_axis[start_idx],
        't_end': time_axis[end_idx],
        'duration': (end_idx - start_idx) * dt,
        'baseline': baseline,
        'area_under_curve': area,
        'start_idx': start_idx,
        'end_idx': end_idx,
        'threshold': threshold,
        'gradient_window': gradient_window,
        'threshold_scale': threshold_scale,
    }

In [None]:
event_features = [compute_event_features(event) for event in raw_events]
features_df = pd.DataFrame(event_features)
features_df.head()

In [None]:
features_df.to_csv(OUTPUT_PATH, index=False)
OUTPUT_PATH

In [None]:
def plot_event(event: Dict, features: Dict, gradient_window: int = 21):
    raw, filtered, time_axis, dt, baseline = prepare_event_arrays(event)
    _, norm_grad = compute_gradients(filtered, dt, window=gradient_window)

    start_idx = int(features['start_idx'])
    end_idx = int(features['end_idx'])

    fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)

    axes[0].plot(time_axis, filtered, label='Filtered signal', color='tab:blue')
    axes[0].axhline(baseline, color='tab:gray', linestyle='--', label='Baseline')
    axes[0].axvspan(time_axis[start_idx], time_axis[end_idx], color='tab:orange', alpha=0.2, label='Event bounds')
    axes[0].set_ylabel('Filtered (a.u.)')
    axes[0].legend(loc='upper right')

    axes[1].plot(time_axis, raw, label='Raw signal', color='tab:green')
    axes[1].axhline(baseline, color='tab:gray', linestyle='--', label='Baseline')
    axes[1].axvspan(time_axis[start_idx], time_axis[end_idx], color='tab:orange', alpha=0.2, label='Integrated window')
    axes[1].set_ylabel('Raw (a.u.)')
    axes[1].set_xlabel('Time (s)')
    axes[1].legend(loc='upper right')

    inset = axes[0].twinx()
    inset.plot(time_axis, np.abs(norm_grad), color='tab:red', alpha=0.5, label='|Normalised gradient|')
    inset.axhline(features['threshold'], color='tab:red', linestyle=':', label='Threshold')
    inset.set_ylabel('Gradient magnitude')

    plt.tight_layout()
    plt.show()

# Example usage: replace 0 with the index you wish to inspect
example_idx = 0
plot_event(raw_events[example_idx], event_features[example_idx])