In [1]:
# Only run this once
import os
os.chdir('..')

In [59]:
import numpy as np
import scipy
from scipy import signal
import mne
import glob
from sklearn.decomposition import PCA

import utils.logger
from utils.experiments_classification import classify_nusvm_cross_valid, classify_nusvm_param_pca_seach

In [60]:
import importlib
importlib.reload(utils.experiments_classification)


<module 'utils.experiments_classification' from 'D:\\etc\\uni\\yr5\\project\\workspace\\eeg-cnp-final-year-project\\utils\\experiments_classification.py'>

#### Calculate the band power over a time series

In [4]:
# Calculate the band power over a time series

def bandpower(x, fs, fmin, fmax):
    """
    Returns the band power over the specified frequency interval
    
    x - input time series (1D array)
    fs - sampling frequency
    fmin - min frequency
    fmax - max frequency
    """
    
    f, Pxx = signal.periodogram(x, fs=fs)
    ind_min = scipy.argmax(f > fmin) - 1
    ind_max = scipy.argmax(f > fmax) - 1
    return scipy.trapz(Pxx[ind_min: ind_max], f[ind_min: ind_max])

# Apply bandpower to datasets

In [24]:
def get_datasets(patient_type_location, recording_type_expression):
    """
    Returns relevant datasets (f.e. all right-hand recordings of patients with pain) as a list of np arrays
    First parameter should be a regex for location, second parameter should be a regex for dataset type
    E.g. if right-hand movement datasets of patients with pain are in /data/pp/ and their file names contain '_RH_'
    then patient_type_location=/data/pp/ and recording_type_expression='_RH_'
    """
    
    # Find locations of matching dataset files
    if recording_type_expression != l_new:
        sets_locations = glob.glob(patient_type_location + recording_type_expression + suffix)
    else:
        # For the newer (PDP/PP) dataset we had to use a separate expression with includes the file extension
        sets_locations = glob.glob(patient_type_location + recording_type_expression)
    
    sets = []
    for path in sets_locations: 
        sets.append(mne.io.read_epochs_eeglab(path))
        
    return np.array(np.array([(patient._data) for patient in sets]))


def get_channel_names(patient_type_location, recording_type_expression):
    '''
    Returns the list of channel names in order
    Only works if ALL chosen datasets use the same channels
    '''
    if recording_type_expression != l_new:
        sets_locations = glob.glob(patient_type_location + recording_type_expression + suffix)
    else:
        sets_locations = glob.glob(patient_type_location + recording_type_expression)
        
    return mne.io.read_epochs_eeglab(sets_locations[0]).ch_names
    

In [6]:
# Calculate bandpower for all channels for a patient
bands = [(4, 8), (8, 13), (13, 30)]
time_series_index = range(1250)[:]

def channels_bandpower(channels, bands, fs=250):
    b = bandpower
    return np.array(list(map( lambda arr: [b(arr[time_series_index], fs, band[0], band[1]) for band in bands], channels)))

#### Define dataset locations and expressions

In [28]:
root = './../../'
suffix = '*.set'

# Old (PP/PNP datasets)
location_healthy = root + 'data/raw/HV/*/'
location_pain = root + 'data/raw/PP/*/'
location_nopain = root + 'data/raw/PnP/*/'

# New (PDP/PNP datasets)
location_pwp = root + 'data_new/raw/PwP/*/'
location_pdp = root + 'data_new/raw/PdP/*/'
location_pnp = root + 'data_new/raw/PnP/*/'

rh = {'name': 'RH', 'extension': '*_RH*'}
lh = {'name': 'LH', 'extension': '*_LH*'}
l_new = {'name': 'L', 'extension': '*_L.set'}   # NO SUFFIX
l_old = {'name': 'L', 'extension': '*_L_*'}


# As an example, get paths of all PP/PNP datasets from right-hand movements
sets_healthy_rh = glob.glob(location_pdp + l_new['extension'])
sets_healthy_rh

['./../../data_new/raw/PdP\\PdP_3\\PdP_3_L.set',
 './../../data_new/raw/PdP\\PdP_4\\PdP_4_L.set',
 './../../data_new/raw/PdP\\PdP_5\\PdP_5_L.set',
 './../../data_new/raw/PdP\\PdP_6\\PdP_6_L.set',
 './../../data_new/raw/PdP\\PdP_7\\PdP_7_L.set']

#### Now read the chosen datasets

In [29]:
limb = l_old

pp_rh_raw = get_datasets(location_pain, limb['extension'])
pnp_rh_raw = get_datasets(location_nopain, limb['extension'])

Extracting parameters from ./../../data/raw/PP\PP1\PP1_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


57 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP10\PP10_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


59 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP11\PP11_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


59 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP2\PP2_F1_L_Removed_ICA.set...
51 matching events found


  sets.append(mne.io.read_epochs_eeglab(path))


No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP3\PP3_F1_L_Removed_ICA.set...
52 matching events found


  sets.append(mne.io.read_epochs_eeglab(path))


No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP4\PP4_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


57 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP5\PP5_F1_L_Removed_ICA.set...
55 matching events found


  sets.append(mne.io.read_epochs_eeglab(path))


No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP6\PP6_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


32 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP7\PP7_F1_L_Removed_ICA.set...
52 matching events found


  sets.append(mne.io.read_epochs_eeglab(path))


No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PP\PP9\PP9_F1_L_Removed_ICA.set...
54 matching events found


  sets.append(mne.io.read_epochs_eeglab(path))


No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP1\PnP1_F1_L_Removed_ICA.set...
50 matching events found


  sets.append(mne.io.read_epochs_eeglab(path))


No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP2\PnP2_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


57 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP3\PnP3_F1_L_Removed_ICA.set...
41 matching events found


  sets.append(mne.io.read_epochs_eeglab(path))


No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP4\PnP4_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


58 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP5\PnP5_F1_L_Removed_ICA.set...
50 matching events found


  sets.append(mne.io.read_epochs_eeglab(path))


No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP6\PnP6_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


53 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP7\PnP7_F01_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


58 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP8\PnP8_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


50 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.
Extracting parameters from ./../../data/raw/PnP\PnP9\PnP9_F1_L_Removed_ICA.set...


  sets.append(mne.io.read_epochs_eeglab(path))


55 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.


In [37]:
ch_names = get_channel_names(location_pain, limb['extension'])
ch_names

Extracting parameters from ./../../data/raw/PP\PP1\PP1_F1_L_Removed_ICA.set...
57 matching events found
No baseline correction applied
Not setting metadata
0 projection items activated
Ready.


  return mne.io.read_epochs_eeglab(sets_locations[0]).ch_names


['Cz',
 'C1',
 'C2',
 'C3',
 'C4',
 'C5',
 'C6',
 'C7',
 'C8',
 'FCz',
 'FC1',
 'FC2',
 'FC3',
 'FC4',
 'FC5',
 'FC6',
 'FT7',
 'FT8',
 'CPz',
 'CP1',
 'CP2',
 'CP3',
 'CP4',
 'CP5',
 'CP6',
 'TP7',
 'TP8',
 'Fz',
 'F1',
 'F2',
 'F3',
 'F4',
 'F5',
 'F6',
 'F7',
 'F8',
 'Pz',
 'P1',
 'P2',
 'P3',
 'P4',
 'P5',
 'P6',
 'P7',
 'P8',
 'AFz',
 'AF3',
 'AF4',
 'AF7',
 'AF8',
 'Fp1',
 'FP2',
 'POz',
 'PO3',
 'PO4',
 'PO7',
 'PO8',
 'Oz',
 'O1',
 'O2',
 'Iz']

In [31]:
# The entry for a patient should have shape (n_repetitions, n_channels, n_readings)
pp_rh_raw[4].shape

(52, 61, 1250)

#### Apply the bandpower 

In [32]:
pp_rh_bp = np.array([np.array([channels_bandpower(repetition, bands) for repetition in patient]) for patient in pp_rh_raw])
pnp_rh_bp = np.array([np.array([channels_bandpower(repetition, bands) for repetition in patient]) for patient in pnp_rh_raw])

In [33]:
# Get the total number of repetitions for each class
pp_count = np.vstack(pp_rh_bp).shape[0]
pnp_count = np.vstack(pnp_rh_bp).shape[0]
pnp_count

472

In [34]:
pnp_rh_bp[0].shape

(50, 61, 3)

#### Concatenate the two classes

In [35]:
pp_and_pnp_bp = np.concatenate((pp_rh_bp, pnp_rh_bp))
pp_and_pnp_bp.shape

(19,)

In [36]:
log_proc_method = 'PCA + Bandpower'
log_dataset = 'PP/PNP-RH'
log_db_name = 'log.db'

log_notes = {'pca_components': 3}

# SVM classification

In [None]:
nu = 0.8585
channels = [0, 1, 1, 2, 3, 3, 5, 12, 13, 23, 30, 52, 57]
pca_components=3

acc, sensitivity, specificity, avg_acc = classify_nusvm_cross_valid(pp_rh_bp, pnp_rh_bp, nu, channels,
                                                                    pca_components=pca_components,
                                                                    log_db_name=log_db_name,
                                                                    log_txt=True,
                                                                    log_proc_method=log_proc_method,
                                                                    log_dataset=log_dataset,
                                                                    log_notes=log_notes,
                                                                    log_details=True
                                                                   )
print('Accuracy', acc)
print('Sensitivity', sensitivity)
print('Specificity', specificity)
print('Average accuracy', avg_acc)

In [None]:
previous_channels=[11, 36, 52]
nu = 0.8
pca_components = 3

max_acc = {'index': 0, 'value': 0}
for channel in range(61):        
    accuracy, sensitivity, specificity, avg_accuracy = classify_nusvm_cross_valid(pp_rh_bp, pnp_rh_bp, nu, 
                                                                                  previous_channels + [channel], 
                                                                                  verbose=False, 
                                                                                  pca_components=pca_components,
                                                                                  log_db_name=log_db_name,
                                                                                  log_txt=True,
                                                                                  log_proc_method=log_proc_method,
                                                                                  log_dataset=log_dataset,
                                                                                  log_notes=log_notes
                                                                                 )
    print(channel, accuracy, sensitivity, specificity, avg_accuracy)
        
    if accuracy > max_acc['value']:
        max_acc['index'] = channel
        max_acc['value'] = accuracy

print('Max accuracy:', max_acc['index'], max_acc['value'])

#### Cross validate over multiple nu values

In [None]:
channels = [11, 36, 52]

max_acc = {'index': 0, 'value': 0}
for param in np.arange(0.5, 0.875, 0.01):    
    accuracy, sensitivity, specificity, avg_accuracy = classify_nusvm_cross_valid(pp_rh_bp, pnp_rh_bp, param,
                                                                                  previous_channels + [channel],
                                                                                  verbose=False, 
                                                                                  pca_components=pca_components,
                                                                                  log_db_name=log_db_name,
                                                                                  log_txt=True,
                                                                                  log_proc_method=log_proc_method,
                                                                                  log_dataset=log_dataset,
                                                                                  log_notes=log_notes
                                                                                 ) 
                                                      

    print(param, accuracy, sensitivity, specificity, avg_accuracy)
        
    if accuracy > max_acc['value']:
        max_acc['index'] = param
        max_acc['value'] = accuracy

        
print('Max accuracy:', max_acc['index'], max_acc['value'])

(10,)

In [61]:
max_acc_overall = classify_nusvm_param_pca_seach(pp_rh_bp, pnp_rh_bp, 0.5, 0.6, 0.02, ch_names,
                                                 log_db_name=log_db_name,
                                                 log_txt=True,
                                                 log_proc_method=log_proc_method,
                                                 log_dataset=log_dataset,
                                                 log_notes=log_notes
                                                )

nu: 0.5
[43] 0.58
*******************************
Channels: [43] Components: 1 Max acc: 0.58
*******************************
[43, 13] 0.62
*******************************
Channels: [43, 13] Components: 1 Max acc: 0.62
*******************************
[43, 13, 17] 0.62
*******************************
Channels: [43, 13, 17] Components: 1 Max acc: 0.62
*******************************
[43, 13, 17] 0.62
*******************************
Channels: [43, 13, 17] Components: 1 Max acc: 0.62
*******************************
Current Max Accuracy: {'channels': [43, 13, 17], 'value': 0.621, 'nu': 0.5, 'components': 1}
nu: 0.52
[11] 0.66
*******************************
Channels: [11] Components: 1 Max acc: 0.66
*******************************
[11] 0.66
*******************************
Channels: [11] Components: 1 Max acc: 0.66
*******************************
Current Max Accuracy: {'channels': [11], 'value': 0.658, 'nu': 0.52, 'components': 1}
nu: 0.54
[43] 0.55
*******************************
Channels: [