In [None]:
""" import settings """
%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
from matplotlib_settings import set_plot_settings, reset_plot_settings

# Set the plot settings
set_plot_settings()

# import global variables
from utils_ssep_global import *

In [None]:
""" load data, define static params """
# load bad channel
bad_ch_idx_dir = f'{DATA_DIR}/1_bad_channels'
bad_chs = np.load(f"{bad_ch_idx_dir}/bad_ch_idx.npy")

# load downsampled data
fs, Ts = FS/RS, TS*RS
ds_data_dir = f'{DATA_DIR}/5_downsample'
t = np.load(f"{ds_data_dir}/t_{RS}.npy")

# SSEP index
sep_idxs = np.where(np.logical_and(t > SEP_T0, t < SEP_T1))[0]

# Baseline index
baseline_idxs = np.where(np.logical_and(t > BASELINE_T0, t < BASELINE_T1))[0]

# define "z-score window"
# zscore_t0, zscore_t1 = BASELINE_T0, SEP_T1
# zscore_idxs = np.where(np.logical_and(t > zscore_t0, t < zscore_t1))[0]

zscored_data_dir = f'{DATA_DIR}/7_zscore'

Load data

In [None]:
n_sites = 5 # total number of stimulation locations
stim_sites = [0, 1, 4, 2, 3] # re-order for plotting
min_trial = 100

In [None]:
""" load z-scored data """
zds_datas = [[] for _ in range(n_sites)]

for idx, stim_site in enumerate(stim_sites):
    fn_label = STIM_LABELS[stim_site].replace(" ", "_").lower()

    zds_segs = np.load(f"{zscored_data_dir}/{fn_label}_ds_{RS}_zscore.npy")
    zds_datas[idx] = zds_segs # 256*Trial*time

In [None]:
""" For each stim location, sort the data in terms of trial, ordering each trial 
    by the number of non-saturated channels """

for idx, zds_data in enumerate(zds_datas):
    zds_data = np.transpose(zds_data, (1, 0, 2)) # trial*256*time

    # iterate through trials, computing the number of non-saturated channels
    good_ch_counts = np.zeros((zds_data.shape[0]))
    for trial_idx, trial_data in enumerate(zds_data):
        good_ch_counts[trial_idx] = np.sum(~np.isnan(trial_data[:,0]))

    # sort
    sorted_idx = np.argsort(-good_ch_counts)

    # rearrange the original data in the order of non-sat. channels
    zds_datas[idx] = zds_datas[idx][:,sorted_idx,:]

Generate Input for tSNE and LDA

In [None]:
""" sample 100 trials from each stimulation site """
sample_size = 100
zds_subset = np.zeros((sample_size*n_sites, NCH, len(sep_idxs)))

bad_chs = []
for site_idx, zds_segs in enumerate(zds_datas):
    zds_segs = np.transpose(zds_segs, (1, 0, 2)) # trial*256*time

    sample_idxs = np.arange(0, sample_size)
    offset_idx = site_idx*sample_size

    zds_subset[offset_idx:offset_idx+sample_size,:,:] = zds_segs[:sample_size,:,sep_idxs]

In [None]:
""" identify channels that are not consistently non-saturated (for all trials) """
for ch, ch_data in enumerate(np.transpose(zds_subset, (1, 0, 2))):
    if np.any(np.isnan(ch_data)):
        bad_chs.append(ch)

In [None]:
""" discard the identified channels """
common_chs = []
for ch in range(256):
    if not ch in bad_chs:
        common_chs.append(ch)
common_chs = np.array(common_chs)

zds_subset = zds_subset[:,common_chs,:]
print(len(common_chs))

t-SNE

In [None]:
""" true label. for both tSNE and LDA """
true_label = [[ii]*sample_size for ii in range(n_sites)]
true_label = np.array(true_label).reshape(-1)

In [None]:
""" create a custom colormap """
from matplotlib.colors import LinearSegmentedColormap

# Define the colors for the custom colormap
color_list = [
            #   (0, 0, 0, 0.7), # black
            #   (1, 0.68, 0.26, 0.7), # orange
              (1, 0, 0, 0.7), # red
              (0.5, 0.2, 0.7, 0.7), # purple
              (0, 0.5, 0.8, 0.7), # blue
              (0, 0.5, 0, 0.7), # green
              (0.5, 0.2, 0, 0.7) # brown
              ] 
colors = []
for stim_site in stim_sites:
    colors.append(color_list[stim_site])

# Create a ListedColormap using the defined colors
custom_cmap = LinearSegmentedColormap.from_list('custom_colormap', colors, N=n_sites)

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

def get_tsne(data, n_pca, exp_var_ratio, perplexity=30, tsne_seed=42, tsne_iter=1000):

    # Step 1. PCA
    if exp_var_ratio == 1:
        pca_result = data
    else:
        # step A, identify the number of components that explains the given variance (exp_var_ratio)
        pca = PCA(n_components=n_pca) # n_pca should be chosen to be a sufficiently large value
        _ = pca.fit_transform(data)
        var_ratio_cumsum = np.cumsum(pca.explained_variance_ratio_)
        n = np.argmax(var_ratio_cumsum > exp_var_ratio)
        print(f'# of PCA components: {n}')
    
        # step B, re-do PCA with reduced number of components
        pca2 = PCA(n_components=n)
        pca_result = pca2.fit_transform(data)

    # Step 2. TSNE
    tsne = TSNE(n_components=2, perplexity=perplexity, n_iter=tsne_iter,
                random_state=tsne_seed)
    tsne_pca_result = tsne.fit_transform(pca_result)

    return tsne_pca_result

In [None]:
""" tSNE """
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score, silhouette_samples
from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score

# tSNE
tsne_result = get_tsne(zds_subset.reshape(zds_subset.shape[0], -1),
                            n_pca=150, exp_var_ratio=0.8)

# prediction vs. true using tSNE
kmeans = KMeans(n_clusters=n_sites, random_state=42, n_init='auto')
predicted_label =(kmeans.fit_predict(tsne_result))

# metrics to characterize the clusters
ARI = adjusted_rand_score(true_label, predicted_label)
AMIS = adjusted_mutual_info_score(true_label, predicted_label)
silhouette = silhouette_samples(tsne_result, predicted_label)

sil_avg = np.mean(silhouette)
sil_std = np.std(silhouette)
print(f'{ARI=:.2f}, {AMIS=:.2f}')
print(f'{sil_avg=:.2f}, {sil_std=:.2f}')

In [None]:
""" plot tSNE result """
fig, ax = plt.subplots(1, 1, figsize=(4, 4))

# title_str = f'nch={len(common_chs)}' , silhouette: {sil_avg:.2f}±{sil_std:.2f}'
# ax.set_title(title_str)

ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel('tSNE-1')
ax.set_ylabel('tSNE-2')

im = ax.scatter(tsne_result[:,0], tsne_result[:,1], c=true_label, cmap=custom_cmap)

cbar_ax = fig.add_axes([0.93, 0.11, 0.04, 0.77]) # left, bottom, width, height
cbar = fig.colorbar(im, cax = cbar_ax, ticks=[], orientation='vertical') 
custom_labels = ['MN', 'SL', 'SI', 'SS', 'SM']
cbar.set_ticks([0.4, 1.2, 2, 2.8, 3.6])
cbar.set_ticklabels(custom_labels, fontsize=16)
cbar.ax.tick_params(length=0)  # Removing ticks
cbar.ax.invert_yaxis()

# save_dir = './figures/ssep/tSNE_LDA'
# save_fn = f'tSNE'
# plt.savefig(f"{save_dir}./{save_fn}.svg", bbox_inches='tight')
# plt.savefig(f"{save_dir}./{save_fn}.png", bbox_inches='tight', dpi=1200)

In [None]:
""" Initialize LDA Classifier """
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

def get_lda(data, true_label):

    lda = LinearDiscriminantAnalysis(n_components=2)
    lda_result = lda.fit_transform(data, true_label)
    
    return lda_result

In [None]:
""" Build and Evaluate LDA Classifier """
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix, accuracy_score

def get_lda_score(X, y, lda, cv):
    accuracies = []
    conf_matrices = []

    for train_index, test_index in cv.split(X, y):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]

        lda.fit(X_train, y_train)
        y_pred = lda.predict(X_test)

        acc = accuracy_score(y_test, y_pred)
        conf_matrix = confusion_matrix(y_test, y_pred, normalize='true') * 100

        accuracies.append(acc)
        conf_matrices.append(conf_matrix)

    return [np.array(accuracies), np.array(conf_matrices)]

n_splits = 10
cv_stratified_k_fold = StratifiedKFold(n_splits=10, shuffle=False)
lda = LinearDiscriminantAnalysis()

X = zds_subset.reshape(zds_subset.shape[0], -1)
y = true_label

result = get_lda_score(X, y ,lda, cv_stratified_k_fold)
accuracies, cms = result[0], result[1]

# Display the results
print(f"Mean accuracy: {np.mean(accuracies)*100:.2f} +/- {np.std(accuracies)*100:.2f}")

In [None]:
""" Plot Confusion Matrix """
import seaborn as sns
# Visualize the mean confusion matrix using seaborn heatmap

custom_labels = ['MN', 'SL', 'SI', 'SS', 'SM']

plt.close('all')
fig, ax = plt.subplots(figsize=(4, 4))

sns.heatmap(np.sum(cms, axis=0)/n_splits, annot=True, fmt=".0f", cmap="Blues", vmin=0, vmax=100,
            xticklabels=custom_labels, yticklabels=custom_labels, cbar=False)
ax.set_xlabel('Predicted', fontsize=18)
ax.set_ylabel('True', fontsize=18)
ax.tick_params(labelsize=16)
ax.tick_params(axis='both', which='both', length=0)  # Removing ticks

# save_dir = './figures/ssep/tSNE_LDA'
# save_fn = f'lda_cm'
# plt.savefig(f"{save_dir}./{save_fn}.svg", bbox_inches='tight')
# plt.savefig(f"{save_dir}./{save_fn}.png", bbox_inches='tight', dpi=1200)