In [None]:
import numpy as np 
import matplotlib.pyplot as plt 
%matplotlib inline 
from scipy.ndimage import measurements

In [None]:
import sys
sys.path.append('../')
import utils as ut
import os
import analysis_stn_data.plotting_functions as plotter
from definitions import DATA_PATH, SAVE_PATH_DATA
import scipy

In [None]:
data_folder = os.path.join(DATA_PATH, 'STN_data_PAC', 'collected')
file_list = os.listdir(data_folder)
save_folder = os.path.join(SAVE_PATH_DATA, 'stn')

subject_file_list = [file for file in file_list if file.startswith('subject') and file.endswith('.p')]

data_dict = dict()

### Use this loop to find a good positive example of significant PAC with nice PSD 

In [None]:
max_cluster_list = []
cluster_criterion = 250

plot_idx = 1
plt.figure(figsize=(15, 30))
for file_idx, file in enumerate(subject_file_list):

    print('Analysing subject file', file)

    super_dict = np.load(os.path.join(data_folder, file))
    subject_id = super_dict['id']

    # collect data
    lfp_dict = super_dict['lfp']
    fs = super_dict['fs']
    pac_dict = super_dict['pac']
    pac_matrix = super_dict['pac_matrix']
    sig_matrix = super_dict['sig_matrix']
    pac_phase = pac_matrix.mean(axis=2)  # average over amplitude frequencies
    
    f_amp = pac_dict['on']['F_amp'].squeeze()
    f_phase = pac_dict['on']['F_phase'].squeeze()
    conditions = ['off', 'on']
    n_conditions = len(conditions)

    # the frequency resolution should be the same for conditions
    n_amplitude = pac_dict['on']['F_amp'].size
    n_phase = pac_dict['on']['F_phase'].size
    
    # channel labels will be same within a subject
    channel_labels = np.squeeze(lfp_dict['on']['channels'])
    channel_labels = [chan[0] for chan in channel_labels]
    right_channels = [chan for chan in channel_labels if chan.startswith('STN_R')]
    left_channels = [chan for chan in channel_labels if chan.startswith('STN_L')]
    left_channel_idx = [channel_labels.index(lc) for lc in left_channels]
    right_channel_idx = [channel_labels.index(rc) for rc in right_channels]

    # LFP DATA
    # over conditions
    conditions = ['off', 'on']
    n_conditions = len(conditions)
    n_channels = len(lfp_dict['on']['channels'])
    bands = [[11, 22]]
    n_bands = len(bands)
    
    significant_pac = np.zeros((len(channel_labels), n_conditions))
    
    for channel_idx, channel_label in enumerate(channel_labels):

        # the customized freqeuncy bands are saved per hemisphere, therefore we have to find out the current hemi
        current_hemi = 'left' if channel_label in left_channels else 'right'

        for condition_idx, condition in enumerate(conditions):

            # get current lfp data
            current_lfp_epochs = lfp_dict[condition]['data'][channel_idx]

            # consider reasonable beta range
            mask = ut.get_array_mask(f_phase >= 5, f_phase <= 40).squeeze()
            f_mask = f_phase[mask]
            data = pac_phase[channel_idx, condition_idx, mask]
            # smooth the mean PAC
            smoother_pac = ut.smooth_with_mean_window(data, window_size=3)
            max_idx = np.argmax(smoother_pac)
            # sum logical significance values across the amplitude frequency dimension
            # calculate the binary groups in the significance map
            lw, num = measurements.label(sig_matrix[channel_idx, condition_idx, : , :])
            # calculate the area of the clusters:
            # from http://stackoverflow.com/questions/25664682/how-to-find-cluster-sizes-in-2d-numpy-array
            area = measurements.sum(sig_matrix[channel_idx, condition_idx,], lw, index=np.arange(lw.max() + 1))
            # get the size of the largest group
            max_cluster_size = np.max(area)
            max_cluster_list.append(max_cluster_size)
            
            # calculate mean
            current_sig_phase = sig_matrix[channel_idx, condition_idx, :, mask].mean(axis=1)  # should be shape (61,)
            current_sig_amp = sig_matrix[channel_idx, condition_idx, :, mask].mean(axis=0)  # should be shape (61,)

            
            if max_cluster_size > cluster_criterion:
                significant_pac[channel_idx, condition_idx] = 1
                
                # plot the pac and the psd 
                plt.subplot(30, 2, plot_idx)
                plt.imshow(pac_matrix[channel_idx, condition_idx].T, interpolation='None')
                plt.subplot(30, 2, plot_idx + 1)
                
                
                f_psd, psd = ut.calculate_psd(y=current_lfp_epochs[:, 0], fs=fs, window_length=1024)  # to get the dims
                for epoch_idx, lfp_epoch in enumerate(current_lfp_epochs[:, 1:].T):
                    f_psd, psd_tmp = ut.calculate_psd(y=lfp_epoch, fs=fs, window_length=1024)
                    psd += psd_tmp
                # divide by n epochs to average
                psd /= current_lfp_epochs.shape[1]
                # interpolate the psd to have the same sample point as in the PAC phase dimensions:
                psd_inter_f = scipy.interpolate.interp1d(f_psd, psd)
                psd = psd_inter_f(f_phase)
                plt.plot(f_phase, psd)
                plt.title('f{}, ch{}, cond{}'.format(file_idx, channel_idx, condition_idx))
                plot_idx += 2

### Now choose a good file idx, channel idx condition idx combination

In [None]:
pos_file_idx = 8
pos_channel_idx = 5 
pos_condition_idx = 0

### Look for a good negative example

In [None]:
plot_idx = 1
plt.figure(figsize=(15, 10))
max_cluster_list = []
cluster_criterion = 150
for file_idx, file in enumerate(subject_file_list):

    print('Analysing subject file', file)

    super_dict = np.load(os.path.join(data_folder, file))
    subject_id = super_dict['id']

    # collect data
    lfp_dict = super_dict['lfp']
    fs = super_dict['fs']
    pac_dict = super_dict['pac']
    pac_matrix = super_dict['pac_matrix']
    sig_matrix = super_dict['sig_matrix']
    pac_phase = pac_matrix.mean(axis=2)  # average over amplitude frequencies
    
    f_amp = pac_dict['on']['F_amp'].squeeze()
    f_phase = pac_dict['on']['F_phase'].squeeze()
    conditions = ['off', 'on']
    n_conditions = len(conditions)

    # the frequency resolution should be the same for conditions
    n_amplitude = pac_dict['on']['F_amp'].size
    n_phase = pac_dict['on']['F_phase'].size
    
    # channel labels will be same within a subject
    channel_labels = np.squeeze(lfp_dict['on']['channels'])
    channel_labels = [chan[0] for chan in channel_labels]
    right_channels = [chan for chan in channel_labels if chan.startswith('STN_R')]
    left_channels = [chan for chan in channel_labels if chan.startswith('STN_L')]
    left_channel_idx = [channel_labels.index(lc) for lc in left_channels]
    right_channel_idx = [channel_labels.index(rc) for rc in right_channels]

    # LFP DATA
    # over conditions
    conditions = ['off', 'on']
    n_conditions = len(conditions)
    n_channels = len(lfp_dict['on']['channels'])
    bands = [[11, 22]]
    n_bands = len(bands)
    
    significant_pac = np.zeros((len(channel_labels), n_conditions))
    
    for channel_idx, channel_label in enumerate(channel_labels):

        # the customized freqeuncy bands are saved per hemisphere, therefore we have to find out the current hemi
        current_hemi = 'left' if channel_label in left_channels else 'right'

        for condition_idx, condition in enumerate(conditions):

            # get current lfp data
            current_lfp_epochs = lfp_dict[condition]['data'][channel_idx]

            # consider reasonable beta range
            mask = ut.get_array_mask(f_phase >= 5, f_phase <= 40).squeeze()
            f_mask = f_phase[mask]
            data = pac_phase[channel_idx, condition_idx, mask]
            # smooth the mean PAC
            smoother_pac = ut.smooth_with_mean_window(data, window_size=3)
            max_idx = np.argmax(smoother_pac)
            # sum logical significance values across the amplitude frequency dimension
            # calculate the binary groups in the significance map
            lw, num = measurements.label(sig_matrix[channel_idx, condition_idx, : , :])
            # calculate the area of the clusters:
            # from http://stackoverflow.com/questions/25664682/how-to-find-cluster-sizes-in-2d-numpy-array
            area = measurements.sum(sig_matrix[channel_idx, condition_idx,], lw, index=np.arange(lw.max() + 1))
            # get the size of the largest group
            max_cluster_size = np.max(area)
            max_cluster_list.append(max_cluster_size)
            
            # calculate mean
            current_sig_phase = sig_matrix[channel_idx, condition_idx, :, mask].mean(axis=1)  # should be shape (61,)
            current_sig_amp = sig_matrix[channel_idx, condition_idx, :, mask].mean(axis=0)  # should be shape (61,)

            
            if max_cluster_size < cluster_criterion and plot_idx < 100:
                plt.subplot(10, 10, plot_idx)
                plt.imshow(pac_matrix[channel_idx, condition_idx].T, interpolation='None')
                plt.xticks([], [])
                plt.yticks([], [])
                plot_idx += 1
                plt.title('f{}, ch{}, cond{}'.format(file_idx, channel_idx, condition_idx))

In [None]:
n_file_idx = 0
n_channel_idx = 2
n_cond_idx = 1

### Load the negative example pac matrix 

In [None]:
file = subject_file_list[n_file_idx]
super_dict = np.load(os.path.join(data_folder, file))
subject_id = super_dict['id']

# collect data
pac_matrix = super_dict['pac_matrix']
sig_matrix = super_dict['sig_matrix']

pac_matrix_nonsig = pac_matrix[n_channel_idx, n_cond_idx, :, :]
sig_matrix2 = sig_matrix[n_channel_idx, n_cond_idx, :, :]

In [None]:
for file_idx, file in enumerate(subject_file_list):

    super_dict = np.load(os.path.join(data_folder, file))
    subject_id = super_dict['id']

    # collect data
    lfp_dict = super_dict['lfp']
    fs = super_dict['fs']
    pac_dict = super_dict['pac']
    pac_matrix = super_dict['pac_matrix']
    sig_matrix = super_dict['sig_matrix']
    pac_phase = pac_matrix.mean(axis=2)  # average over amplitude frequencies
    
    f_amp = pac_dict['on']['F_amp'].squeeze()
    f_phase = pac_dict['on']['F_phase'].squeeze()
    conditions = ['off', 'on']
    n_conditions = len(conditions)

    # the frequency resolution should be the same for conditions
    n_amplitude = pac_dict['on']['F_amp'].size
    n_phase = pac_dict['on']['F_phase'].size
    
    # channel labels will be same within a subject
    channel_labels = np.squeeze(lfp_dict['on']['channels'])
    channel_labels = [chan[0] for chan in channel_labels]
    right_channels = [chan for chan in channel_labels if chan.startswith('STN_R')]
    left_channels = [chan for chan in channel_labels if chan.startswith('STN_L')]
    left_channel_idx = [channel_labels.index(lc) for lc in left_channels]
    right_channel_idx = [channel_labels.index(rc) for rc in right_channels]

    # LFP DATA
    # over conditions
    conditions = ['off', 'on']
    n_conditions = len(conditions)
    n_channels = len(lfp_dict['on']['channels'])
    bands = [[11, 22]]
    n_bands = len(bands)
    
    significant_pac = np.zeros((len(channel_labels), n_conditions))
    
    for channel_idx, channel_label in enumerate(channel_labels):

        # the customized freqeuncy bands are saved per hemisphere, therefore we have to find out the current hemi
        current_hemi = 'left' if channel_label in left_channels else 'right'

        for condition_idx, condition in enumerate(conditions):

            # get current lfp data
            current_lfp_epochs = lfp_dict[condition]['data'][channel_idx]

            # consider reasonable beta range
            mask = ut.get_array_mask(f_phase >= 5, f_phase <= 40).squeeze()
            f_mask = f_phase[mask]
            data = pac_phase[channel_idx, condition_idx, mask]
            # smooth the mean PAC
            smoother_pac = ut.smooth_with_mean_window(data, window_size=3)
            max_idx = np.argmax(smoother_pac)
            
            # calculate mean
            current_sig_phase = sig_matrix[channel_idx, condition_idx, :, mask].mean(axis=1)  # should be shape (61,)
            current_sig_amp = sig_matrix[channel_idx, condition_idx, :, mask].mean(axis=0)  # should be shape (61,)
            
            if file_idx == pos_file_idx and channel_idx == pos_channel_idx and condition_idx == pos_condition_idx: 
                pac_matrix_sig = pac_matrix[channel_idx, condition_idx, :, :]
                sig_matrix1 = sig_matrix[channel_idx, condition_idx, :, :]
                plotter.plot_beta_band_selection_illustration_for_poster(pac_matrix_sig, pac_matrix_nonsig, 
                                                                         sig_matrix1, sig_matrix2,
                                                                         n_phase, n_amplitude,
                                                     f_phase, f_amp, mask, smoother_pac, max_idx, current_lfp_epochs,
                                                     subject_id, fs, save=True)
                