In [166]:
import mne
from pathlib import Path 
import numpy as np
from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score, ShuffleSplit
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import matplotlib
import matplotlib.pyplot as plt
from scipy import linalg

#matplotlib.use('Qt5Agg')
mne.set_log_level('warning')

all_epochs = mne.read_epochs(Path('out_data') / 'all_asd_td_concat_cleaned_equal_1_40hz_epo.fif')

idx_asd = all_epochs.events[:, 2] == all_epochs.event_id['asd']
idx_td = all_epochs.events[:, 2] == all_epochs.event_id['td']

# trails (epochs, channels, samples)
np_all_epochs = all_epochs.get_data()
print(f"epochs: {np_all_epochs.shape[0]}, channels: {np_all_epochs.shape[1]}, samples: {np_all_epochs.shape[2]}")
n_trials = np_all_epochs.shape[0]


# Labels
y = np.empty(len(all_epochs.events), dtype=int)  

# Encode: ASD = 0, TD = 1.
y[idx_asd] = 0
y[idx_td] = 1


# epochs: 1230, channels: 32, samples: 513
#create array of shape (1230, 32, # of band powers)
num_freq_band_feature = 4
X_2d = np.empty([np_all_epochs.shape[0], np_all_epochs.shape[1], num_freq_band_feature], dtype=float)  

fmax = 30
psd_epochs_channels_freqs = all_epochs.compute_psd(fmax=fmax).get_data()

for epoch_id, epoch in enumerate(psd_epochs_channels_freqs):
    for channel_id, channel in enumerate(epoch):
        #print(channel.shape)
        delta = np.mean(channel[:4])
        theta = np.mean(channel[4:8])
        alpha = np.mean(channel[8:12])
        beta = np.mean(channel[13:30])
        X_2d[epoch_id, channel_id, :] = [delta,theta, alpha, beta]

def viewEpochChannelPSD(epoch, channel, matrix):
    # View Epoch channel PSD
    _, ax = plt.subplots()
    _x = np.indices(matrix[epoch, channel, :].shape).squeeze()
    ax.plot(_x , matrix[epoch, channel, :], color='k')
    plt.show()

#viewEpochChannelPSD(10, 30, X_2d)

# Test 1 (1230, 32, 30) 30 -> 1hz - 30hz PSD 

#print(X_2d.shape, "before")
X_2d_reshaped = X_2d.reshape(n_trials, -1)
#print(X_2d.shape, "after")

# Define an SVM classifier (SVC) with a linear kernel

clf = SVC(C=1, kernel='linear')
cv = ShuffleSplit(10, test_size=0.2, random_state=42)
scores_full = cross_val_score(clf, X_2d_reshaped, y, cv=cv, n_jobs=1)
print(f"SVM Classification score: {np.mean(scores_full)} (std. {np.std(scores_full)})")


clf = LinearDiscriminantAnalysis()
scores_full = cross_val_score(clf, X_2d_reshaped, y, cv=cv, n_jobs=1)
print(f"LDA Classification score: {np.mean(scores_full)} (std. {np.std(scores_full)})")

colors = ["navy", "turquoise"]
target_names = ["asd", "td"]

asd_powers = X_2d[idx_asd]
td_powers = X_2d[idx_td]

print("shape: ", asd_powers[:, :, 0].shape)



'''
fig, axs = plt.subplots(3, 4)

channelid = 2
for i in range(3):
    for x in range(4):
        axs[i, x].scatter(asd_powers[:, channelid, i], asd_powers[:, channelid, x], alpha=0.5, color="navy", label="ASD")
        axs[i, x].scatter(td_powers[:, channelid, i], td_powers[:, channelid, x], alpha=0.5, color="turquoise", label="TD")
        axs[i, x].legend(loc="best", shadow=False, scatterpoints=1)
        axs[i, x].set_title(f'Axis [{i}, {x}]')
'''


bar_colors = ['tab:red', 'tab:blue']
print("ASD delta avg: ", np.mean(asd_powers[:, :, 0]))
print("TD delta avg: ", np.mean(td_powers[:, :, 0]))
asd_mean_psd = np.mean(asd_powers[:, :, 0])
td_mean_psd =  np.mean(td_powers[:, :, 0])
counts = [asd_mean_psd, td_mean_psd]
        
fig, ax = plt.subplots()
ax.bar(target_names, counts, color=bar_colors)
plt.show()







epochs: 1230, channels: 32, samples: 513
SVM Classification score: 0.46991869918699186 (std. 0.013506705468159469)
LDA Classification score: 0.9195121951219513 (std. 0.01683928063187641)
shape:  (615, 32)
ASD delta avg:  4.716783585567914e-09
TD delta avg:  5.578230853433155e-09
