In [17]:
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.utils import resample
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import cross_validate, StratifiedKFold

In [18]:
def get_json_files(path):
    """Return individual paths to all data and events json files in given directory.
    
    Arguments
    ---------
    path (string/path) : Path to directory where json files are located.
    
    Returns
    -------
    data_ (list) : Sorted list of all json data files
    evs_ (list) : Sorted list of all json events files
    """
    data_ = []
    evs_ = []
    for root, dirs, files in os.walk(path):
        for file in files:
            if file.endswith(".json"):
                if file.startswith('evs'):
                    evs_.append(os.path.join(root, file))
                else:
                    data_.append(os.path.join(root, file))
    data_.sort()
    evs_.sort()
    return data_, evs_

def concat_feats(X):
    """Concatenate data from 3-D array to 2-D array and transpose. Useful for using scikit classifiers.
    
    Arguments
    ---------
    X (numpy array) : Array of shape (n_features, n_channels, n_samples).
    
    Returns
    -------
    Xm (numpy array) : Array of shape (n_samples, n_features*n_channels).
    """
    feats = X.shape[0]
    chans = X.shape[1]
    samps = X.shape[2]
    Xm = np.zeros((feats*chans,samps))
    for chan in range(chans):
        Xm[chan*feats:(chan+1)*feats,:] = X[:,chan,:]
    Xm = Xm.T
    return Xm

def classify(X, labels, lfp_chs, run):
    """Balance labels and classify with shrinkage LDA. Return 10-fold shuffled cross-val. mean-AP and accuracy.
    
    Arguments
    ---------
    X (numpy array) : Array of shape (n_features, n_channels, n_samples). ECOG channels come first, LFP channels last.
    labels (numpy array) : Array of labels of shape (n_samples). Must have same num. of samples as X.
    lfp_chs (integer) : Number of LFP channels.
    run (string) : Either 'All', 'All ECoG' or 'All LFP'. 
    
    Returns
    -------
    mean_ap (float) : 10-fold cross validated mean_ap
    accuracy (float) : Sorted list of all json events files
    """
    if run == 'All':
        Xm = concat_feats(X)
    elif run == 'All ECoG':
        Xm = concat_feats(X[:,0:-lfp_chs,:])
    elif run == 'All LFP':
        Xm = concat_feats(X[:,-lfp_chs:,:])

    df1 = pd.DataFrame(data=labels, columns=['label'])
    df2 = pd.DataFrame(data=Xm)
    df_join = df1.join(df2)
    value_counts = df_join['label'].value_counts()
    df_majority = df_join[df_join.label==value_counts.index[0]]
    df_minority = df_join[df_join.label==value_counts.index[1]]
    df_maj_downsampled = resample(df_majority, replace=False, n_samples=len(df_minority),random_state=123)
    df_downsampled = pd.concat([df_maj_downsampled, df_minority])
    # Separate input features (X) and target variable (y)
    y_ds = df_downsampled.label
    X_ds = df_downsampled.drop('label', axis=1)

    clf = LinearDiscriminantAnalysis(solver='lsqr', shrinkage='auto')
    kf = StratifiedKFold(n_splits = 10, shuffle=True)
    scores = cross_validate(clf, X_ds, y_ds, cv=kf, scoring=['average_precision', 'accuracy'])
    mean_ap = round(np.mean(scores['test_average_precision']),3)
    accuracy = round(np.mean(scores['test_accuracy']),3)
    return mean_ap, accuracy

In [19]:
inpath = '/Users/richardkoehler/OneDrive - Charité - Universitätsmedizin Berlin/Derivatives/BIDS Beijing'
outpath = '/Users/richardkoehler/OneDrive - Charité - Universitätsmedizin Berlin/Derivatives/BIDS Beijing'
outfile = 'Beijing_LDA_scores.tsv'
data_list, events_list = get_json_files(path=inpath)

results = {}

for ind, data in enumerate(data_list):
    
    with open(data_list[ind]) as json_file:
        data_dict = json.load(json_file)
    with open(events_list[ind]) as json_file:
        events_dict = json.load(json_file)
        
    xf_zs_r = np.asarray(data_dict['data'])
    subject = data_dict['subject']
    labels = np.asarray(events_dict['labels'])
    
    items = ['All', 'All ECoG', 'All LFP']
    for item in items:
        mean_ap, accuracy = classify(xf_zs_r, labels, 6, item)
        results.update({'Subject ' + subject + ' ' + item : [mean_ap,accuracy]})

df = pd.DataFrame.from_dict(results,orient='index', columns=['MAP','Accuracy'])
df.to_csv(os.path.join(outpath, outfile), sep='\t')