## The basics of visualizing, manipulating and processing EEG data

In this script you will be instroduced to the basics of visualizing, manipulating and processing electroencephalography data.
A python package dedicated to the analysis of MEG and EEG data, MNE-Python is used. 

We will use a dataset recorded during a      task

In [None]:
# import the required packages and load in the data file (in *.bdf)
import mne
import matplotlib.pyplot as plt
import os
import numpy as np

%matplotlib  #magic command to permit interactive plotting in Jupyter

fname = 'sub-001_eeg_sub-001_task-think1_eeg.bdf'
filepath = 'data'
fullpath = os.path.join(filepath, fname)
rawIn = mne.io.read_raw_bdf(fullpath, preload=True)   # Load in the raw data.
rawIn.plot()    # Plot the raw, continuous data.


## Define the channel types and scalings

EXG1 and EXG2: HEOG (horizontal electrooculogram)
EXG3 and EXG4: VEOG (vertical electrooculogram or eye-blinks)
EXG7: ECG (electro-cardiogram)
EXG5: Mastoid 1
EXG6: Mastoid 2
EXG8: Fp1 electrode

Now we will set the auxiliary channels to type MISC, EXG1-6 to eog and the stimulus channel to 'stim'
By doing this each data type is scaled appropriately when visualized. 

In [None]:
my_dict = {'EXG1': 'eog', 'EXG2': 'eog', 'EXG3': 'eog', 'EXG4': 'eog', 'EXG5': 'eog', 'EXG6': 'eog', 'EXG7': 'eog',
           'EXG8': 'eog',
           'GSR1': 'misc', 'GSR2': 'misc', 'Erg1': 'stim', 'Erg2': 'stim', 'Resp': 'resp', 'Plet': 'misc',
           'Temp': 'misc', 'Status': 'stim'}
print(my_dict)
rawIn.set_channel_types(my_dict)  # Apply the channel type to our raw object.
mne.viz.plot_raw(rawIn, duration=5.0, scalings="auto", remove_dc=True)  # Plot again the Channels X Time


## Getting basic dataset information

We will extract some basic information about the dataset by looking at the **rawIn.info** attribute.
This contains the **.info** object, which is a python dictionary. 

In [None]:
print(rawIn.info)  # Visualising all dataset information
print(rawIn.info['sfreq'])  # Display the sampling frequency

## Exercise 1...

Display the following information:
- channel names
- the number and the type of channels
Put the code in the cell below

In [None]:
# Put your code here...

In [None]:
# We can pass the information contained in the .info object to a new variable.

sfreq = rawIn.info['sfreq']  # assign the sampling frequency value to sfreq
print('The sampling frequency of the current dataset is {} Hz.'.format(sfreq))  # using .format() method to print variable.

## Extracting the data from the raw file

Here we extract the data from the rawIn variable. 
The data variable is in the form of a **Channels** X **Time** matrix.
We then get the dimensions of the data extracted.

In [None]:
dataIn = rawIn.get_data()                                                          
dataDims = dataIn.shape                                                            
print(f'The number of channels in the current raw dataset is {dataDims[0]}.\n')    
print(f'The number of time samples in the current raw dataset is {dataDims[1]}.\n')

## Exercise 2

Now that you know the sampling frequency (Hz) and the number of data samples,                            
try to do the following:                                                                                 
- calculate the duration of the data in seconds.                                                         
- create the time vector 

Remember that the sampling frequency tells you how many time samples there are in 1 second of data.      
The sampling frequency can be used to calculate the time step from one sample to the next.               

In [None]:
nSamps =                                                   
tstep =                                                    
T =      # time vector code here.                          
                                                           
print(f'The length of the time vector is { }')             
print(f'The duration of the time vector is { } seconds')   

## Plot individual electrodes

1. Plot a single electrode over the entire duration 

We can use the time vector (T) that we constructed
                                                 
2. Plot a single electrode over a defined time interval
3. Plot several electrodes against each other.   

In [None]:
# Plot a single electrode over all time.                                                   
channelNames = rawIn.info['ch_names']                                                      
chan2plot = 'Cz'                                                                           
chanIndx = channelNames.index(chan2plot)                                                   
print(f'The index of channel {chan2plot} is {chanIndx}.')                                  
                                                                                           
plt.plot(T, dataIn[chanIndx, :])                                                           
plt.xlabel('time (seconds)')                                                               
plt.show()                                                                                                                                                         

In [None]:
# Plot a single electrode over a pre-defined time interval.                                

timeLims = np.array([60, 70])                                                              
lim1, lim2 = (timeLims * sfreq).astype(int)                                                
                                                                                           
plt.plot(T[lim1:lim2], dataIn[chanIndx, lim1:lim2])                                        
plt.xlabel('time (seconds)')                                                               
plt.show()  

In [None]:
# Plot two electrodes against each other                                                   

chans2plot = ['Cz', 'Pz']                                                                  
channelIndices = [i for i, chan in enumerate(channelNames) if chan in chans2plot]          
                                                                                           
fig, ax = plt.subplots()                                                                   
for idx in channelIndices:                                                                 
    ax.plot(T[lim1:lim2], dataIn[idx, lim1:lim2])                                          
ax.set_xlabel('time (seconds)')                                                            
plt.show()  

# Do you notice something odd about the two electrodes? 

In [None]:
# We will plot the raw data again in the Channels X Time format, but this time setting
# the **remove_dc** parameter to False   

mne.viz.plot_raw(rawIn, scalings='auto', remove_dc=False)    

## The DC Component

You notice that none of the electrodes appear to be visible...this is due to what we call the "**DC Offset**".                            
The acquisition system works on battery, so DC, and it captures ALL the frequencies including the OHz.                                    
The OHz is the offset from zero mean.                                                                                                     
So we need to remove this offset; after which our signals will have a zero mean.                                                          
                                                                                                                    There are different ways of removing the DC Offset:                                                                                                                                                                                                                           
We will start by trying to remove the DC offset, by subtracting the mean activity from the activity of one channel.                       
Then we will plot the result.                                                                                                             
So...                                                                                                                                     
- Let's calculate the mean of a few channels.                                                                                             
- What do you notice about the means? How do we know that there is a DC offset?                                                           
                                                                                                    
**Note** also the use of the **copy()** method. We use this to make a copy of the original **RawIn** object.                                   
When we apply a method such as, **.pick_channels**, to a raw object, we change that object. Therefore, the copy() method is very useful.                                                                                                                                                                                                   

In [None]:
# Fill in the missing code...

rawInbis = rawIn.copy()  # creating a copy of the raw object.                        
rawInbis.pick_channels(['Fz', 'Cz', 'Pz'])                                           
dataInbis =   # Get the data of the selected channels.                               
dataMean =                                                                           
dataMean =   # convert dataMean to a list using the .tolist() method                 
                                                                                     
print(f'The mean of each electrode is {dataMean}')                                   
                                                                                     
# Now subtract the mean of each channel for all samples of each channel.             
chan_demean1 = np.subtract(dataInbis[0, :], dataMean[0])                             
chan_demean2 =                                                                       
chan_demean3 =                                                                       

In [None]:
# Now we will visualize the channel data before and after subtracting the mean.

ax1 = plt.subplot(231)       
ax1.margins(0.5)             
ax1.plot(T, dataIn[0,])      
                             
ax2 = plt.subplot(232)       
ax2.margins(0.5)             
ax2.plot(T, dataIn[1,])      
                             
ax3 = plt.subplot(233)       
ax3.margins(0.5)             
ax3.plot(T, dataIn[2,])      
                             
ax4 = plt.subplot(234)       
ax4.margins(0.5)             
ax4.plot(T, chan_demean1)    
                             
ax5 = plt.subplot(235)       
ax5.margins(0.5)             
ax5.plot(T, chan_demean2)    
                             
ax6 = plt.subplot(236)       
ax6.margins(0.5)             
ax6.plot(T, chan_demean3)    

## Removing the DC component via high-pass filtering

The generally applied approach in EEG data processing is to apply a **high-pass filter** to the EEG data.                         
Remember we want to remove the 0Hz, DC component.                                                                             
Here we apply a high-pass filter with a cutoff at 0.1Hz. This means that we remove all frequencies below 0.1Hz and retain all 
frequencies above 0.1Hz.                                                                                                      
We will use the **.filter()** method in MNE.                                                                                      

In [None]:
# Notice again that we make a copy of the rawIn object before applying the .filter() method.
rawInFilt = rawIn.copy().filter(0.1, None,fir_design='firwin')                                                                                     fir_design='firwin')  # Notice that we created a copy of the origin rawIn data before filtering.
                                                                                                                                
mne.viz.plot_raw(rawInFilt, scalings='auto', remove_dc=False)  # Visualize in Channels X Time format.                           

## Re-referencing the data

The data needs to be re-referenced. A reference needs to be chosen.                                                          
The potential measured in microVolts is always measured in relation to the potential at another point, called the reference. 
                                                                                                                             
This means that the activity at each channel is interpreted relative to the potential at a reference.                        
- the reference can be the mean activity of all electrodes.                                                                  
- the average of the two mastoids (generally these reference channels are marked as Ref1, Ref2 or EXG1, EXG2)                
The current dataset does not have the external (EXG) channels, so we will apply an average reference.                                                                                                                                                 
However, we cannot include the bad channels or the VEOG when applying the reference.                                         
We use the *pick_types()* method to exclude these channels when applying the average reference.                              
                                                                                                                  Here we will re-reference in relation to the average of all scalp electrodes, excluding any electrodes that have been marked as being noisy.                                                                                                       
Other options for the reference are the mastoids (M1, M2).                                                                   

In [None]:
rawInRef = rawInFilt.copy().pick_types(eeg=True, exclude=['bads', 'eog']).set_eeg_reference()         
mne.viz.plot_raw(rawInRef, scalings='auto', remove_dc=False)                                          

## Plot the spatial distribution of activity (topography) 

In addition to looking at the signal as a function of time, we can also look at the spatial distribution of the mean activity across all scalp electrodes over a given time interval. This visualization is referred to as a **topography** or a **topographical map**. 

A topography presenting the average activity over a pre-defined time interval can be used to highlight or study activity at a specific time interval. 

So, to plot the topography, for each electrode, we need to calculate the average activity over a pre-defined time-interval. This implies that, while we gain an insight into **where** activity is concentrated at a given time, we lose temporal precision and, therefore, information on **when** there is a change in activity. 

Before we can visualize a topography, we need to define the electrode layout or **montage** that corresponds to the current dataset's configuration. In this example, the electrodes are distributed according to the standard **10-20 layout** . 

In [None]:
# Load the relevant montage and visualize it.

montage = mne.channels.make_standard_montage('standard_1020')  # Assigning the standard 10-20 montage        
mne.viz.plot_montage(mne.channels.make_standard_montage('standard_1020'))  # Visualize the montage           
rawInRef.set_montage(montage)                                                                                

In [None]:
# Plot the topographies

timeIntval = [70, 75]                          # Defining the time interval over which to plot the topography.                               
timeIndx = rawInRef.time_as_index(timeIntval)  # Find the indices of the samples in the defined time interval.       
chanRange = np.arange(0, 64)                   # We take the only the scalp channels.                                           
dataSeg1 = rawInRef.get_data(chanRange, timeIndx[0], timeIndx[1])                                                    
dataSeg_mean = np.mean(dataSeg1, 1)                                                                                  
                                                                                                                     
fig1, ax1 = plt.subplots(1)                                                                                          
mne.viz.plot_topomap(dataSeg_mean, rawInRef.info, ch_type='eeg', axes=ax1)                                           

In [None]:
# We will now plot several topographies over time from 60 to 80 seconds in 5 second steps.
# This visualization allows us to explore the change in the spatial distribution of activity over time.

timeIntvals = np.arange(60, 80, 5)                                                                                 
fig2, axes = plt.subplots(1, len(timeIntvals) - 1, figsize=(15, 5))                                                
for ind in range(len(timeIntvals) - 1):                                                                            
    curr_int = [timeIntvals[ind], timeIntvals[ind + 1]]                                                            
    timeIndx2 = rawInRef.time_as_index(curr_int)                                                                   
    dataSegCurr = rawInRef.get_data(chanRange, timeIndx2[0], timeIndx2[1])                                         
    dataMeanCurr = np.mean(dataSegCurr, 1)                                                                         
                                                                                                                   
    mne.viz.plot_topomap(dataMeanCurr, rawInRef.info, ch_type='eeg', axes=axes[ind])                               
    axes[ind].set_title(str(timeIntvals[ind]) + ' - ' + str(timeIntvals[ind + 1]) + 'seconds', {'fontsize': 20})   
                                                                                                                   