# Detection of Bruxism events in Tinnitus patients polysomnographic data
This notebook will 
- load EMG channels of polusomnographic data
- detect the EMG bursts in a unsupervised way
- classify EMG bursts as different bruxism events
- give insights on the bruxism events

In [1]:
import os
PATH = os.getcwd() 
import sys
sys.path.append(PATH + '/../')
import matplotlib.pyplot as plt
%matplotlib widget
import numpy as np
import mne
import scipy
import seaborn as sns
from tinnsleep.config import Config
from tinnsleep.data import CreateRaw, RawToEpochs_sliding, CleanAnnotations, AnnotateRaw_sliding
from tinnsleep.classification import AmplitudeThresholding
from tinnsleep.check_impedance import create_annotation_mne, Impedance_thresholding_sliding, check_RMS, fuse_with_classif_result
from tinnsleep.signal import rms
from tinnsleep.visualization import plotTimeSeries, plotAnnotations, zoom_effect
from IPython.core.display import display
from ipywidgets import widgets
print("Config loaded")

print(Config.bruxisme_files)

Config loaded
['/Users/louis/Data/SIOPI/bruxisme/1BA07_nuit_hab.edf', '/Users/louis/Data/SIOPI/bruxisme/1DA15_nuit_hab.edf', '/Users/louis/Data/SIOPI/bruxisme/1GB19_nuit_hab.edf', '/Users/louis/Data/SIOPI/bruxisme/1RA17_nuit_hab.edf']


## Load, filter, and prepare data

In [2]:
filename = Config.bruxisme_files[0]  # load file from config
picks_chan = ['1', '2']           # subset of EMG electrodes

raw  = mne.io.read_raw_edf(filename, preload=False)  # prepare loading
tmin = raw.times[0]                     
tmax = raw.times[-1]

croptimes=dict(tmin=raw.times[0]+3600*2, tmax=raw.times[-1]-3600)
raw.crop(**croptimes)

raw  = CreateRaw(raw[picks_chan][0], picks_chan, ch_types=['emg'])        # pick channels and load

raw  = raw.filter(20., 99., n_jobs=4, 
                  fir_design='firwin', filter_length='auto', phase='zero-double',
                  picks=picks_chan)
ch_names = raw.info["ch_names"]
print("Data filtered")

offset = raw.times[0]
print(f"keeping {(raw.times[-1]-raw.times[0])/3600:0.2f} hours of recording out of {(tmax-tmin)/3600:0.2f} hours")

Extracting EDF parameters from /Users/louis/Data/SIOPI/bruxisme/1BA07_nuit_hab.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


['Inductance Abdom', 'Inductance Thora', 'Intensit? lumine', 'Jambe droite Imp', 'Jambe gauche Imp', 'Tension (aliment', 'Tension (Bluetoo']
  raw  = mne.io.read_raw_edf(filename, preload=False)  # prepare loading


Data filtered
keeping 7.28 hours of recording out of 8.83 hours


## Epoching data

In [3]:
sfreq = raw.info["sfreq"]
window_length = 0.25                    # in seconds
duration = int(window_length * sfreq)   # in samples
interval = duration                     # no overlapping
epochs = RawToEpochs_sliding(raw, duration=duration, interval=interval)
print(f"Epochs done, shape {epochs.shape}")

Epochs done, shape (104900, 2, 50)


## Get the impedance & artefacts annotations

In [4]:
# Value of the impedance threshold
THR_imp = 6000

raw_imp  = mne.io.read_raw_edf(filename, preload=False)  # prepare loading
ch_names = raw_imp.info["ch_names"]
picks_chan = [ch_names[1],ch_names[5]]
print(picks_chan)

croptimes=dict(tmin=raw_imp.times[0]+3600*2, tmax=raw_imp.times[-1]-3600)
raw_imp.crop(**croptimes)

#Get the table of bad electrodes booleans from the impedance thresholding algo
check_imp  = Impedance_thresholding_sliding(raw_imp[picks_chan][0], duration, interval,THR_imp) 

# convert to labels per epoch
impedance_labels = np.any(check_imp, axis=-1)
print(f"rejected impedances for {np.sum(impedance_labels)} epochs out of {len(impedance_labels)} ({np.sum(impedance_labels)/len(impedance_labels)*100:.2f}%)")

# Epoch rejection based on |min-max| thresholding 
from tinnsleep.signal import is_good_epochs
params = dict(ch_names=raw.info["ch_names"],
             rejection_thresholds=dict(emg=5e-04), # two order of magnitude higher q0.01
             flat_thresholds=dict(emg=1e-09),    # one order of magnitude lower median
             channel_type_idx=dict(emg=[0, 1]),
             full_report=True
            )
amplitude_labels, bad_lists = is_good_epochs(epochs, **params)
print(f"good amplitudes for {np.sum(amplitude_labels)} epochs out of {len(amplitude_labels)} ({np.sum(amplitude_labels)/len(amplitude_labels)*100:.2f}%)")

# Merge is_good and amplitude
# Logical OR
valid_labels = np.all(np.c_[np.invert(impedance_labels), amplitude_labels], axis=-1)

print(np.sum(valid_labels))
print(f"good epochs for {np.sum(valid_labels)} epochs out of {len(valid_labels)} ({np.sum(valid_labels)/len(valid_labels)*100:.2f}%)")
print(len(valid_labels))

dict_annotations_artefacts = {1: "artefact"}
annotations_artefacts = []
for k, label in enumerate(np.invert(valid_labels)):
    if label > 0:
        annotations_artefacts.append(dict(
            onset=k*interval/sfreq,
            duration=duration/sfreq,
            description=dict_annotations_artefacts[label],
            orig_time=offset
        )
            
        )

Extracting EDF parameters from /Users/louis/Data/SIOPI/bruxisme/1BA07_nuit_hab.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
['1 Imp?dance', '2 Imp?dance']


['Inductance Abdom', 'Inductance Thora', 'Intensit? lumine', 'Jambe droite Imp', 'Jambe gauche Imp', 'Tension (aliment', 'Tension (Bluetoo']
  raw_imp  = mne.io.read_raw_edf(filename, preload=False)  # prepare loading


rejected impedances for 5 epochs out of 104900 (0.00%)
good amplitudes for 104713 epochs out of 104900 (99.82%)
104708
good epochs for 104708 epochs out of 104900 (99.82%)
104900


## Classifying bursts with different adaptive length

In [5]:
from tinnsleep.scoring import classif_to_burst, burst_to_episode

def pipeline2annotations(X, pipeline):
    labels   = pipeline.fit_predict(X)
    labels   = fuse_with_classif_result(np.invert(valid_labels), labels) # add the missing labels removed with artefacts
    print(f"bursts count ({pipeline.n_adaptive}): {np.sum(labels)}/{len(labels)} ({np.sum(labels) / len(labels) * 100:.2f}%)")
    print(f"bursts time ({pipeline.n_adaptive}): {np.sum(labels) * window_length} seconds")
    bursts = classif_to_burst(labels, time_interval=window_length)
    annotations_episodes = [episode.generate_annotation(orig_time=offset) for episode in burst_to_episode(bursts)]
    print(annotations_episodes[:3])
    return annotations_episodes

In [6]:
X        = rms(epochs[valid_labels]) # take only valid labels

# without adaptive
annotations_episodes = []
for n_adaptive in [0, 120, 240, 480]:
    annotations_episodes.append(pipeline2annotations(X, AmplitudeThresholding(abs_threshold=0., rel_threshold=2, n_adaptive=n_adaptive)))


bursts count (0): 7136/104900 (6.80%)
bursts time (0): 1784.0 seconds
[{'onset': 58.75, 'duration': 5.0, 'description': 'Phasic', 'orig_time': 0.0}, {'onset': 136.75, 'duration': 10.5, 'description': 'Mixed', 'orig_time': 0.0}, {'onset': 150.25, 'duration': 32.75, 'description': 'Phasic', 'orig_time': 0.0}]
bursts count (120): 3079/104900 (2.94%)
bursts time (120): 769.75 seconds
[{'onset': 866.5, 'duration': 2.75, 'description': 'Phasic', 'orig_time': 0.0}, {'onset': 965.25, 'duration': 6.5, 'description': 'Phasic', 'orig_time': 0.0}, {'onset': 1034.5, 'duration': 8.75, 'description': 'Phasic', 'orig_time': 0.0}]
bursts count (240): 3525/104900 (3.36%)
bursts time (240): 881.25 seconds
[{'onset': 58.75, 'duration': 3.0, 'description': 'Phasic', 'orig_time': 0.0}, {'onset': 140.25, 'duration': 2.0, 'description': 'Phasic', 'orig_time': 0.0}, {'onset': 965.25, 'duration': 6.5, 'description': 'Mixed', 'orig_time': 0.0}]
bursts count (480): 3953/104900 (3.77%)
bursts time (480): 988.25 se

## Display Episode for different adaptive thresholds

In [7]:
def plotall(raw, ax1, annotations, annotations_artefacts, plotargs):
    plotTimeSeries(raw.get_data().T, ax=ax1,**plotargs)
    plotAnnotations(annotations_artefacts, ax=ax1, color="red")
    plotAnnotations(annotations, ax=ax1, text_prop=dict(color="green"), color="green")
    return ax1

def set_xlim_all(xmin, xmax, *args):
    for arg in args:
        arg.set_xlim(xmin, xmax)

In [8]:
plt.close("all")
# decimate signal to make it more readible 
raw_ds = raw.copy().resample(100)

%matplotlib widget
scalings=5e-5
plotargs = dict(sfreq=raw_ds.info["sfreq"], scalings=scalings, offset=offset, linewidth=0.5)
plt.figure()
axes = []
for k, annotations in enumerate(annotations_episodes):
    axes.append(plt.subplot(len(annotations_episodes),1,k+1))
    plotall(raw_ds, axes[-1], annotations, annotations_artefacts, plotargs)
set_xlim_all(10000,10200,*axes)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
set_xlim_all(8175,8500,*axes)