# Exercises

import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import numpy as np
import sys
sys.path.append('../../AchimBrinkop/Neuro-Analysis/01-local_field_potential/code/')
import lfp_functions as lf

sns.set_theme(context='notebook',style='white',font_scale=1.5,
              rc = {'axes.spines.top':False,'axes.spines.right':False,
                     'image.cmap':plt.cm.jet})

from scipy.signal import butter
from scipy.signal import sosfilt

from ipywidgets import interactive
import ipywidgets as widgets

import math
from timeit import default_timer as timer

from pywt import scale2frequency
from pywt import cwt

from scipy.stats import zscore

### FPB comments 
*you haven't included the "code" directory here, so I had to get it somewhere else*

## Exercise 1: explore filter parameters
---
In this exercises you will explore how different parameters affect the filtering process.
- Import the data form `data/moving_lfp.pickle`
- Take a few seconds of the data, and filter them in different frequency bands. Plot the resulting singal.
  Try and fix a central frequency and play with the width of the band, then try to change the central frequency of the band.
- Try to change the order of the filter, for a fixed frequency band. How does the filtered signal change? How does the computing time change? (You can try to quantify the computing time with the jupyter magic commeand [`%timeit`](https://docs.python.org/3/library/timeit.html) )
- Comment on what you see

In [4]:
# Load data
with open('../../AchimBrinkop/Neuro-Analysis/01-local_field_potential/data/moving_lfp.pickle', 'rb') as handle:
    lfp_data = pickle.load(handle)

sampling_rate = lfp_data['sampling_rate']
signal_start, signal_end = 0, 2
lfp = lfp_data['lfp'][math.floor(signal_start * sampling_rate):math.ceil(signal_end * sampling_rate)]

# Adapted from local-field-potention.ipynb
def plot_filtered_signal(center, width, order):
    # Determine the frequency band, taking into concern a minimum of 1 Hz and a maximum of sampling_rate / 2 - 1 Hz
    band = [math.floor(max(center - width / 2, 1)), math.ceil(min(center + width / 2, sampling_rate / 2 - 1))]

    filter = butter(
        order,
        band,
        btype='band', 
        output='sos', 
        fs=sampling_rate
    )

    timer_start = timer()

    # Apply filter to lfp       
    wideband_lfp = sosfilt(filter, lfp)

    timer_end = timer()

    print(f'Frequency band: {band}')
    print(f'Processing time: {(timer_end - timer_start) * 1000}ms')

    plt.figure(figsize=(10,5))
    X = np.linspace(signal_start, signal_end, math.ceil((signal_end - signal_start) * sampling_rate))
    plt.plot(X, lfp, label='raw signal')
    plt.plot(X, wideband_lfp, label = 'wideband filtered')
    plt.xlabel('time (s)')
    plt.ylabel('voltage (mV)')

interactive_plot = interactive(
    plot_filtered_signal, 
    center=widgets.IntSlider(min=1, max=math.floor(sampling_rate / 2) - 1, value=100),
    width=widgets.IntSlider(min=1, max=math.floor(sampling_rate / 2) - 1, value=50),
    order=widgets.IntSlider(min=1, max=99, value=5),
)

# Adapted from https://ipywidgets.readthedocs.io/en/7.6.2/examples/Using%20Interact.html
output = interactive_plot.children[-1]
output.layout.height = '600px'
interactive_plot

interactive(children=(IntSlider(value=100, description='center', max=499, min=1), IntSlider(value=50, descript…

When changing the values for the center and width of the frequency band, a few observations can be made. First of all, increasing the band width tends to increase the complexity of the filtered signal. This makes sense, as this increases the number of frequencies that contribute to this signal. Another observation, made when keeping the band width low (around 1 or 2) and changing the frequency is that the lower frequencies tend to contribute more to the signal than higher frequencies (i.e. they have a higher amplitude). Increasing the order of the filter while keeping the frequency band fixed increases the processing time seemingly linearly, from  0.09ms at an order of 1, 0.30ms at 50, and 0.59ms at 99. Additionally it seems to delay the onset and phase of filtered signal. Increasing the order to very high values seems to make the filter instable, as some random combinations of band center and width can make the filtered signal grow to many times the amplitude of the raw signal, or completely vanish.

### FPB Comments
very nicely done and the sliders are cool. We should have told you to use `sosfiltfilt` instead of `sosfilt` in order to avoid the order and passband-dependent phase shift, that would make the comparisons a little smoother. Grade: 9

## Exercise 2: wavelet transform with real wavelets
---
In this exercises you will explore the difference between real and complex wavelets in the continous wavelet transform.
- Import the data form `data/moving_lfp.pickle`
- Select a period of a few seconds of the signal and compute the spectrogram. 
  Try the wavelet transform with wavelet `mexh` (mexican hat),`gaus` (gaussian) and `morl` (morlet). this are all real wavelet.
- How does the spectrum look different from what we computed with a complex Morlet tranform ? Comment on why do you think this might be the case.

In [6]:
# Load data
with open('../../AchimBrinkop/Neuro-Analysis/01-local_field_potential/data/moving_lfp.pickle', 'rb') as handle:
    lfp_data = pickle.load(handle)

sampling_rate = lfp_data['sampling_rate']
signal_start, signal_end = 0, 2
lfp = lfp_data['lfp']
signal = lfp[math.floor(signal_start * sampling_rate):math.ceil(signal_end * sampling_rate)]
signal = signal - np.mean(signal) # get rid of 0 frequency component

# Adapted from local-field-potention.ipynb
def plot_wavelet(wavelet):
    frequencies =  np.linspace(0.1, 20, 100) / sampling_rate
    scales = scale2frequency(wavelet, frequencies)
    cwtmatr, _ = cwt(signal, wavelet=wavelet, scales=scales,
                    sampling_period=1.0/sampling_rate)
    plt.figure(figsize=(10,5))
    plt.imshow(abs(cwtmatr)**2, extent=[signal_start, signal_end, frequencies[-1] * sampling_rate,
    frequencies[0] * sampling_rate], aspect='auto', cmap='viridis')

interactive_plot = interactive(
    plot_wavelet, 
    wavelet=widgets.Dropdown(
        options=[
            ('Complex Morlet', 'cmor1.0-1.0'), 
            ('Mexican hat', 'mexh'), 
            ('Gaussian', 'gaus1'), 
            ('Morlet', 'morl')
        ]
    ),
)

# Adapted from https://ipywidgets.readthedocs.io/en/7.6.2/examples/Using%20Interact.html
output = interactive_plot.children[-1]
output.layout.height = '450px'
interactive_plot

interactive(children=(Dropdown(description='wavelet', options=(('Complex Morlet', 'cmor1.0-1.0'), ('Mexican ha…

Using real wavelets for the transform creates an noticeable alternating pattern on the X-axis (in the temporal domain), which the complex Morlet wavelet does not. I suspect this might be the case due to real wavelets being sensitive to the phase of the frequencies contributing to the signal, while the complex wavelets take into account the envelope of these frequencies without being affected by the phase.

### FPB Comments 
Code is OK. Please add axis labels and colorbar, comments are correct. Grade: 8

## Exercise 3: REM sleep detection
---
Mammalian sleep is organized in different phases. [Rapid Eye Movement (REM) sleep](https://en.wikipedia.org/wiki/Rapid_eye_movement_sleep) is a well studied phase, with clear markers that make it detectable from LFP data. One of the most used marker is the ratio (or difference) between the oscillatory power in the theta and delta bands in the hippocampus. During REM sleep, the former dominates, in a pattern that resembles active behaviour.  
In this exercise we will use the tools we learned about to detect and inspect the REM and non-REM (nREM) phases of sleep in data recorded from the hippocampus of a sleeping mouse.

- Import the data form `data/moving_lfp.pickle`
- Compute the instantaneous power in the theta (6-10 Hz) and delta (2-4 Hz) frequency bands, for the whole signal.
- z-score the two power signals (subtract the mean and divide by the standard deviation, you can use `scipy.stats.zscore()`)
- Compute and plot the power difference (`theta_power - delta_power`) for the sleep session.
- Compute the REM (power difference $>0.2$) and nREM (power diffewrence $<0.2$) sleep periods.
- Plot examples of REM vs nREM lfp signals (wideband filtered betw. 1 and 200 Hz to get rid of some of the noise).
- Explore how changing the threshold affects the amount of REM sleep detected.

In [7]:
# Load data
with open('../../AchimBrinkop/Neuro-Analysis/01-local_field_potential/data/moving_lfp.pickle', 'rb') as handle:
    lfp_data = pickle.load(handle)

sampling_rate = lfp_data['sampling_rate']
signal_start, signal_end = 0, 10
lfp = lfp_data['lfp']
signal = lfp[math.floor(signal_start * sampling_rate):math.ceil(signal_end * sampling_rate)]
signal = signal - np.mean(signal) # get rid of 0 frequency component

# Adapted from local-field-potention.ipynb
def plot_rem(threshold):
    bands = {
        'theta': [6, 10],
        'delta': [2, 4]
    }
    powers_zscored = {}

    for name, band in bands.items():
        frequencies =  np.linspace(band[0], band[1], 50) / sampling_rate
        wavelet = 'cmor1.0-1.0'
        scales = scale2frequency(wavelet, frequencies)
        cwtmatr, _ = cwt(signal, wavelet=wavelet, scales=scales,
                        sampling_period=1.0/sampling_rate)

        power = np.mean(abs(cwtmatr)**2,axis=0)
        powers_zscored[name] = zscore(power)

    power_diff = powers_zscored['theta'] - powers_zscored['delta']

    # Plot the power difference
    plt.figure(figsize=(10,5))
    plt.plot(np.linspace(signal_start, signal_end, len(power_diff)), power_diff)
    plt.axhline(threshold, linestyle='dashed', color='orange', label='threshold')
    plt.xlabel('time (s)')
    plt.ylabel('power difference (a.u.)')
    plt.legend()

    # Compute REM and nREM periods
    rem_periods = period_idxs(power_diff > threshold)
    nrem_periods = period_idxs(power_diff < threshold)

    lfp_filtered = lf.bandpass_filter(lfp, low_f=1, high_f=200, sampling_rate=sampling_rate)
    signal_samples = (signal_end - signal_start) * sampling_rate

    rem_pct = sum([len(period) for period in rem_periods]) / signal_samples
    print(f'Found {len(rem_periods)} REM periods, totalling {rem_pct:.2%} of the signal')
    plot_examples(lfp_filtered, rem_periods[:2], 'REM')

    nrem_pct = sum([len(period) for period in nrem_periods]) / signal_samples
    print(f'Found {len(nrem_periods)} nREM periods, totalling {nrem_pct:.2%} of the signal')
    plot_examples(lfp_filtered, nrem_periods[:2], 'nREM')

# Adapted from local-field-potention.ipynb
def period_idxs(period_mask):
    '''
    Takes logical period mask and returns a list of arrays. Each array correspond to a period
    and contains its time idxs (relative to the provided period mask)
    '''
    period_starts = []
    period_ends = []
    for i in range(1, len(period_mask)):
        if not period_mask[i - 1] and period_mask[i]:
            period_starts.append(i)
        if period_mask[i - 1] and not period_mask[i]:
            period_ends.append(i)

    # handle edge cases
    if not len(period_ends) == 0 and (len(period_starts) == 0 or period_ends[0] < period_starts[0]):
        period_starts = [0] + period_starts  # if session starts with a period

    if not len(period_starts) == 0 and (len(period_ends) == 0 or period_starts[-1] > period_ends[-1]):
        period_ends.append(len(period_mask))  # if session ends with a period

    if len(period_starts) == 0 and len(period_ends) == 0 and len(period_mask) != 0 and period_mask[0]:
        period_starts = [0] # if full session is one period
        period_ends = [len(period_mask)] 

    periods = [np.arange(period_starts[i], period_ends[i])
               for i in range(len(period_starts))]

    return periods

def plot_examples(lfp_filtered, periods, title):
    fig = plt.figure(figsize=(len(periods) * 5, 5))
    fig.suptitle(f'Examples of {title} periods')

    for i, period in enumerate(periods):
        start, end = period[0], period[-1]
        
        length = end - start
        plot_margin_start, plot_margin_end = min(start, 30), min(len(lfp_filtered - end), 30)

        plt.subplot(1, len(periods), i + 1)
        x = np.arange(-plot_margin_start, length + plot_margin_end)
        y = lfp_filtered[start - plot_margin_start:end + plot_margin_end]
        plt.plot(x, y)
        plt.axvline(x=0, linestyle='--', c='green', label='detected onset')
        plt.axvline(x=length, linestyle='--', c='red', label='detected end')
        plt.xlabel('time from onset (ms)')
        plt.ylabel('lfp')
        plt.legend()
        plt.tight_layout()

interactive_plot = interactive(
    plot_rem, 
    threshold=widgets.FloatSlider(min=-6, max=4, value=0.2),
)

# Adapted from https://ipywidgets.readthedocs.io/en/7.6.2/examples/Using%20Interact.html
output = interactive_plot.children[-1]
output.layout.height = '600px'
interactive_plot

interactive(children=(FloatSlider(value=0.2, description='threshold', max=4.0, min=-6.0), Output(layout=Layout…

Increasing the threshold leads to a smaller portion of the signal being detected as REM sleep, while decreasing the threshold leads to a larger portion of the signal being detected as REM sleep. With a threshold of -2, 88% of the signal is categorized as REM sleep, in contrast with only 7% at a threshold of 2.

### FPB comments
Code is correct, we should have suggested you to operate a smoothing on the statistics, as NREM and REM periods typically are many seconds. Here you are mostly looking at fluctuations, but this is on us. Grade: 9