# Detection of Bruxism events in Tinnitus patients polysomnographic data
This notebook will 
- load EMG channels of polusomnographic data
- detect the EMG bursts in a unsupervised way
- classify EMG bursts as different bruxism events
- give insights on the bruxism events

In [5]:
import os
PATH = os.getcwd() 
import sys
sys.path.append(PATH + '/../')
import matplotlib.pyplot as plt
%matplotlib widget
import numpy as np
import mne
import scipy
import seaborn as sns
from tinnsleep.config import Config
from tinnsleep.data import CreateRaw, RawToEpochs_sliding, CleanAnnotations, AnnotateRaw_sliding
from tinnsleep.classification import AmplitudeThresholding
from tinnsleep.check_impedance import create_annotation_mne, Impedance_thresholding_sliding, check_RMS, fuse_with_classif_result
from tinnsleep.signal import rms
from tinnsleep.visualization import plotTimeSeries, plotAnnotations, zoom_effect

from IPython.core.display import display
from ipywidgets import widgets
print("Config loaded")


Config loaded


## Load, filter, and prepare data

In [2]:
filename = Config.bruxisme_files[0]  # load file from config
picks_chan = ['1', '2']           # subset of EMG electrodes

raw  = mne.io.read_raw_edf(filename, preload=False)  # prepare loading
raw  = CreateRaw(raw[picks_chan][0], picks_chan, ch_types=['emg'])        # pick channels and load
raw  = raw.load_data()  # load data into memory 
print("Data loaded")

raw  = raw.filter(20., 99., n_jobs=4, 
                  fir_design='firwin', filter_length='auto', phase='zero-double',
                  picks=picks_chan)
ch_names = raw.info["ch_names"]
print("Data filtered")

tmin = raw.times[0]                     
tmax = raw.times[-1]

# remove the two first hours and last hour
croptimes=dict(tmin=raw.times[0]+3600*2, tmax=raw.times[-1]-3600)
raw.crop(**croptimes)
offset = raw.times[0]
print(f"keeping {(raw.times[-1]-raw.times[0])/3600:0.2f} hours of recording out of {(tmax-tmin)/3600:0.2f} hours")

Extracting EDF parameters from /Users/louis/Data/SIOPI/bruxisme/1BA07_nuit_hab.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


['Inductance Abdom', 'Inductance Thora', 'Intensit? lumine', 'Jambe droite Imp', 'Jambe gauche Imp', 'Tension (aliment', 'Tension (Bluetoo']
  raw  = mne.io.read_raw_edf(filename, preload=False)  # prepare loading


Data loaded
Data filtered
keeping 8.03 hours of recording out of 11.03 hours


## Epoching data

In [3]:
sfreq = raw.info["sfreq"]
window_length = 0.25                    # in seconds
duration = int(window_length * sfreq)   # in samples
interval = duration                     # no overlapping
epochs = RawToEpochs_sliding(raw, duration=duration, interval=interval)
print(f"Epochs done, shape {epochs.shape}")

Epochs done, shape (115700, 2, 50)


## Get the impedance & artefacts annotations

In [60]:
# Value of the impedance threshold
THR_imp = 5000

raw_imp  = mne.io.read_raw_edf(filename, preload=False)  # prepare loading
ch_names = raw_imp.info["ch_names"]
picks_chan = [ch_names[1],ch_names[5]]
print(picks_chan)

#Get the table of bad electrodes booleans from the impedance thresholding algo
check_imp  = Impedance_thresholding_sliding(raw_imp[picks_chan][0], duration, interval,THR_imp) 
print(check_imp[:3])

# convert to labels per epoch
impedance_labels = np.any(check_imp, axis=-1)
print(f"rejected impedances for {np.sum(impedance_labels)} epochs out of {len(impedance_labels)} ({np.sum(impedance_labels)/len(impedance_labels)*100:.2f}%)")

Extracting EDF parameters from /Users/louis/Data/SIOPI/bruxisme/1BA07_nuit_hab.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
['1 Imp?dance', '2 Imp?dance']


['Inductance Abdom', 'Inductance Thora', 'Intensit? lumine', 'Jambe droite Imp', 'Jambe gauche Imp', 'Tension (aliment', 'Tension (Bluetoo']
  raw_imp  = mne.io.read_raw_edf(filename, preload=False)  # prepare loading


[[False False]
 [False False]
 [False False]]
rejected impedances for 26 epochs out of 158900 (0.02%)


In [52]:
# baseline epochs amplitudes
print(f"mean {np.mean(np.mean(epochs, axis=0), axis=-1)}")
print(f"median {np.median(np.median(epochs, axis=0), axis=-1)}")
print(f"quantile 0.05 {np.quantile(np.quantile(epochs, 0.05, axis=0),0.05, axis=-1)}")
print(f"quantile 0.95 {np.quantile(np.quantile(epochs, 0.95, axis=0),0.95, axis=-1)}")

mean [-9.05867275e-13 -1.18524668e-12]
median [2.65982311e-08 1.10616584e-08]
quantile 0.05 [-6.97165086e-06 -7.10919239e-06]
quantile 0.95 [6.93789249e-06 7.09388268e-06]


In [61]:
# Epoch rejection based on |min-max| thresholding 
from tinnsleep.signal import is_good_epochs
params = dict(ch_names=raw.info["ch_names"],
             rejection_thresholds=dict(emg=1e-04), # two order of magnitude higher q0.01
             flat_thresholds=dict(emg=1e-09),    # one order of magnitude lower median
             channel_type_idx=dict(emg=[0, 1]),
             full_report=True
            )
amplitude_labels, bad_lists = is_good_epochs(epochs, **params)
print(labels[:10])
print(bad_lists[:10])
print(f"good amplitudes for {np.sum(amplitude_labels)} epochs out of {len(amplitude_labels)} ({np.sum(amplitude_labels)/len(amplitude_labels)*100:.2f}%)")

[True, True, True, True, True, True, True, True, True, True]
[None, None, None, None, None, None, None, None, None, None]
good amplitudes for 112938 epochs out of 115700 (97.61%)


In [63]:
# Merge is_good and amplitude
#TODO: THERE IS A MISSMATCH BETWEEN LABEL SIZE
valid_labels = np.any(np.c_[np.invert(impedance_labels),  amplitude_labels], axis=-1) # Logical OR
print(valid_labels[:3])

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 158900 and the array at index 1 has size 115700

## Classifying epochs and annotate raw

In [6]:
# compute the sum of power over electrodes and samples in each window
pipeline = AmplitudeThresholding(abs_threshold=0., rel_threshold=2)
X        = rms(epochs)
labels   = pipeline.fit_predict(X)
print(f"bursts count: {np.sum(labels)}/{len(labels)} ({np.sum(labels) / len(labels) * 100:.2f}%)")
print(f"bursts time: {np.sum(labels) * window_length} seconds")

dict_annotations = {1: "burst"}
annotations = []
for k, label in enumerate(labels):
    if label > 0:
        annotations.append(dict(
            onset=k*interval/sfreq,
            duration=duration/sfreq,
            description=dict_annotations[label],
            orig_time=offset
        )
            
        )

bursts count: 7417/115700 (6.41%)
bursts time: 1854.25 seconds


## Display Annotations

In [7]:
plt.close("all")

# decimate signal to make it more readible 
raw_ds = raw.copy().resample(100)

%matplotlib widget
scalings=1e-4

ax1 = plt.subplot(211)
plotTimeSeries(raw_ds.get_data().T, sfreq=raw_ds.info["sfreq"], ax=ax1, scalings=scalings, offset=offset)
plotAnnotations(annotations, color="red")
ax1.set_xlim(5145,5165)
ax2 = plt.subplot(212)
plotTimeSeries(raw_ds.get_data().T, sfreq=raw_ds.info["sfreq"], ax=ax2, scalings=scalings, offset=offset)
z = zoom_effect(ax1, ax2)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
# update manually the time axis
ax1.set_xlim(5145,5165) # in seconds

### (OPTIONAL) Enable widget

In [14]:
from ipywidgets import interact, FloatSlider
def update_axis(xmin, xmax):
    if xmin<xmax:
        ax1.set_xlim(xmin,xmax)

i=FloatSlider(min=raw.times[0], max=raw.times[-1], step=10, continuous_update=False)
ii=FloatSlider(min=raw.times[0], max=raw.times[-1], step=10, continuous_update=False)
from ipywidgets import FloatSlider
interact(update_axis,xmin=i, xmax=ii);

interactive(children=(FloatSlider(value=0.0, continuous_update=False, description='xmin', max=28924.995, step=…