# Feature extraction from nanopore event traces

This notebook loads `combined_peaks_data.json` files produced by the screening pipeline and computes the five baseline features (ΔI, standard deviation, skewness, kurtosis, and dwell time `t_off`) described in Wang *et al.* (2023). The resulting feature table will be reused for quick model prototyping (e.g. an SVM classifier).



## Workflow overview

1. Configure the paths to the processed data folders that contain `combined_peaks_data.json`.
2. Load every event trace and estimate the baseline and blockade regions.
3. Compute ΔI, SD, skewness, kurtosis, and `t_off` for each event.
4. Export the aggregated feature table for downstream modeling.



In [None]:
from pathlib import Path
import json
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import kurtosis, skew

pd.set_option('display.max_columns', 20)
pd.set_option('display.width', 120)

sns.set_theme(style='whitegrid')


## Configure data sources

Set `DATA_ROOT` to the directory that contains one or more subfolders produced by `combine_peaks_data.py`. Each subfolder should contain a `combined_peaks_data.json`. If you prefer to point to specific files, populate `EXPLICIT_FILES` with their paths instead.



In [None]:
# Path to the directory that contains processed runs (update this to your environment)
DATA_ROOT = Path('/path/to/processed/data/root')  # <-- change me

# Optional: list explicit combined JSON files if they live in different roots
EXPLICIT_FILES: List[str] = []  # e.g. ['/data/run1/combined_peaks_data.json']

OUTPUT_DIR = Path('feature_tables')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

if EXPLICIT_FILES:
    combined_files = [Path(p).expanduser().resolve() for p in EXPLICIT_FILES]
elif DATA_ROOT.exists():
    if DATA_ROOT.is_file() and DATA_ROOT.name == 'combined_peaks_data.json':
        combined_files = [DATA_ROOT.expanduser().resolve()]
    else:
        combined_files = sorted(path.expanduser().resolve() for path in DATA_ROOT.rglob('combined_peaks_data.json'))
else:
    combined_files = []
    print('Update DATA_ROOT or EXPLICIT_FILES with the location of combined_peaks_data.json files.')

combined_files



## Helper utilities

The functions below load events, identify the blockade portion of each trace, and compute the five requested features. The baseline is estimated from the high-current samples (outside the blockade), whereas the blockade level is taken from the low-current samples.



In [None]:
def load_events(json_path: Path) -> List[Dict]:
    """Load the list of events from a combined peaks JSON file."""
    with open(json_path, 'r') as fh:
        events = json.load(fh)
    for event in events:
        if 'source_file' not in event:
            event['source_file'] = json_path.stem
    return events


def _safe_median(values: np.ndarray) -> float:
    values = values[np.isfinite(values)]
    if values.size == 0:
        return float('nan')
    return float(np.median(values))


def _event_masks(norm_signal: np.ndarray, fraction: float = 0.25) -> Dict[str, np.ndarray]:
    """Return boolean masks that delineate blockade vs baseline samples."""
    sorted_norm = np.sort(norm_signal[np.isfinite(norm_signal)])
    if sorted_norm.size == 0:
        mask = np.zeros_like(norm_signal, dtype=bool)
        return {
            'event_mask': mask,
            'baseline_mask': ~mask,
            'baseline_level': float('nan'),
            'blockade_level': float('nan'),
            'delta_norm': float('nan'),
            'threshold': float('nan'),
        }

    top_k = max(int(np.ceil(sorted_norm.size * 0.2)), 1)
    baseline_level = float(np.median(sorted_norm[-top_k:]))
    blockade_level = float(np.median(sorted_norm[:top_k]))
    delta_norm = baseline_level - blockade_level

    if delta_norm <= 0:
        delta_norm = baseline_level - float(np.min(sorted_norm))

    threshold = baseline_level - fraction * delta_norm if delta_norm > 0 else baseline_level - 1e-3
    applied_threshold = threshold
    event_mask = norm_signal <= applied_threshold

    if not np.any(event_mask):
        fallback_threshold = baseline_level - 0.1 * max(abs(delta_norm), abs(baseline_level) * 0.05)
        event_mask = norm_signal <= fallback_threshold
        applied_threshold = fallback_threshold

    if not np.any(event_mask):
        applied_threshold = baseline_level
        event_mask = norm_signal < baseline_level

    baseline_mask = ~event_mask

    return {
        'event_mask': event_mask,
        'baseline_mask': baseline_mask,
        'baseline_level': baseline_level,
        'blockade_level': blockade_level,
        'delta_norm': delta_norm,
        'threshold': float(applied_threshold),
    }


def _extract_signals(event: Dict) -> Tuple[np.ndarray, np.ndarray]:
    """Return normalized and absolute signals with NaNs removed."""
    raw_signal = event.get('raw_signal') or event.get('norm_signal')
    if raw_signal is None:
        raise KeyError('Event is missing a normalized signal ("raw_signal" or "norm_signal").')

    raw_signal = np.asarray(raw_signal, dtype=float)
    raw_signal_abs = np.asarray(event['raw_signal_not_norm'], dtype=float)

    n_samples = min(raw_signal.size, raw_signal_abs.size)
    raw_signal = raw_signal[:n_samples]
    raw_signal_abs = raw_signal_abs[:n_samples]

    valid = np.isfinite(raw_signal) & np.isfinite(raw_signal_abs)
    return raw_signal[valid], raw_signal_abs[valid]


def _event_time_axis(event: Dict, n_samples: int) -> Tuple[np.ndarray, str]:
    """Construct a time axis suitable for plotting the event."""
    dt = float(event.get('dt', np.nan))
    if not np.isfinite(dt) or dt <= 0:
        return np.arange(n_samples, dtype=float), 'Sample index'
    axis = np.arange(n_samples, dtype=float) * dt
    return axis, f'Time (dt={dt:g})'


def _safe_std(values: np.ndarray) -> float:
    if values.size < 2:
        return float('nan')
    return float(np.std(values, ddof=1))


def _safe_skew(values: np.ndarray) -> float:
    if values.size < 3:
        return float('nan')
    return float(skew(values, bias=False))


def _safe_kurtosis(values: np.ndarray) -> float:
    if values.size < 4:
        return float('nan')
    return float(kurtosis(values, fisher=True, bias=False))


def compute_event_features(event: Dict, *, sample_id: Optional[str] = None, source_path: Optional[Path] = None) -> Dict:
    raw_signal, raw_signal_abs = _extract_signals(event)

    masks = _event_masks(raw_signal)
    event_mask = masks['event_mask']
    baseline_mask = masks['baseline_mask']

    if np.any(baseline_mask):
        baseline_current = _safe_median(raw_signal_abs[baseline_mask])
    else:
        baseline_current = _safe_median(raw_signal_abs)

    if np.any(event_mask):
        blockade_current = _safe_median(raw_signal_abs[event_mask])
        drop_series = baseline_current - raw_signal_abs[event_mask]
    else:
        blockade_current = _safe_median(raw_signal_abs)
        drop_series = baseline_current - raw_signal_abs

    delta_I = float(baseline_current - blockade_current)
    sd_drop = _safe_std(drop_series)
    skew_drop = _safe_skew(drop_series)
    kurt_drop = _safe_kurtosis(drop_series)

    dt = float(event.get('dt', np.nan))
    if np.isnan(dt) or dt == 0:
        t_off = float(event.get('t_end', np.nan) - event.get('t_start', np.nan))
    else:
        if np.any(event_mask):
            event_indices = np.where(event_mask)[0]
            t_off = float((event_indices[-1] - event_indices[0] + 1) * dt)
        else:
            t_off = float(raw_signal.size * dt)

    feature_record = {
        'sample_id': sample_id if sample_id is not None else (source_path.parent.name if source_path else None),
        'source_file': event.get('source_file', source_path.stem if source_path else None),
        'peak_index': event.get('peak_index'),
        'delta_I': delta_I,
        'sd': sd_drop,
        'skew': skew_drop,
        'kurtosis': kurt_drop,
        't_off': t_off,
        'baseline_current': baseline_current,
        'blocked_current': blockade_current,
        'n_event_samples': int(np.sum(event_mask)),
        'total_samples': int(raw_signal.size),
    }

    if source_path is not None:
        feature_record['combined_file'] = str(source_path)

    if 'snr_db' in event:
        feature_record['snr_db'] = event['snr_db']

    return feature_record


## Extract features from every combined file

Run the cell below after configuring the paths. A feature dictionary is created for every event and aggregated into a single DataFrame.



In [None]:
feature_records: List[Dict] = []

for combined_path in combined_files:
    events = load_events(combined_path)
    sample_id = combined_path.parent.name
    for event in events:
        try:
            feature_records.append(
                compute_event_features(event, sample_id=sample_id, source_path=combined_path)
            )
        except Exception as exc:
            print(f'Failed to process event from {combined_path}: {exc}')

features_df = pd.DataFrame(feature_records)
print(f'Loaded {len(feature_records)} events from {len(combined_files)} combined files.')
features_df.head()



## Per-event inspection

Visualise a representative event to confirm the `_event_masks` segmentation. The plot overlays baseline and blockade samples, the derived threshold, and the `t_start`/`t_end` markers inferred from the mask.


In [None]:
# Visualise baseline vs blockade segmentation for a single event
if not combined_files:
    print('No combined_peaks_data.json files were discovered. Update DATA_ROOT or EXPLICIT_FILES to plot an event.')
else:
    example_path = combined_files[0]
    events = load_events(example_path)

    if not events:
        print(f'No events found in {example_path}.')
    else:
        example_event = events[0]
        try:
            norm_signal, _ = _extract_signals(example_event)
        except KeyError as exc:
            print(f'Unable to read event signals: {exc}')
        else:
            masks = _event_masks(norm_signal)
            time_axis, time_label = _event_time_axis(example_event, norm_signal.size)
            event_indices = np.where(masks['event_mask'])[0]

            sample_name = example_event.get('sample_id') or example_path.parent.name
            event_label = example_event.get('peak_index', 'unknown')

            fig, ax = plt.subplots(figsize=(10, 4.5))
            ax.plot(time_axis, norm_signal, color='tab:blue', linewidth=1.2, label='Normalized signal')

            if np.any(masks['baseline_mask']):
                ax.scatter(
                    time_axis[masks['baseline_mask']],
                    norm_signal[masks['baseline_mask']],
                    color='tab:green',
                    s=12,
                    alpha=0.7,
                    label='Baseline samples',
                )

            if np.any(masks['event_mask']):
                ax.scatter(
                    time_axis[masks['event_mask']],
                    norm_signal[masks['event_mask']],
                    color='tab:orange',
                    s=12,
                    alpha=0.9,
                    label='Blockade samples',
                )

            ax.axhline(masks['baseline_level'], color='tab:green', linestyle='--', linewidth=1, label=f"Baseline level ({masks['baseline_level']:.3g})")
            ax.axhline(masks['blockade_level'], color='tab:orange', linestyle='--', linewidth=1, label=f"Blockade level ({masks['blockade_level']:.3g})")

            if np.isfinite(masks['threshold']):
                ax.axhline(masks['threshold'], color='tab:red', linestyle=':', linewidth=1.2, label=f"Threshold ({masks['threshold']:.3g})")

            if event_indices.size:
                start_time = time_axis[event_indices[0]]
                end_time = time_axis[event_indices[-1]]
                ax.axvline(start_time, color='black', linestyle='--', linewidth=1, label='t_start')
                ax.axvline(end_time, color='black', linestyle='-.', linewidth=1, label='t_end')

            ax.set_xlabel(time_label)
            ax.set_ylabel('Normalized current')
            ax.set_title(f'Sample: {sample_name} — peak {event_label}')

            handles, labels = ax.get_legend_handles_labels()
            legend = dict(zip(labels, handles))
            ax.legend(legend.values(), legend.keys(), loc='best', frameon=True)
            ax.grid(True, alpha=0.3)
            plt.show()


## Feature distributions

Use histograms and scatter plots to inspect the spread of the handcrafted features and how different samples occupy feature space.


In [None]:
# Histograms / KDE plots for the five features
if features_df.empty:
    print('The feature table is empty. Run the extraction cell above to populate `features_df`.')
else:
    columns = globals().get('feature_columns', ['delta_I', 'sd', 'skew', 'kurtosis', 't_off'])
    fig, axes = plt.subplots(1, len(columns), figsize=(4 * len(columns), 3.5), constrained_layout=True)
    if len(columns) == 1:
        axes = [axes]
    for ax, column in zip(axes, columns):
        column_data = features_df[column].replace([np.inf, -np.inf], np.nan).dropna()
        if column_data.empty:
            ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
            ax.set_title(column)
            ax.set_xlabel(column)
            ax.set_ylabel('Count')
            continue
        sns.histplot(column_data, bins=30, kde=True, ax=ax, color='tab:blue', edgecolor='white', linewidth=0.5)
        ax.set_title(column)
        ax.set_xlabel(column)
        ax.set_ylabel('Count')
    plt.show()


In [None]:
# Scatter plot to inspect ΔI vs. t_off coloured by sample
if features_df.empty:
    print('The feature table is empty. Run the extraction cell above to populate `features_df`.')
else:
    scatter_df = features_df[['delta_I', 't_off', 'sample_id']].copy()
    scatter_df[['delta_I', 't_off']] = scatter_df[['delta_I', 't_off']].replace([np.inf, -np.inf], np.nan)
    scatter_df = scatter_df.dropna(subset=['delta_I', 't_off'])

    if scatter_df.empty:
        print('No finite ΔI / t_off values are available for plotting.')
    else:
        scatter_df['sample_plot_label'] = scatter_df['sample_id'].fillna('Unknown').astype(str)
        counts = scatter_df['sample_plot_label'].value_counts()
        if counts.size > 10:
            top_labels = counts.index[:9]
            scatter_df['sample_plot_label'] = scatter_df['sample_plot_label'].where(scatter_df['sample_plot_label'].isin(top_labels), 'Other')

        fig, ax = plt.subplots(figsize=(6, 5))
        sns.scatterplot(
            data=scatter_df,
            x='delta_I',
            y='t_off',
            hue='sample_plot_label',
            palette='tab10',
            s=45,
            edgecolor='white',
            linewidth=0.5,
            ax=ax,
        )
        ax.set_xlabel('ΔI (baseline − blockade)')
        ax.set_ylabel('t_off (dwell time)')
        ax.set_title('ΔI vs. t_off by sample')
        ax.grid(True, alpha=0.3)
        ax.legend(title='Sample', bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
        plt.tight_layout()
        plt.show()


## Raw vs. normalized traces

Compare several events to verify that the normalisation preserves event structure and removes slow drift.


In [None]:
# Display a few raw and normalized traces side-by-side
max_examples = 3
trace_examples = []

for combined_path in combined_files:
    events = load_events(combined_path)
    for event in events:
        try:
            norm_signal, raw_signal_abs = _extract_signals(event)
        except KeyError:
            continue

        if norm_signal.size == 0:
            continue

        time_axis, time_label = _event_time_axis(event, norm_signal.size)
        label = f"{combined_path.parent.name} — peak {event.get('peak_index', 'unknown')}"
        trace_examples.append((label, time_axis, time_label, raw_signal_abs, norm_signal))

        if len(trace_examples) >= max_examples:
            break
    if len(trace_examples) >= max_examples:
        break

if not trace_examples:
    print('No events available to plot. Update DATA_ROOT or EXPLICIT_FILES and re-run the extraction cell.')
else:
    fig, axes = plt.subplots(len(trace_examples), 2, figsize=(12, 3.2 * len(trace_examples)), sharex='col')
    axes = np.atleast_2d(axes)

    for row, (label, time_axis, time_label, raw_abs, norm_sig) in enumerate(trace_examples):
        raw_ax, norm_ax = axes[row, 0], axes[row, 1]

        raw_ax.plot(time_axis, raw_abs, color='tab:blue', linewidth=1.0)
        raw_ax.set_title(f'{label} — raw')
        raw_ax.set_ylabel('Current (raw units)')
        raw_ax.grid(True, alpha=0.3)

        norm_ax.plot(time_axis, norm_sig, color='tab:orange', linewidth=1.0)
        norm_ax.set_title(f'{label} — normalized')
        norm_ax.set_ylabel('Normalized current')
        norm_ax.grid(True, alpha=0.3)

        raw_ax.set_xlabel(time_label)
        norm_ax.set_xlabel(time_label)

    fig.suptitle('Raw vs. normalized event traces', fontsize=14, y=1.02)
    fig.tight_layout()
    plt.show()


## Inspect summary statistics

The table below provides a quick sanity check for the magnitude of each feature. Adjust the baseline detection heuristics above if the distributions look suspicious.



In [None]:
feature_columns = ['delta_I', 'sd', 'skew', 'kurtosis', 't_off']

if not features_df.empty:
    display(features_df[feature_columns].describe())
else:
    print('Feature table is empty. Verify DATA_ROOT/EXPLICIT_FILES before continuing.')



## Persist the feature table

Export the features to CSV so that other notebooks (e.g. model fitting) can consume them.



In [None]:
output_path = OUTPUT_DIR / 'event_features.csv'

if not features_df.empty:
    features_df.to_csv(output_path, index=False)
    print(f'Saved {len(features_df)} feature rows to {output_path}')
else:
    print('Skipped saving because the feature table is empty.')



## Optional: map samples to class labels

Create a dictionary that maps folder names (or `sample_id`) to the DNA classes you want to predict. Uncomment and adapt the cell below when preparing the modeling notebook.



In [None]:
# SAMPLE_LABEL_MAP = {
#     '500mV_100bp_1ngmkl_10MHz_boost_240830164444': '100bp',
#     '500mV_200bp_1ngmkl_10MHz_boost_240830165006': '200bp',
# }
#
# if not features_df.empty:
#     features_df['label'] = features_df['sample_id'].map(SAMPLE_LABEL_MAP)
#     display(features_df[['sample_id', 'label']].drop_duplicates())

