In [None]:
# Only run this once
import os
os.chdir('..')

In [None]:
import numpy as np
import scipy
from scipy import signal
import mne
import glob
from sklearn.decomposition import PCA
import csv
import os

import utils.logger
from utils.experiments_classification import classify_nusvm_cross_valid

In [None]:
import importlib
importlib.reload(utils.experiments_classification)


#### Calculate the band power over a time series

In [None]:
# Calculate the band power over a time series

def bandpower(x, fs, fmin, fmax):
    """
    Returns the band power over the specified frequency interval
    
    x - input time series (1D array)
    fs - sampling frequency
    fmin - min frequency
    fmax - max frequency
    """
    
    f, Pxx = signal.periodogram(x, fs=fs)
    ind_min = scipy.argmax(f > fmin) - 1
    ind_max = scipy.argmax(f > fmax) - 1
    return scipy.trapz(Pxx[ind_min: ind_max], f[ind_min: ind_max])

# Apply bandpower to datasets

In [None]:
def get_datasets(patient_type_location, recording_type_expression):
    """
    Returns relevant datasets (f.e. all right-hand recordings of patients with pain) as a list of np arrays
    First parameter should be a regex for location, second parameter should be a regex for dataset type
    E.g. if right-hand movement datasets of patients with pain are in /data/pp/ and their file names contain '_RH_'
    then patient_type_location=/data/pp/ and recording_type_expression='_RH_'
    """
    
    # Find locations of matching dataset files
    if recording_type_expression != l_new:
        sets_locations = glob.glob(patient_type_location + recording_type_expression + suffix)
    else:
        # For the newer (PDP/PP) dataset we had to use a separate expression with includes the file extension
        sets_locations = glob.glob(patient_type_location + recording_type_expression)
    
    sets = []
    for path in sets_locations: 
        sets.append(mne.io.read_epochs_eeglab(path))
        
    return np.array(np.array([(patient._data) for patient in sets]))

In [None]:
# Calculate bandpower for all channels for a patient
bands = [(4, 8), (8, 13), (13, 30)]
time_series_index = range(1250)[:]

def channels_bandpower(channels, bands, fs=250):
    b = bandpower
    return np.array(list(map( lambda arr: [b(arr[time_series_index], fs, band[0], band[1]) for band in bands], channels)))

#### Define dataset locations and expressions

In [None]:
root = './../../'
suffix = '*.set'

# Old (PP/PNP datasets)
location_healthy = root + 'data/raw/HV/*/'
location_pain = root + 'data/raw/PP/*/'
location_nopain = root + 'data/raw/PnP/*/'

# New (PDP/PNP datasets)
location_pwp = root + 'data_new/raw/PwP/*/'
location_pdp = root + 'data_new/raw/PdP/*/'
location_pnp = root + 'data_new/raw/PnP/*/'

rh = '*_RH*'
lh = '*_LH*'
l_new = '*_L.set'   # NO SUFFIX
l_old = '*_L_*'


# As an example, get paths of all PP/PNP datasets from right-hand movements
sets_healthy_rh = glob.glob(location_pain + rh + suffix)
sets_healthy_rh

#### Now read the chosen datasets

In [None]:
pp_rh_raw = get_datasets(location_pain, rh)
pnp_rh_raw = get_datasets(location_nopain, rh)

In [None]:
# The entry for a patient should have shape (n_repetitions, n_channels, n_readings)
pp_rh_raw[4].shape

#### Apply the bandpower 

In [None]:
pp_rh_bp = np.array([np.array([channels_bandpower(repetition, bands) for repetition in patient]) for patient in pp_rh_raw])
pnp_rh_bp = np.array([np.array([channels_bandpower(repetition, bands) for repetition in patient]) for patient in pnp_rh_raw])

In [None]:
# Get the total number of repetitions for each class
pp_count = np.vstack(pp_rh_bp).shape[0]
pnp_count = np.vstack(pnp_rh_bp).shape[0]
pnp_count

In [None]:
pnp_rh_bp[0].shape

#### Concatenate the two classes

In [None]:
pp_and_pnp_bp = np.concatenate((pp_rh_bp, pnp_rh_bp))
pp_and_pnp_bp.shape

In [None]:
log_proc_method = 'PCA + Bandpower'
log_dataset = 'PP/PNP-RH'
log_notes = 'pca_components=3'
log_db_name = 'log.db'

# SVM classification

In [None]:
nu = 0.8585
channels = [0, 1, 1, 2, 3, 3, 5, 12, 13, 23, 30, 52, 57]
pca_components=3

acc, sensitivity, specificity, avg_acc = classify_nusvm_cross_valid(pp_rh_bp, pnp_rh_bp, nu, channels,
                                                                    pca_components=pca_components,
                                                                    log_db_name=log_db_name,
                                                                    log_proc_method=log_proc_method,
                                                                    log_dataset=log_dataset,
                                                                    log_notes=log_notes
                                                                   )
print('Accuracy', acc)
print('Sensitivity', sensitivity)
print('Specificity', specificity)
print('Average accuracy', avg_acc)

In [None]:
previous_channels=[11, 36, 52]
nu = 0.8
pca_components = 3

max_acc = {'index': 0, 'value': 0}
for channel in range(61):        
    accuracy, sensitivity, specificity, avg_accuracy = classify_nusvm_cross_valid(pp_rh_bp, pnp_rh_bp, nu, 
                                                                                  previous_channels + [channel], 
                                                                                  verbose=False, 
                                                                                  pca_components=pca_components,
                                                                                  log_db_name=log_db_name,
                                                                                  log_proc_method=log_proc_method,
                                                                                  log_dataset=log_dataset,
                                                                                  log_notes=log_notes
                                                                                 )
    print(channel, accuracy, sensitivity, specificity, avg_accuracy)
        
    if accuracy > max_acc['value']:
        max_acc['index'] = channel
        max_acc['value'] = accuracy

print('Max accuracy:', max_acc['index'], max_acc['value'])

#### Cross validate over multiple nu values

In [None]:
channels = [11, 36, 52]

max_acc = {'index': 0, 'value': 0}
for param in np.arange(0.5, 0.875, 0.01):    
    accuracy, sensitivity, specificity, avg_accuracy = classify_nusvm_cross_valid(pp_rh_bp, pnp_rh_bp, nu,
                                                                                  previous_channels + [channel],
                                                                                  verbose=False, 
                                                                                  pca_components=pca_components,
                                                                                  log_db_name=log_db_name,
                                                                                  log_proc_method=log_proc_method,
                                                                                  log_dataset=log_dataset,
                                                                                  log_notes=log_notes
                                                                                 ) 
                                                      

    print(param, accuracy, sensitivity, specificity, avg_accuracy)
        
    if accuracy > max_acc['value']:
        max_acc['index'] = param
        max_acc['value'] = accuracy

        
print('Max accuracy:', max_acc['index'], max_acc['value'])