In [1]:
%matplotlib inline
import os
import glob
import datetime
import seaborn as sns
import numpy as np
import scipy as sp
import pandas as pd
import scipy.io
import numpy.fft
import scipy.signal
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.stats import linregress
from sklearn import linear_model
mpl.rcParams['figure.figsize'] = (16, 10)

In [2]:
def get_filelist(import_path):
    matfiles = []
    for root, dirs, files in os.walk(import_path):
        matfiles += glob.glob(os.path.join(root, '*.mat'))
    return matfiles

def import_subject(subj, i, import_path):
    """ 
    Imports a single subject and adds them to the subj
    data structure. Additionally, merges 
    """
    subj[i] = {}
    datafile = sp.io.loadmat(import_path)
    subj[i]['name'] = str(np.squeeze(datafile['name']))
    subj[i]['srate'] = int(np.squeeze(datafile['srate']))
    subj[i]['events'] = []
    for event in np.squeeze(datafile['evts']):
        subj[i]['events'].append([event[0][0], event[1][0][0], event[2][0][0]])
    subj[i]['data'] = np.squeeze(datafile['data'])
    subj[i]['nbchan'] = len(subj[i]['data'])
    return subj

def _print_window_info(events, port_code):
    evts = [[events[i][1], events[i+1][1]] for i in range(len(events)) if events[i][0] == port_code]
    total_wins = 0
    total_secs = 0
    for e in evts:
        if (e[1] - e[0]) >= 1024:
            pts  = e[1] - e[0]
            secs = (e[1] - e[0])//512
            nwin = (e[1] - e[0])//512 - 1
            total_wins += nwin
            total_secs += secs
            print('Event {}:\t{} points, {} seconds, {} windows'.format(e, pts, secs, nwin))
    print('Total windows able to be extracted: ', total_wins)

def get_windows(data, events, port_code, nperwindow=512*2, noverlap=512):
    windows = []
    # The following line restructures events of type port_code into the 
    # following format:
    #         [start_time, end_time]
    evts = [[events[i][1], events[i+1][1]] for i in range(len(events)) if events[i][0] == port_code]
    for event in evts:
        if event[1]-event[0] >= nperwindow:
            nwindows = (event[1] - event[0])//noverlap - 1
            for i in range(nwindows):
                windows.append(data[event[0] + noverlap*i : event[0] + noverlap*i + nperwindow])
    return windows

def welch(windows, srate):
    """
    Takes a list of data segments (each size 1xN), computes each segment's PSD,
    and averages them to get a final PSD.
    """
    psds = [sp.signal.welch(window, srate, nperseg=len(window), window='hamming')[1] for window in windows]
    return np.mean(psds, axis=0)

def remove_freq_buffer(data, lofreq, hifreq):
    """
    Removes a frequency buffer from a PSD or frequency vector.
    """
    data = np.delete(data, range(lofreq*2, hifreq*2))
    return data.reshape(len(data), 1)

def compute_subject_psds(import_path, import_path_csv):
    """ Returns subj data structure with calculated PSDs and subject information.
    Arguments:
        import_path:     String, path to .mat files
        import_path_csv: String, path to .csv containing subject class, sex, and
                         age information. 
    """
    matfiles = get_filelist(import_path)
    df = pd.read_csv(import_path_csv)
    df.SUBJECT = df.SUBJECT.astype(str)

    subj = {}
    subj['nbsubj'] = len(matfiles)
    subj['f'] = np.linspace(0, 256, 513)
    subj['f'] = subj['f'].reshape(len(subj['f']), 1)
    subj['f_rm_alpha'] = remove_freq_buffer(subj['f'], 7, 14)
    for i in range(len(matfiles)):
        
        subj = import_subject(subj, i, matfiles[i])
        subj[i]['age']   = df[df.SUBJECT == subj[i]['name']].AGE.values[0]
        subj[i]['class'] = df[df.SUBJECT == subj[i]['name']].CLASS.values[0]
        subj[i]['sex']   = df[df.SUBJECT == subj[i]['name']].SEX.values[0]

        for ch in range(subj[i]['nbchan']):
            subj[i][ch] = {}
            windows = get_windows(subj[i]['data'][ch], subj[i]['events'], 'C1')
            subj[i][ch]['psd'] = welch(windows, 512)
            subj[i][ch]['psd_rm_alpha'] = remove_freq_buffer(subj[i][ch]['psd'], 7, 14)
        subj[i]['data'] = np.nan # No longer needed, so clear it from memory
        subj[i]['psd'] = np.mean([subj[i][ch]['psd'] for ch in range(subj[i]['nbchan'])], axis=0)
        subj[i]['psd_rm_alpha'] = remove_freq_buffer(subj[i]['psd'], 7, 14)
        print("Processed: ", subj[i]['name'])
    subj['psd'] = np.mean([subj[i]['psd'] for i in range(subj['nbsubj'])], axis=0)
    return subj

In [3]:
subjoa = np.load('../../data/GoNoGo/2016-10-22-22:43:12.989774/subjoa-2-24-ransac.npy').item()

In [4]:
matfiles = get_filelist('../../data/GoNoGo/oaExclFiltCARClust-mat/')
df = pd.read_csv('../../data/GoNoGo/ya-oa-gng.csv')

In [11]:
subj = {}
datafile = sp.io.loadmat('../../data/GoNoGo/oaExclFiltCARClust-mat/120132101.mat')
subj['name'] = str(np.squeeze(datafile['name']))
subj['srate'] = int(np.squeeze(datafile['srate']))
subj['events'] = []
for event in np.squeeze(datafile['evts']):
    subj['events'].append([event[0][0], event[1][0][0], event[2][0][0]])
subj['data'] = np.squeeze(datafile['data'])
subj['nbchan'] = len(subj['data'])
# return subj

In [14]:
_print_window_info(subj['events'], 'C1')

Event [207252, 210038]:	2786 points, 5 seconds, 4 windows
Event [210243, 211669]:	1426 points, 2 seconds, 1 windows
Event [211874, 214984]:	3110 points, 6 seconds, 5 windows
Event [216851, 218752]:	1901 points, 3 seconds, 2 windows
Event [222591, 223875]:	1284 points, 2 seconds, 1 windows
Event [235627, 236937]:	1310 points, 2 seconds, 1 windows
Event [237202, 238249]:	1047 points, 2 seconds, 1 windows
Event [238454, 239562]:	1108 points, 2 seconds, 1 windows
Event [241102, 242787]:	1685 points, 3 seconds, 2 windows
Event [243275, 245174]:	1899 points, 3 seconds, 2 windows
Event [245632, 247432]:	1800 points, 3 seconds, 2 windows
Event [249558, 251427]:	1869 points, 3 seconds, 2 windows
Event [252053, 253372]:	1319 points, 2 seconds, 1 windows
Event [257785, 258879]:	1094 points, 2 seconds, 1 windows
Event [259084, 260132]:	1048 points, 2 seconds, 1 windows
Event [260588, 263446]:	2858 points, 5 seconds, 4 windows
Event [265523, 267061]:	1538 points, 3 seconds, 2 windows
Event [270006,

In [15]:
subj['events']

[['C1', 207252, 1],
 ['C2', 210038, 2],
 ['C1', 210243, 3],
 ['C2', 211669, 4],
 ['C1', 211874, 5],
 ['C2', 214984, 6],
 ['C1', 216851, 7],
 ['C2', 218752, 8],
 ['C1', 222591, 9],
 ['C2', 223875, 10],
 ['C1', 235627, 11],
 ['C2', 236937, 12],
 ['C1', 237202, 13],
 ['C2', 238249, 14],
 ['C1', 238454, 15],
 ['C2', 239562, 16],
 ['C1', 241102, 17],
 ['C2', 242787, 18],
 ['C1', 243275, 19],
 ['C2', 245174, 20],
 ['C1', 245632, 21],
 ['C2', 247432, 22],
 ['C1', 249558, 23],
 ['C2', 251427, 24],
 ['C1', 252053, 25],
 ['C2', 253372, 26],
 ['C1', 257785, 27],
 ['C2', 258879, 28],
 ['C1', 259084, 29],
 ['C2', 260132, 30],
 ['C1', 260588, 31],
 ['C2', 263446, 32],
 ['C1', 265523, 33],
 ['C2', 267061, 34],
 ['C1', 270006, 35],
 ['C2', 271786, 36],
 ['C1', 274285, 37],
 ['C2', 275695, 38],
 ['C1', 275787, 39],
 ['C2', 277315, 40],
 ['C1', 279576, 41],
 ['C2', 281453, 42],
 ['C1', 286006, 43],
 ['C2', 287383, 44],
 ['C1', 287588, 45],
 ['C2', 288933, 46],
 ['C1', 294729, 47],
 ['C2', 296674, 48],
 