In [1]:
curated_states_partition = catalog.load("center_out_curated_states_pkl")
spectrogram_dict = catalog.load("center_out_spectrogram_std_pkl")

In [2]:
curated_states_dict = curated_states_partition['ccCenterOut_20221027_S03']()
sxx_data_dict = spectrogram_dict['ccCenterOut_20221027_S03']()

In [3]:
from scripts.utility_scripts import create_closure, create_closure_func

import numpy as np
from numpy.lib.stride_tricks import sliding_window_view

In [90]:
state = 1
session_details_dict = curated_states_dict[state]

model_data_params = context.params['model_data_params']

pre_stimulus_time = model_data_params['pre_stimulus_time']
post_completion_time = model_data_params['post_completion_time']

window_size = model_data_params['window_size']
shift = model_data_params['shift']

current_experiment = context.params['current_experiment']
sessions = context.params['sessions']
patient_id = context.params['patient_id']

def _generate_trial_sxx_data(data_func, indices):
    sxx_data_dict = data_func()
    
    sxx = sxx_data_dict['sxx']
    
    return sxx[:, indices, :].astype(np.float16)

def _window_spectrogram(data_func, indices, shift, window_size):
    sxx_data_dict = data_func()
    
    sxx = sxx_data_dict['sxx']
    
    return np.moveaxis(sliding_window_view(sxx[:, indices, :][:, ::shift, :], window_shape=window_size, axis=1), [0, -1], [1, -2])
    

def generate_model_windowed_sxx_data(spectrogram_dict, curated_states_partition, sessions, model_data_params, current_experiment, patient_id):
    pre_stimulus_time = model_data_params['pre_stimulus_time']
    post_completion_time = model_data_params['post_completion_time']

    window_size = model_data_params['window_size']
    shift = model_data_params['shift']

    model_data_dict = {}
    model_data_filenames_dict = {}
    global_trial_idx = 0
    for session_type, session_data in sessions[patient_id][current_experiment].items():
        for sxx_partition_key, sxx_partition_func in spectrogram_dict.items():
            date = sxx_partition_key.split('_')[-2]
            session = sxx_partition_key.split('_')[-1]

            if date in session_data.keys() and session in session_data[date]:
                curated_states_data_func = curated_states_partition[sxx_partition_key]
                curated_states_dict = curated_states_data_func()
                # sxx_data_dict = sxx_partition_func()

                # sampling_rate = sxx_data_dict['sampling_rate']
#                 sampling_rate = 1000

                model_data_metadata = {}
                for state, state_information in curated_states_dict.items():
                    cur_dict = {}

                    start_end_idx = state_information['start_end_idx']
                    
                    sampling_rate = state_information['sampling_rate']
                    
                    samples_pre = int(np.ceil((pre_stimulus_time * sampling_rate)/shift))
                    samples_post = int(np.ceil((post_completion_time * sampling_rate)/shift)) 
                    
                    cur_dict['start_end_idx'] = [(x[0]-samples_pre, x[1]+samples_post) for x in start_end_idx]
                    cur_dict['unique_val_idx'] = [np.arange(x[0], x[1]+1) for x in cur_dict['start_end_idx']]
                    cur_dict['num_steps'] = [len(x) for x in cur_dict['unique_val_idx']]

                    model_data_metadata[state] = cur_dict

                # sxx = sxx_data_dict['sxx']

                intermed_model_dict = {}
                for state, metadata_dict in model_data_metadata.items():
                    if state == 0:
                        continue
                        
                    trial_idx_list = []
                    local_trial_idx_list = []
                    for trial_idx, indices in enumerate(metadata_dict['unique_val_idx']):
                        model_data_dict[f"{session_type}_{date}_{session}_T{trial_idx}_state{state}"] = create_closure_func(_generate_trial_sxx_data, sxx_partition_func, indices)
                        
                        trial_idx_list.append(global_trial_idx)
                        global_trial_idx += 1
                        
                        local_trial_idx_list.append(trial_idx)
                        
                    cur_partition_dict = model_data_filenames_dict.setdefault(sxx_partition_key, {})

                    cur_partition_dict[state] = {
                            'date': date,
                            'session_type': session_type,
                            'session': session,
                            'state': state,
                            'local_trials_idx_list': local_trial_idx_list,
                            'global_trials_idx_list': trial_idx_list
                    }
                        
    return model_data_dict, model_data_filenames_dict

In [84]:
min_max_list = []
for sxx_partition_key, sxx_partition_func in spectrogram_dict.items():
    sxx_data_dict = sxx_partition_func()
    
    sxx = sxx_data_dict['sxx']
    
    min_max_list.append((np.min(sxx), np.max(sxx)))

In [85]:
min_max_list

In [91]:
model_data_dict, model_data_filenames_dict = generate_model_windowed_sxx_data(spectrogram_dict, curated_states_partition, sessions, model_data_params, current_experiment, patient_id)

In [101]:
cur_sxx_data = model_data_dict['overt_20221027_S03_T0_state1']()

cur_sxx_data.dtype

In [88]:

model_data_dict
    

In [56]:
np.min(data)

In [86]:
data_small = data.astype(np.float16)

In [82]:
x = np.array([32790])
x = x.astype(np.float16)
# x = x + 1
x

In [65]:
data_small.dtype

In [102]:
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq

import h5py


np.save("mydata.npy", cur_sxx_data)
# np.save("mydata2.npy", data_small)

# with h5py.File("mytestfile.hdf5", "w") as f:
#     dset = f.create_dataset("dataset", [data], dtype='float32')



# x = np.array([0, 1, 2])

# pa_table = pa.table({"data": data})

# pq.write_table(pa_table, "ndarray.parquet")


In [14]:
model_data_dict

In [100]:
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view

test_spectrogram_dict = np.array([[[x] for x in range(50)]])

shift = 2
indices = np.array(range(50))
window_size = 10

np.moveaxis(sliding_window_view(test_spectrogram_dict[:, indices, :], window_shape=window_size, axis=1), [0, -1], [1, -2])[::shift, ...]

In [13]:
model_data_filenames_dict

In [16]:
model_data_filenames_dict['ccCenterOut_20221027_S03']

In [6]:
trial_filenames_list = []
for partition_key, partition_data_func in model_data_filenames_dict.items():
    model_data_dict = partition_data_func #
    
    for state, trial_information_dict in model_data_dict.items():
        session_type = trial_information_dict['session_type']
        date = trial_information_dict['date']
        local_trials_idx_list = trial_information_dict['local_trials_idx_list']
        session = trial_information_dict['session']
        
        trial_filenames_list += [f"{session_type}_{date}_{session}_T{trial_idx}_state{state}" for trial_idx in local_trials_idx_list]
        
        
len(trial_filenames_list)

In [105]:
data_split_type = model_data_params[current_experiment]['sel_split_type']

split_type_params = model_data_params[current_experiment]['split_types'][data_split_type]

leave_out = split_type_params['leave_out']
randomized = split_type_params['randomized']
random_seed = split_type_params['random_seed']
sel_session_type = model_data_params[current_experiment]['sel_session_type']

if data_split_type == 'leave_day_out':
    dates_list = []
    for partition_key, partition_data_func in model_data_filenames_dict.items():
        date = partition_key.split('_')[-2]
        session = partition_key.split('_')[-1]
        
        data_dict = partition_data_func #
        
        for state, trial_information_dict in data_dict.items():
            if sel_session_type != trial_information_dict['session_type']:
                break
            
            if date not in dates_list:
                dates_list.append(date)
                break
    
    dates_list = np.array(list(set(dates_list)))

    dates_list_permuted = np.random.permutation(dates_list)

    test_list_dates = dates_list_permuted[:leave_out]
    train_list_dates = dates_list_permuted[leave_out:]
    
    train_list = []
    test_list = []
    for partition_key, partition_data_func in model_data_filenames_dict.items():
        date = partition_key.split('_')[-2]
        session = partition_key.split('_')[-1]
        
        data_dict = partition_data_func #
        
        for state, trial_information_dict in data_dict.items():
            session_type = trial_information_dict['session_type']
            
            if sel_session_type != trial_information_dict['session_type']:
                continue
            
            if date in train_list_dates:
                train_list += trial_information_dict['global_trials_idx_list']
                
            elif date in test_list_dates:
                test_list += trial_information_dict['global_trials_idx_list']

elif data_split_type == 'leave_session_out':
    dates_and_sessions_list = []
    for partition_key, partition_data_func in model_data_filenames_dict.items():
        date = partition_key.split('_')[-2]
        session = partition_key.split('_')[-1]
        
        data_dict = partition_data_func #
        
        for state, trial_information_dict in data_dict.items():
            if sel_session_type != trial_information_dict['session_type']:
                break
            
            if {'date': date, 'session': session} not in dates_and_sessions_list:
                dates_and_sessions_list.append({'date': date, 'session': session})
                break

    dates_and_sessions_list = np.array(dates_and_sessions_list)
    dates_and_sessions_list_permuted = np.random.permutation(dates_and_sessions_list)
    
    test_list_dates_and_sessions = dates_and_sessions_list_permuted[:leave_out]
    train_list_dates_and_sessions = dates_and_sessions_list_permuted[leave_out:]
    
    train_list = []
    test_list = []
    for partition_key, partition_data_func in model_data_filenames_dict.items():
        date = partition_key.split('_')[-2]
        session = partition_key.split('_')[-1]
        
        data_dict = partition_data_func #
        for state, trial_information_dict in data_dict.items():
            if {'date': date, 'session': session} in train_list_dates_and_sessions:
                train_list += trial_information_dict['global_trials_idx_list']
                
            elif {'date': date, 'session': session} in test_list_dates_and_sessions:
                test_list += trial_information_dict['global_trials_idx_list']

elif data_split_type == 'leave_trial_out':
    trials_list = []
    for partition_key, partition_data_func in model_data_filenames_dict.items():
        date = partition_key.split('_')[-2]
        session = partition_key.split('_')[-1]
        
        data_dict = partition_data_func #
        
        for state, trial_information_dict in data_dict.items():
            if sel_session_type != trial_information_dict['session_type']:
                break
                
            trials_list += trial_information_dict['global_trials_idx_list']
            
    trials_list = np.array(trials_list)
    trials_list_permuted = np.random.permutation(trials_list)
    
    test_list = trials_list_permuted[:leave_out]
    train_list = trials_list_permuted[leave_out:]
    

In [107]:
len(test_list)

In [90]:
train_list_dates_and_sessions

In [34]:
x = [{'a': 1, 'b': 2}, {'c': 1, 'd':2}]
{'a': 1, 'b': 3} in x

In [37]:
test_list

In [32]:
print(len(test_list))
print(len(train_list))

120
680


In [35]:
print(len(test_list))
print(len(train_list))

200
600


In [42]:
model_data_dict

In [None]:
train_dat.shape

In [None]:
np.insert(total_cum_sum, 0, 0)

In [None]:
total_len = np.sum([x.shape[0] for x in train_list])
total_cum_sum = np.cumsum([x.shape[0] for x in train_list])

train_dat = np.zeros((total_len, train_list[0].shape[1], train_list[0].shape[2], train_list[0].shape[3])).astype(np.int32)

total_cum_sum = np.insert(total_cum_sum, 0, 0)

for idx in range(1, len(total_cum_sum)):
    
    if idx != 0:
        continue
    
    start_idx = total_cum_sum[idx - 1]
    end_idx = total_cum_sum[idx]
    
    print(start_idx, end_idx)
    
    print(train_list[idx-1].dtype)
    
    # train_dat[start_idx:end_idx,:,:,:] = train_list[idx-1]


# np.concatenate(train_list).shape

In [15]:
model_data_dict['overt']['20221027'][1][3].shape

In [24]:
np.random.permutation(np.array(list(set(['a', 'b', 'c', 'c']))))