# Data Loading

In [None]:
import sys
sys.path.insert(0, '..')

In [None]:
# imports
import numpy as np
import matplotlib.pyplot as plt

from processing_utils.feature_data_from_mat import load_subject_high_gamma

In [None]:
sig = True
zscore = False
cluster = False

In [None]:
phon_labels = ['a', 'ae', 'i', 'u', 'b', 'p', 'v', 'g', 'k']
artic_labels = ['low', 'high', 'labial', 'dorsal']
phon_to_artic_dict = {1:1, 2:1, 3:2, 4:2, 5:3, 6:3, 7:3, 8:4, 9:4}

In [None]:
def phon_to_artic(phon_idx, phon_to_artic_conv):
    return phon_to_artic_conv[phon_idx]

def phon_to_artic_seq(phon_seq, phon_to_artic_conv):
    flat_seq = phon_seq.flatten()
    artic_conv = np.array([phon_to_artic(phon_idx, phon_to_artic_conv) for phon_idx in flat_seq])
    return np.reshape(artic_conv, phon_seq.shape)

### Relative to Response Onset

Load in S14 Data

In [None]:
S14_hg_trace, S14_hg_map, S14_phon_labels = load_subject_high_gamma('S14', sig_channel=sig, zscore=zscore, cluster=cluster, data_dir='../data/')

In [None]:
print(S14_hg_trace.shape)
print(S14_hg_map.shape)
print(S14_phon_labels.shape)

plt.figure()
plt.plot(np.linspace(-0.5, 0.5, 200), np.mean(S14_hg_trace, axis=2).T, 'grey', alpha=0.35)
plt.plot(np.linspace(-0.5, 0.5, 200), np.mean(np.mean(S14_hg_trace, axis=2), axis=0), 'black')
plt.xlabel('Time (s)')
plt.ylabel('HG (Mean-subtracted)')
plt.title('S14 HG Trace by Trial')
plt.show()

Load in S26 Data

In [None]:
S26_hg_trace, S26_hg_map, S26_phon_labels = load_subject_high_gamma('S26', sig_channel=sig, zscore=zscore, cluster=cluster, data_dir='../data/')

In [None]:
print(S26_hg_trace.shape)
print(S26_hg_map.shape)
print(S26_phon_labels.shape)

plt.figure()
plt.plot(np.mean(S26_hg_trace, axis=0), 'grey')
plt.plot(np.mean(np.mean(S26_hg_trace, axis=0), axis=1), 'black')
plt.title('S26 HG Trace by Channel')
plt.show()

Load in S23 Data

In [None]:
S23_hg_trace, S23_hg_map, S23_phon_labels = load_subject_high_gamma('S23', sig_channel=sig, zscore=zscore, cluster=cluster, data_dir='../data/')

In [None]:
print(S23_hg_trace.shape)
print(S23_hg_map.shape)
print(S23_phon_labels.shape)

plt.figure()
plt.plot(np.mean(S23_hg_trace, axis=0), 'grey')
plt.plot(np.mean(np.mean(S23_hg_trace, axis=0), axis=1), 'black')
plt.title('S23 HG Trace by Channel')
plt.show()

Load in S33 Data

In [None]:
S33_hg_trace, S33_hg_map, S33_phon_labels = load_subject_high_gamma('S33', sig_channel=sig, zscore=zscore, cluster=cluster, data_dir='../data/')

In [None]:
print(S33_hg_trace.shape)
print(S33_hg_map.shape)
print(S33_phon_labels.shape)

plt.figure()
plt.plot(np.mean(S33_hg_trace, axis=0), 'grey')
plt.plot(np.mean(np.mean(S33_hg_trace, axis=0), axis=1), 'black')
plt.title('S33 HG Trace by Channel')
plt.show()

### Relative to Different Phoneme Onsets

In [None]:
from processing_utils.feature_data_from_mat import load_subject_high_gamma_phoneme

S14 Data

In [None]:
S14_hg_data = load_subject_high_gamma_phoneme('S14', phons=[1, 2, 3], cluster=cluster, zscore=zscore, data_dir='../data/')

In [None]:
print(S14_hg_data['X1'].shape, S14_hg_data['X2'].shape, S14_hg_data['X3'].shape)
print(S14_hg_data['X1_map'].shape, S14_hg_data['X2_map'].shape, S14_hg_data['X3_map'].shape)
print(S14_hg_data['y1'].shape, S14_hg_data['y2'].shape, S14_hg_data['y3'].shape, S14_hg_data['y_full_phon'].shape)

t = np.linspace(-0.5, 0.5, S14_hg_data['X1'].shape[1])
f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(18, 5))
ax1.plot(t, np.mean(S14_hg_data['X1'], axis=0), 'grey')
ax1.plot(t, np.mean(np.mean(S14_hg_data['X1'], axis=0), axis=1), 'black')
ax1.set_title('P1')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('High Gamma')

ax2.plot(t, np.mean(S14_hg_data['X2'], axis=0), 'grey')
ax2.plot(t, np.mean(np.mean(S14_hg_data['X2'], axis=0), axis=1), 'black')
ax2.set_title('P2')
ax2.set_xlabel('Time (s)')
# ax2.set_ylabel('High Gamma')

ax3.plot(t, np.mean(S14_hg_data['X3'], axis=0), 'grey')
ax3.plot(t, np.mean(np.mean(S14_hg_data['X3'], axis=0), axis=1), 'black')
ax3.set_title('P3')
ax3.set_xlabel('Time (s)')
# ax3.set_ylabel('High Gamma')

plt.show()

S26 data

In [None]:
S26_hg_data = load_subject_high_gamma_phoneme('S26', phons=[1, 2, 3], cluster=cluster, zscore=zscore, data_dir='../data/')

In [None]:
print(S26_hg_data['X1'].shape, S26_hg_data['X2'].shape, S26_hg_data['X3'].shape)
print(S26_hg_data['X1_map'].shape, S26_hg_data['X2_map'].shape, S26_hg_data['X3_map'].shape)
print(S26_hg_data['y1'].shape, S26_hg_data['y2'].shape, S26_hg_data['y3'].shape, S26_hg_data['y_full_phon'].shape)

t = np.linspace(-0.5, 0.5, S26_hg_data['X1'].shape[1])
f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(18, 5))
ax1.plot(t, np.mean(S26_hg_data['X1'], axis=0), 'grey')
ax1.plot(t, np.mean(np.mean(S26_hg_data['X1'], axis=0), axis=1), 'black')
ax1.set_title('P1')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('High Gamma')

ax2.plot(t, np.mean(S26_hg_data['X2'], axis=0), 'grey')
ax2.plot(t, np.mean(np.mean(S26_hg_data['X2'], axis=0), axis=1), 'black')
ax2.set_title('P2')
ax2.set_xlabel('Time (s)')
# ax2.set_ylabel('High Gamma')

ax3.plot(t, np.mean(S26_hg_data['X3'], axis=0), 'grey')
ax3.plot(t, np.mean(np.mean(S26_hg_data['X3'], axis=0), axis=1), 'black')
ax3.set_title('P3')
ax3.set_xlabel('Time (s)')
# ax3.set_ylabel('High Gamma')

plt.show()


S23 Data

In [None]:
S23_hg_data = load_subject_high_gamma_phoneme('S23', phons=[1, 2, 3], cluster=cluster, zscore=zscore, data_dir='../data/')

In [None]:
print(S23_hg_data['X1'].shape, S23_hg_data['X2'].shape, S23_hg_data['X3'].shape)
print(S23_hg_data['X1_map'].shape, S23_hg_data['X2_map'].shape, S23_hg_data['X3_map'].shape)
print(S23_hg_data['y1'].shape, S23_hg_data['y2'].shape, S23_hg_data['y3'].shape, S23_hg_data['y_full_phon'].shape)

t = np.linspace(-0.5, 0.5, S23_hg_data['X1'].shape[1])
f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(18, 5))
ax1.plot(t, np.mean(S23_hg_data['X1'], axis=0), 'grey')
ax1.plot(t, np.mean(np.mean(S23_hg_data['X1'], axis=0), axis=1), 'black')
ax1.set_title('P1')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('High Gamma')

ax2.plot(t, np.mean(S23_hg_data['X2'], axis=0), 'grey')
ax2.plot(t, np.mean(np.mean(S23_hg_data['X2'], axis=0), axis=1), 'black')
ax2.set_title('P2')
ax2.set_xlabel('Time (s)')
# ax2.set_ylabel('High Gamma')

ax3.plot(t, np.mean(S23_hg_data['X3'], axis=0), 'grey')
ax3.plot(t, np.mean(np.mean(S23_hg_data['X3'], axis=0), axis=1), 'black')
ax3.set_title('P3')
ax3.set_xlabel('Time (s)')
# ax3.set_ylabel('High Gamma')

plt.show()

S33 Data

In [None]:
S33_hg_data = load_subject_high_gamma_phoneme('S33', phons=[1, 2, 3], cluster=cluster, zscore=zscore, data_dir='../data/')

In [None]:
print(S33_hg_data['X1'].shape, S33_hg_data['X2'].shape, S33_hg_data['X3'].shape)
print(S33_hg_data['X1_map'].shape, S33_hg_data['X2_map'].shape, S33_hg_data['X3_map'].shape)
print(S33_hg_data['y1'].shape, S33_hg_data['y2'].shape, S33_hg_data['y3'].shape, S33_hg_data['y_full_phon'].shape)

t = np.linspace(-0.5, 0.5, S33_hg_data['X1'].shape[1])
f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(18, 5))
ax1.plot(t, np.mean(S33_hg_data['X1'], axis=0), 'grey')
ax1.plot(t, np.mean(np.mean(S33_hg_data['X1'], axis=0), axis=1), 'black')
ax1.set_title('P1')
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('High Gamma')

ax2.plot(t, np.mean(S33_hg_data['X2'], axis=0), 'grey')
ax2.plot(t, np.mean(np.mean(S33_hg_data['X2'], axis=0), axis=1), 'black')
ax2.set_title('P2')
ax2.set_xlabel('Time (s)')
# ax2.set_ylabel('High Gamma')

ax3.plot(t, np.mean(S33_hg_data['X3'], axis=0), 'grey')
ax3.plot(t, np.mean(np.mean(S33_hg_data['X3'], axis=0), axis=1), 'black')
ax3.set_title('P3')
ax3.set_xlabel('Time (s)')
# ax3.set_ylabel('High Gamma')

plt.show()

In [None]:
S14_artic_labels = phon_to_artic_seq(S14_phon_labels, phon_to_artic_dict)
S26_artic_labels = phon_to_artic_seq(S26_phon_labels, phon_to_artic_dict)
S23_artic_labels = phon_to_artic_seq(S23_phon_labels, phon_to_artic_dict)
S33_artic_labels = phon_to_artic_seq(S33_phon_labels, phon_to_artic_dict)

### Collapse Across Positions

In [None]:
S14_hg_collapsed = np.concatenate((S14_hg_data['X1'], S14_hg_data['X2'], S14_hg_data['X3']), axis=0)
S14_phon_labels_collapsed = np.concatenate((S14_hg_data['y1'], S14_hg_data['y2'], S14_hg_data['y3']), axis=0)

S26_hg_collapsed = np.concatenate((S26_hg_data['X1'], S26_hg_data['X2'], S26_hg_data['X3']), axis=0)
S26_phon_labels_collapsed = np.concatenate((S26_hg_data['y1'], S26_hg_data['y2'], S26_hg_data['y3']), axis=0)

S23_hg_collapsed = np.concatenate((S23_hg_data['X1'], S23_hg_data['X2'], S23_hg_data['X3']), axis=0)
S23_phon_labels_collapsed = np.concatenate((S23_hg_data['y1'], S23_hg_data['y2'], S23_hg_data['y3']), axis=0)

S33_hg_collapsed = np.concatenate((S33_hg_data['X1'], S33_hg_data['X2'], S33_hg_data['X3']), axis=0)
S33_phon_labels_collapsed = np.concatenate((S33_hg_data['y1'], S33_hg_data['y2'], S33_hg_data['y3']), axis=0)

In [None]:
S14_artic_labels_collapsed = phon_to_artic_seq(S14_phon_labels_collapsed, phon_to_artic_dict)
S26_artic_labels_collapsed = phon_to_artic_seq(S26_phon_labels_collapsed, phon_to_artic_dict)
S23_artic_labels_collapsed = phon_to_artic_seq(S23_phon_labels_collapsed, phon_to_artic_dict)
S33_artic_labels_collapsed = phon_to_artic_seq(S33_phon_labels_collapsed, phon_to_artic_dict)

# Joint PCA Decomp Decoding

### Data Preparation

In [None]:
S14_hg_data['y_full_artic'] = S14_artic_labels
S26_hg_data['y_full_artic'] = S26_artic_labels
S23_hg_data['y_full_artic'] = S23_artic_labels
S33_hg_data['y_full_artic'] = S33_artic_labels

S14_hg_data['X_collapsed'] = S14_hg_collapsed
S26_hg_data['X_collapsed'] = S26_hg_collapsed
S23_hg_data['X_collapsed'] = S23_hg_collapsed
S33_hg_data['X_collapsed'] = S33_hg_collapsed

S14_hg_data['y_phon_collapsed'] = S14_phon_labels_collapsed
S26_hg_data['y_phon_collapsed'] = S26_phon_labels_collapsed
S23_hg_data['y_phon_collapsed'] = S23_phon_labels_collapsed
S33_hg_data['y_phon_collapsed'] = S33_phon_labels_collapsed

S14_hg_data['y_artic_collapsed'] = S14_artic_labels_collapsed
S26_hg_data['y_artic_collapsed'] = S26_artic_labels_collapsed
S23_hg_data['y_artic_collapsed'] = S23_artic_labels_collapsed
S33_hg_data['y_artic_collapsed'] = S33_artic_labels_collapsed


In [None]:
S14_hg_data.keys()

In [None]:
pt_dict = {}

# merge pretrain information with pt data
pt_dict['S14'] = S14_hg_data | {'pre_pts': ['S26', 'S23', 'S33']}
pt_dict['S26'] = S26_hg_data | {'pre_pts': ['S14', 'S23', 'S33']}
pt_dict['S23'] = S23_hg_data | {'pre_pts': ['S14', 'S26', 'S33']}
# pt_dict['S33'] = S33_hg_data | {'pre_pts': ['S14', 'S26', 'S23']}
pt_dict['S33'] = S33_hg_data | {'pre_pts': ['S14', 'S23', 'S26']}

In [76]:
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn

from sklearn.model_selection import StratifiedKFold
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import balanced_accuracy_score
from sklearn.svm import SVC
from sklearn.ensemble import BaggingClassifier
from tqdm.notebook import tqdm

from alignment_methods import JointPCADecomp, CCAAlign

# patient and target params
pt = 'S14'
p_ind = 1

# experiment params
pool_pre = True
tar_in_train = True
cca_algn = True

# constant params
n_iter = 10
n_folds = 5
n_comp = 30

# alignment label type
# algn_type = 'artic_seq'
algn_type = 'phon_seq'
algn_grouping = 'class'

# decoding label type
lab_type = 'phon'
# lab_type = 'artic'

# dimensionality reduction
red_method = 'PCA'
dim_red = PCA

pre_pts = pt_dict[pt]['pre_pts']
if p_ind == -1:
    D_tar = pt_dict[pt]['X_collapsed']
    lab_tar = pt_dict[pt]['y_' + lab_type + '_collapsed']
    

    D1 = pt_dict[pre_pts[0]]['X_collapsed']
    lab1 = pt_dict[pre_pts[0]]['y_' + lab_type + '_collapsed']
    

    D2 = pt_dict[pre_pts[1]]['X_collapsed']
    lab2 = pt_dict[pre_pts[1]]['y_' + lab_type + '_collapsed']
    

    D3 = pt_dict[pre_pts[2]]['X_collapsed']
    lab3 = pt_dict[pre_pts[2]]['y_' + lab_type + '_collapsed']
    
else:
    D_tar = pt_dict[pt]['X' + str(p_ind)]
    lab_tar = pt_dict[pt]['y' + str(p_ind)]

    D1 = pt_dict[pre_pts[0]]['X' + str(p_ind)]
    lab1 = pt_dict[pre_pts[0]]['y' + str(p_ind)]

    D2 = pt_dict[pre_pts[1]]['X' + str(p_ind)]
    lab2 = pt_dict[pre_pts[1]]['y' + str(p_ind)]

    D3 = pt_dict[pre_pts[2]]['X' + str(p_ind)]
    lab3 = pt_dict[pre_pts[2]]['y' + str(p_ind)]

    if lab_type == 'artic':
        lab_tar = phon_to_artic_seq(lab_tar, phon_to_artic_dict)
        lab1 = phon_to_artic_seq(lab1, phon_to_artic_dict)
        lab2 = phon_to_artic_seq(lab2, phon_to_artic_dict)
        lab3 = phon_to_artic_seq(lab3, phon_to_artic_dict)

lab_tar_full = pt_dict[pt]['y_full_' + algn_type[:-4]]
lab1_full = pt_dict[pre_pts[0]]['y_full_' + algn_type[:-4]]
lab2_full = pt_dict[pre_pts[1]]['y_full_' + algn_type[:-4]]
lab3_full = pt_dict[pre_pts[2]]['y_full_' + algn_type[:-4]]
if p_ind == -1:
    lab_tar_full = np.tile(lab_tar_full, (3, 1))
    lab1_full = np.tile(lab1_full, (3, 1))
    lab2_full = np.tile(lab2_full, (3, 1))
    lab3_full = np.tile(lab3_full, (3, 1))

iter_accs = []
for _ in tqdm(range(n_iter)):
    y_true_all, y_pred_all = [], []
    cv = StratifiedKFold(n_splits=n_folds, shuffle=True)
    for train_idx, test_idx in cv.split(D_tar, lab_tar):
        X1, X2, X3 = D1, D2, D3
        y1, y2, y3 = lab1, lab2, lab3
        y1_full, y2_full, y3_full = lab1_full, lab2_full, lab3_full
        
        # split target data into train and test
        X_tar_train, X_tar_test = D_tar[train_idx], D_tar[test_idx]
        y_tar_train, y_tar_test = lab_tar[train_idx], lab_tar[test_idx]
        y_tar_full_train, y_tar_full_test = (lab_tar_full[train_idx],
                                             lab_tar_full[test_idx])

        # learn joint PCA decomposition from full articulator sequences
        jointPCA = JointPCADecomp(n_components=n_comp)
        X1, X2, X3, X_tar_train = jointPCA.fit_transform([X1, X2, X3,
                                                          X_tar_train],
                                                         [y1_full, y2_full,
                                                          y3_full,
                                                          y_tar_full_train])
        # apply target transformation to test data
        X_tar_test = jointPCA.transform(X_tar_test, idx=3)

        # X_tar_train_p = X_tar_train.reshape(-1, X_tar_train.shape[-1])
        # X_tar_test_p = X_tar_test.reshape(-1, X_tar_test.shape[-1])
        # pca = PCA(n_components=n_comp)
        # X_tar_train_p = pca.fit_transform(X_tar_train_p)
        # X_tar_test_p = pca.transform(X_tar_test_p)
        # X_tar_train = X_tar_train_p.reshape(X_tar_train.shape[0], -1, n_comp)
        # X_tar_test = X_tar_test_p.reshape(X_tar_test.shape[0], -1, n_comp)

        # align each pooled patient data to target data with CCA
        if cca_algn:
            cca1 = CCAAlign(type=algn_grouping)
            cca2 = CCAAlign(type=algn_grouping)
            cca3 = CCAAlign(type=algn_grouping)
            cca1.fit(X_tar_train, X1, y_tar_full_train, y1_full)
            cca2.fit(X_tar_train, X2, y_tar_full_train, y2_full)
            cca3.fit(X_tar_train, X3, y_tar_full_train, y3_full)
            X1 = cca1.transform(X1)
            X2 = cca2.transform(X2)
            X3 = cca3.transform(X3)

        # reshape to trials x features
        X_tar_train = X_tar_train.reshape(X_tar_train.shape[0], -1)
        X_tar_test = X_tar_test.reshape(X_tar_test.shape[0], -1)
        X1 = X1.reshape(X1.shape[0], -1)
        X2 = X2.reshape(X2.shape[0], -1)
        X3 = X3.reshape(X3.shape[0], -1)

        if not pool_pre:
            X_train, y_train = X_tar_train, y_tar_train
        else:
            if not tar_in_train:
                X_train = np.concatenate((X1, X2, X3), axis=0)
                y_train = np.concatenate((y1, y2, y3), axis=0)
            else:
                X_train = np.concatenate((X_tar_train, X1, X2, X3), axis=0)
                y_train = np.concatenate((y_tar_train, y1, y2, y3), axis=0)
                # X_train = np.concatenate((X_tar_train, X1, X3), axis=0)
                # y_train = np.concatenate((y_tar_train, y1, y3), axis=0)
        X_test = X_tar_test
        y_test = y_tar_test

        # sc = MinMaxScaler()
        # X_train = sc.fit_transform(X_train)
        # X_test = sc.transform(X_test)

        clf = BaggingClassifier(base_estimator=SVC(kernel='linear', C=0.5),
                                n_estimators=10)
        clf.fit(X_train, y_train)
        y_pred = clf.predict(X_test)

        y_true_all.extend(y_test)
        y_pred_all.extend(y_pred)

    iter_acc = balanced_accuracy_score(y_true_all, y_pred_all)
    print(iter_acc)
    iter_accs.append(iter_acc)

print(iter_accs)
print(f'Mean acc: {np.mean(iter_accs)}, Std: {np.std(iter_accs)}')
print()    


  0%|          | 0/10 [00:00<?, ?it/s]

0.5058865970630676
0.4956702883173471
0.4857482386894151
0.5282955606485018
0.47988144752850637
0.524355600826189
0.5033655886597063
0.5197818521347933
0.5271814786520669
0.5032163370398666
[0.5058865970630676, 0.4956702883173471, 0.4857482386894151, 0.5282955606485018, 0.47988144752850637, 0.524355600826189, 0.5033655886597063, 0.5197818521347933, 0.5271814786520669, 0.5032163370398666]
Mean acc: 0.507338298955946, Std: 0.016335303722812205



# Decoding Results

# New Section