# imports

In [None]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
import math
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from neurokit.io import read_edf
from neurokit.io import Recording
from typing import Sequence, Tuple
from neurokit.preprocessing.filters import bandpass
from plotly.subplots import make_subplots
from scipy.ndimage.morphology import binary_opening
from neurokit.utils import mask_to_intervals
from neurokit.utils import intervals_to_mask

##### helper functions

In [None]:
def _calculate_perf_measures(prediction, reference, threshold):
    TN, FP, FN, TP = confusion_matrix(reference, prediction).ravel()
    accuracy = (prediction == reference).mean()
    sensitivity = TP/(TP+FN)
    specificity = TN/(TN+FP)
    precision = TP/(TP+FP)
    fnr = FN/(FN+TP)
    fpr = FP/(FP+TN)
    return {'accuracy': accuracy,
            'sensitivity': sensitivity,
            'specificity': specificity,
            'precision': precision,
            'fnr': fnr,
            'fpr': fpr,
            'threshold': threshold}

def _get_values(values, idx):
    orig = values.copy()
    mask = values.copy()
    orig[idx] = np.nan
    mask[~idx] = np.nan
    return orig, mask

def _plot_suppression(data, pred, ref, channel, sup_type):
    values = data.loc[:, channel]
    ref_ies = ref.loc[:, sup_type].values.astype(bool)
    orig , ies = _get_values(values, ref_ies)
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05, subplot_titles=('Reference', 'Prediction'))
    fig.add_trace(go.Scatter(x=data.index, y=orig, line_color="#1f77b4", name=channel), row=1, col=1)
    fig.add_trace(go.Scatter(x=data.index, y=ies, line_color="#d62728", name=sup_type), row=1, col=1)
    o2, ies2 = _get_values(values, pred)
    fig.add_trace(go.Scatter(x=data.index, y=o2, line_color="#1f77b4", name=channel), row=2, col=1)
    fig.add_trace(go.Scatter(x=data.index, y=ies2, line_color="#d62728", name=sup_type), row=2, col=1)
    fig.show()

def _detect_suppressions(recording: Recording,
                         channels: Sequence = None,
                         threshold: float = None,
                         min_duration: float = 1.):
    if not channels:
        channels = recording.channels

    rec = recording.artifacts_to_nan()
    if threshold is None:
        threshold = _find_threshold(rec.data.loc[:, channels])
    envelope = rec.data.loc[:, channels].abs().values.max(axis=1)
    min_length = math.ceil(min_duration * rec.frequency)
    with np.errstate(invalid='ignore'):
        ies_mask = envelope < threshold
    ies_mask = binary_opening(ies_mask, np.ones(min_length))
    return ies_mask


def _find_threshold(data: pd.DataFrame, threshold: float = 8.):
    mean_amplitude = data.abs().mean().mean()
    if mean_amplitude < 30:
        threshold = threshold / 1.25
    return threshold


def _detect_alpha_suppressions(
    recording: Recording,
    channels: Sequence = None,
    frequency_band: Tuple[float, float] = (8., 16.),
    threshold: float = None
):
    if not channels:
        channels = recording.channels
    rec = recording.copy()
    rec.data = recording.data.loc[:, channels]
    filtered = bandpass(rec, frequency_band)
    if threshold is None:
        rms_before = np.sqrt(np.mean(rec.data.values**2))
        rms_after = np.sqrt(np.mean(filtered.data.values**2))
        threshold = 8 * rms_after / rms_before
    return _detect_suppressions(filtered, threshold=threshold)


class SuppressionAnalyzer:
    """Detects isoelectric- and α-suppressions in a Recording."""

    def __init__(self, recording: Recording):
        self.recording = recording
        self._ies_detections = None
        self._alpha_detections = None
        self._ies_mask = None

    def detect_ies(self, **kwargs):
        self._ies_mask = _detect_suppressions(self.recording, **kwargs)
        intervals = mask_to_intervals(self._ies_mask, self.recording.data.index)
        detections = [{'start': start,
                       'end': end,
                       'channel': None,
                       'description': 'IES'}
                      for start, end in intervals]
        self._ies_detections = pd.DataFrame(detections)
        return self._ies_detections

    def detect_alpha_suppressions(
            self,
            channels: Sequence = None,
            frequency_band: Tuple[float, float] = (8., 16.),
            threshold: float = None
    ):
        if self._ies_mask is None:
            self._ies_mask = _detect_suppressions(
                self.recording, min_duration=2.5)

        rec = self.recording.copy()
        rec.data[self._ies_mask] = np.nan
        alpha_mask = _detect_alpha_suppressions(
            rec, channels, frequency_band, threshold)

        intervals = mask_to_intervals(alpha_mask, self.recording.data.index)
        detections = [{'start': start,
                       'end': end,
                       'channel': None,
                       'description': 'alpha_suppression'}
                      for start, end in intervals]

        self._alpha_detections = pd.DataFrame(detections)

        return self._alpha_detections




#### auto threshold

In [None]:
def automatic_threshold(recording, quantile):
    threshold = np.quantile(recording.data.loc[:, :].abs().values, quantile)
    pred = suppressions._detect_suppressions(recording, threshold=threshold)
    return pred, threshold

#### Evaluating performance

In [None]:
seg1 = read_edf('data/EEG_BM65_ies_as.edf')
ref1 = pd.read_csv('data/BM65_ies_as.csv')
rec = seg1.copy()
filtered = bandpass(rec, (8, 16))
analyzer = SuppressionAnalyzer(filtered)
ies = analyzer.detect_ies(channels=['EEG L1(Fp1)', 'EEG R1(Fp2)'], min_duration=2.5)

In [None]:
# quantiles = np.arange(0.9, 0.99, 0.01)
minVal = rec.data.loc[:, :].abs().values.min()
maxVal = rec.data.loc[:, :].abs().values.max()

thresholds = np.linspace(2, 10, 20)
perf = pd.DataFrame(columns=['accuracy','sensitivity','specificity','precision','fnr','fpr', 'threshold'])

for th in thresholds:
#    pred, th = automatic_threshold(filtered, quantile)
#    detection = analyzer._detect_alpha_suppressions(threshold=quantile)
#    pred = intervals_to_mask(detection.loc[:,['start', 'end']].values, seg1.data.index)
    pred = _detect_alpha_suppressions(filtered, threshold=th)
    pred[analyzer._ies_mask] = False
    perf = perf.append(_calculate_perf_measures(pred, ref1.loc[:, 'as'], th), ignore_index=True)
pd.set_option('display.max_rows', None)
perf

In [None]:
pred = analyzer.detect_alpha_suppressions()
#_plot_suppression(seg1.data, pred, ref1, 'EEG R1(Fp2)','as')
len(pred)