In [None]:
import os
import glob
import json
import xarray as xr
import numpy as np 
import scipy.io
import matplotlib.pyplot as plt

# magic to autoreload modules called
%load_ext autoreload
%autoreload 2

In [None]:
# will build the path for the sessions as 'default_directory'//'experiment_directory[n]'//(specific selected session)
default_directory = "D://EinsteinMed Dropbox//Gabriel Baltazar//" # root directory 
experiment_directory = ["Chat-Cre X ai40 (Revised)"] # ["Chat-Cre X ai40 (Revised)", "VIP-Cre x Ai40 (Revised)"] # folders that are in the root directory and store specific recording folders

# metadata file
metadata_file = 'session_metadata.json'
criteria_dict = {
    "include_in_analysis":  True,
    "visual_stimulation_type": "Drifting Grating" # including only drifting grating bcs natural stimuli recordings didn't have spontaneous block
}

valid_paths_list = div.get_recordings_path(default_directory, experiment_directory, criteria_dict)

### Checking dropped frames

In [None]:
print(proc["motion"])

In [None]:
adc_xarray = "adc_channels.zarr"

# loop for recordings
for session in valid_paths_list:

    # loading the xarrays
    session_processed = os.path.join(session, "processed")
    adc = xr.open_zarr(os.path.join(session_processed, adc_xarray), consolidated=True)
    camera_strobe = adc["camera_strobe"].values

    # getting the facemap processed file
    npy_files = glob.glob(os.path.join(session, '*proc.npy'))
    video_data_file_add = npy_files[0] # selecting the first .npy file by alphabetical order
    proc = np.load(video_data_file_add, allow_pickle=True).item()
    face_motion = proc["motion"][1]

    # detecting the strobe transitions
    strobe_onset, _, _, _ = intanGab.detect_transitions(
        camera_strobe,
        lowpass_filter=True,
        output_plot=False
    )

    print(f"({os.path.basename(session)}) number of frames: {len(face_motion)}")
    print(f"({os.path.basename(session)}) number of strobes: {len(strobe_onset)}")

In [None]:
# No need to worry about dropped frames, all recordings seem fine in regard to that

# 'read_intan_ports.detect_transitions: detected 159368 transitions, with the start trigger as down transition
# (ehsan_221206_100358) number of frames: 159368
# (ehsan_221206_100358) number of strobes: 159368
# 'read_intan_ports.detect_transitions: detected 146342 transitions, with the start trigger as up transition
# (ehsan_221206_104734) number of frames: 146341
# (ehsan_221206_104734) number of strobes: 146342
# 'read_intan_ports.detect_transitions: detected 192372 transitions, with the start trigger as down transition
# (ehsan_221207_105129) number of frames: 192372
# (ehsan_221207_105129) number of strobes: 192372
# 'read_intan_ports.detect_transitions: detected 186462 transitions, with the start trigger as down transition
# (ehsan_221208_095858) number of frames: 186462
# (ehsan_221208_095858) number of strobes: 186462
# 'read_intan_ports.detect_transitions: detected 185990 transitions, with the start trigger as up transition
# (ehsan_221208_103742) number of frames: 185989
# (ehsan_221208_103742) number of strobes: 185990
# 'read_intan_ports.detect_transitions: detected 197245 transitions, with the start trigger as up transition
# (ehsan_221209_082822) number of frames: 197244
# (ehsan_221209_082822) number of strobes: 197245
# 'read_intan_ports.detect_transitions: detected 160860 transitions, with the start trigger as up transition
# (ehsan_210415_102404) number of frames: 160858
# (ehsan_210415_102404) number of strobes: 160860
# 'read_intan_ports.detect_transitions: detected 164406 transitions, with the start trigger as up transition
# (ehsan_210416_124048) number of frames: 164405
# (ehsan_210416_124048) number of strobes: 164406
# 'read_intan_ports.detect_transitions: detected 161685 transitions, with the start trigger as up transition
# (ehsan_210416_150531) number of frames: 161684
# (ehsan_210416_150531) number of strobes: 161685
# 'read_intan_ports.detect_transitions: detected 193498 transitions, with the start trigger as up transition
# (ehsan_210422_103751) number of frames: 193497
# (ehsan_210422_103751) number of strobes: 193498
# 'read_intan_ports.detect_transitions: detected 213467 transitions, with the start trigger as up transition
# (ehsan_210422_135429) number of frames: 213466
# (ehsan_210422_135429) number of strobes: 213467
# 'read_intan_ports.detect_transitions: detected 189261 transitions, with the start trigger as up transition
# (ehsan_210423_100903) number of frames: 189260
# (ehsan_210423_100903) number of strobes: 189261
# 'read_intan_ports.detect_transitions: detected 188361 transitions, with the start trigger as up transition
# (ehsan_210423_124453) number of frames: 188360
# (ehsan_210423_124453) number of strobes: 188361
# 'read_intan_ports.detect_transitions: detected 193850 transitions, with the start trigger as up transition
# (ehsan_210506_125649) number of frames: 193849
# (ehsan_210506_125649) number of strobes: 193850
# 'read_intan_ports.detect_transitions: detected 136499 transitions, with the start trigger as up transition
# (ehsan_210506_154650) number of frames: 136498
# (ehsan_210506_154650) number of strobes: 136499
# 'read_intan_ports.detect_transitions: detected 175482 transitions, with the start trigger as up transition
# (ehsan_210507_102850) number of frames: 175481
# (ehsan_210507_102850) number of strobes: 175482
# 'read_intan_ports.detect_transitions: detected 187046 transitions, with the start trigger as up transition
# (ehsan_210507_133101) number of frames: 187003
# (ehsan_210507_133101) number of strobes: 187046
# 'read_intan_ports.detect_transitions: detected 242581 transitions, with the start trigger as up transition
# (ehsan_210617_102107) number of frames: 242580
# (ehsan_210617_102107) number of strobes: 242581
# 'read_intan_ports.detect_transitions: detected 187107 transitions, with the start trigger as up transition
# (ehsan_210618_114033) number of frames: 187106
# (ehsan_210618_114033) number of strobes: 187107

### Test

General settings

In [None]:
# xarray file names
vis_stim_xarray = "optogenetics_visual_stimulation.zarr"
amplifier_xarray = "amplifier_reorder_low_and_high_pass_int16.zarr"
state_xarray = "state.zarr"

# temporal domain settings
window_size = 0.5
time_to_plot_before = -0.5
time_to_plot_after = 2.5
window_to_plot = abs(time_to_plot_before) + abs(time_to_plot_after)
int16_to_mV = 0.195

# classification thresholds
camera_sample_rate = 30
emg_perc = 80
face_motion_perc = 70
face_motion_time_threshold = 5
face_motion_window = face_motion_time_threshold * camera_sample_rate
locomotion_threshold = 0.5
locomotion_time_threshold = 5
locomotion_window = locomotion_time_threshold * camera_sample_rate
classification_threshold = 0.9

# states settings
quiescence_indicator = 0
face_motion_indicator = 1
locomotion_indicator = 2
quies_label = "quiet"
active_label = "wake"
loc_label = "loc."
disc_label = "disc."
all_labels = np.asarray([quies_label, active_label, loc_label, disc_label])

Figure settings

In [None]:
# figure settings
figure_size = [14, 7]
figure_rows = 16
figure_columns = 4

# subplots settings
first_row = slice(0, 3)
second_row = slice(3, 6)
third_row = slice(6, 9)
fourth_row = slice(9, 12)
fifth_row = slice(12, 13)
sixth_row = slice(13, 14)
seventh_row = slice(14, 15)
eigth_row = slice(15, 16)
trial_row = slice(12, 16)
col_left = slice(0, 1)
col_midl = slice(1, 2)
col_midr = slice(2, 3)
col_right = slice(3, 4)
col_sess = slice(0, 3)

# figure colors
replicates_color = "#ebebeb"
window_color =  "#f4c542" #"#808080"
quies_color = "#6666ff"#"#bfbfbf"
active_color = "#66ff66"#"#808080"
loc_color = "#ff6666"#"#000000"
disc_color = "#000000"#"#ff6666"
state_color = "#000000"#"#1f77b4"
opto_off_color = "#000000"
opto_on_color = "#0cb7f4"
highlight_color = "#FF0000"

save_fig = True
folder_to_save = os.path.join("D://EinsteinMed Dropbox//Gabriel Baltazar//Chat-Cre VIP-Cre figs", "state_sorting_accurate_score")

Plots spines settings

In [None]:
general_settings = {
    "left":     {"linewidth": 1.2, "tick_width": 1.2, "ticks": []}, 
    "bottom":   {"linewidth": 1.2, "tick_width": 1.2, "ticks": []}, 
    "right":    {"visible": False}, 
    "top":      {"visible": False},
    }

whole_sess_settings = {
    "left":     {"visible": False}, 
    "bottom":   {"visible": False}, 
    "right":    {"visible": False}, 
    "top":      {"visible": False},
    }

trial_settings = {
    "left":     {"linewidth": 1.2, "tick_width": 1.2, "ticks": ("data_lim", 0), "label": "Trial\nCount", "label_pad": -8}, 
    "bottom":   {"linewidth": 1.2, "tick_width": 1.2, "tick_rotation": 45}, 
    "right":    {"visible": False}, 
    "top":      {"visible": False},
    }

State sorting and plotting figures

In [None]:
%matplotlib widget

# loop for recordings
for session in valid_paths_list:

    # loading session metadata
    with open(os.path.join(session, "session_metadata.json"), 'r') as json_file:
        metadata = json.load(json_file)
    # sample_rate = metadata["amplifier_sample_rate"]
    lfp_sample_rate = metadata["amplifier_low_pass_sample_rate"]
    camera_sample_rate = metadata["camera_sample_rate"]
    animal_id = metadata["animal_id"]
    experiment_id = metadata["experiment_id"]
    group = metadata["group"]

    # loading the xarrays
    session_processed = os.path.join(session, "processed")
    vis_stim_opto = xr.open_zarr(os.path.join(session_processed, vis_stim_xarray), consolidated=True)
    state = xr.open_zarr(os.path.join(session_processed, state_xarray), consolidated=True)
    amplifier = xr.open_zarr(os.path.join(session_processed, amplifier_xarray), consolidated=True)
    
    # processing the lfp 
    lfp_mat_chs = scipy.io.loadmat(os.path.join(session, "amplifierReorder", "amplifierReorder.SleepScoreLFP.LFP.mat"))
    lfp_mat_emg = scipy.io.loadmat(os.path.join(session, "amplifierReorder", "amplifierReorder.EMGFromLFP.LFP.mat"))
    th_ch = lfp_mat_chs["SleepScoreLFP"]["THchanID"].item()[0][0]
    sw_ch = lfp_mat_chs["SleepScoreLFP"]["SWchanID"].item()[0][0]
    ts_emg = lfp_mat_emg["EMGFromLFP"]["timestamps"][0][0]
    data_emg = lfp_mat_emg["EMGFromLFP"]["data"][0][0]
    lfp = amplifier["low_pass"]
    lfp_timepoint = amplifier.coords["time_in_sec_low_pass"].values
    slow_waves = filtsGab.butter_bandpass_filter(data=lfp.sel(channel=sw_ch).values, cutoff=[0.5, 4], fs=lfp_sample_rate, order=2) * int16_to_mV
    theta_waves = filtsGab.butter_bandpass_filter(data=lfp.sel(channel=th_ch).values, cutoff=[6, 10], fs=lfp_sample_rate, order=2) * int16_to_mV

    # loading state data
    strobe_in_seconds = state.coords["strobe_in_seconds"].values
    face_motion = state["face_motion_normalized"].values
    treadmill = state["treadmill_analog"].values
    scored_state = np.zeros_like(strobe_in_seconds, dtype=int)

    # scoring face motion epochs
    face_motion_threshold = np.percentile(face_motion, face_motion_perc)
    face_motion_inds = np.where(face_motion > face_motion_threshold)[0]
    i = 0
    while i < len(face_motion_inds): # looping over face motion peaks
        
        onset = face_motion_inds[i]

        # Look ahead until the first peak farther than the window
        j = i + 1
        while j < len(face_motion_inds) and (face_motion_inds[j] - face_motion_inds[i]) < face_motion_window:
            i = j
            j += 1

        # The offset is the last close peak + the window
        offset = face_motion_inds[j-1] + face_motion_window
        scored_state[onset:offset] = face_motion_indicator

        i += 1

    # scoring locomotion epochs according to treadmill and emg data
    treadmill_inds = np.where(abs(treadmill) > locomotion_threshold)[0]
    emg_threshold = np.percentile(data_emg, emg_perc)
    emg_inds = np.where(data_emg > emg_threshold)[0]
    emg_inds_converted = np.zeros(len(emg_inds))
    for i, emg_ind in enumerate(emg_inds): # converting emg indexes
        emg_inds_converted[i] = np.argmin(abs(strobe_in_seconds - ts_emg[emg_ind]))
    locomotion_inds = np.union1d(treadmill_inds, emg_inds_converted).astype(int)

    while i < len(locomotion_inds): # looping over face motion peaks
        
        onset = locomotion_inds[i]

        # Look ahead until the first peak farther than the window
        j = i + 1
        while j < len(locomotion_inds) and (locomotion_inds[j] - locomotion_inds[i]) < locomotion_window:
            i = j
            j += 1

        # The offset is the last close peak 
        offset = locomotion_inds[j-1]
        scored_state[onset:offset] = locomotion_indicator

        i += 1

    # getting visual stimulation trials onset timestamps
    vis_stim_mask = vis_stim_opto.coords["visual_stimulation"].astype(bool)
    stim_onset = vis_stim_opto["visual_stimulation_timestamp"].sel(event_type="onset").where(vis_stim_mask, drop=True).values
    stim_offset = stim_onset + window_size

    sample_no_lfp = int(lfp_sample_rate * window_to_plot)
    sample_no_state = int(camera_sample_rate * window_to_plot)
    trial_class = np.empty(len(stim_onset), dtype='object')
    th_class = np.zeros([len(stim_onset), sample_no_lfp], dtype=float)
    sw_class = np.zeros([len(stim_onset), sample_no_lfp], dtype=float)
    tr_class = np.zeros([len(stim_onset), sample_no_state], dtype=float)
    fm_class = np.zeros([len(stim_onset), sample_no_state], dtype=float)

    # loop to state-score the trials
    for i, (onset, offset) in enumerate(zip(stim_onset, stim_offset)):

        # getting the scored-frames for the current trial
        trial_state = scored_state[(strobe_in_seconds >= onset) & (strobe_in_seconds <= offset)]

        # counting how many frames were scored in each state
        unique_states, count_states = np.unique(trial_state, return_counts=True)
        if np.max(count_states / len(trial_state)) >= classification_threshold: # case: homogeneous trial, at least 90% of frames with same score
            if unique_states[np.argmax(count_states)] == quiescence_indicator: # case: quiescent trial
                st = quies_label
            elif unique_states[np.argmax(count_states)] == face_motion_indicator: # case: active trial
                st = active_label
            elif unique_states[np.argmax(count_states)] == locomotion_indicator: # case: locomotion trial
                st = loc_label
            trial_class[i] = st
        else: # case: inhomogeneous trial, less than 90% of frames with same score
            trial_class[i] = disc_label
        
        # getting state measures for the current trial
        lfp_first_ind = np.argmin(abs(lfp_timepoint - (onset + time_to_plot_before)))
        state_first_ind = np.argmin(abs(strobe_in_seconds - (onset + time_to_plot_before)))
        th_class[i] = theta_waves[lfp_first_ind : lfp_first_ind + sample_no_lfp]
        sw_class[i] = slow_waves[lfp_first_ind : lfp_first_ind + sample_no_lfp]
        tr_class[i] = treadmill[state_first_ind : state_first_ind + sample_no_state]
        fm_class[i] = face_motion[state_first_ind : state_first_ind + sample_no_state]

    # getting the axis objects for all individual subplots
    fig = plt.figure(figsize=figure_size, constrained_layout=True)
    grid = fig.add_gridspec(figure_rows, figure_columns)
    ax_quiet_th = fig.add_subplot(grid[first_row, col_left])
    ax_quiet_sw = fig.add_subplot(grid[second_row, col_left])
    ax_quiet_tr = fig.add_subplot(grid[third_row, col_left])
    ax_quiet_fm = fig.add_subplot(grid[fourth_row, col_left])
    ax_motion_th = fig.add_subplot(grid[first_row, col_midl])
    ax_motion_sw = fig.add_subplot(grid[second_row, col_midl])
    ax_motion_tr = fig.add_subplot(grid[third_row, col_midl])
    ax_motion_fm = fig.add_subplot(grid[fourth_row, col_midl])
    ax_loc_th = fig.add_subplot(grid[first_row, col_midr])
    ax_loc_sw = fig.add_subplot(grid[second_row, col_midr])
    ax_loc_tr = fig.add_subplot(grid[third_row, col_midr])
    ax_loc_fm = fig.add_subplot(grid[fourth_row, col_midr])
    ax_disc_th = fig.add_subplot(grid[first_row, col_right])
    ax_disc_sw = fig.add_subplot(grid[second_row, col_right])
    ax_disc_tr = fig.add_subplot(grid[third_row, col_right])
    ax_disc_fm = fig.add_subplot(grid[fourth_row, col_right])
    ax_emg_sess = fig.add_subplot(grid[fifth_row, col_sess])
    ax_tread_sess = fig.add_subplot(grid[sixth_row, col_sess])
    ax_facem_sess = fig.add_subplot(grid[seventh_row, col_sess])
    ax_state_sess = fig.add_subplot(grid[eigth_row, col_sess])
    ax_trial_count = fig.add_subplot(grid[trial_row, col_right])

    theta_ylims = []; delta_ylims = []; tread_ylims = []; motion_ylims = []

    # looping over groups of axis to plot the theta, delta, treadmill and face motion data for each trial type
    for ax_theta, ax_delta, ax_tread, ax_motion, trial_type, color in zip(
        [ax_quiet_th,  ax_motion_th,  ax_loc_th,  ax_disc_th],
        [ax_quiet_sw,  ax_motion_sw,  ax_loc_sw,  ax_disc_sw],
        [ax_quiet_tr,  ax_motion_tr,  ax_loc_tr,  ax_disc_tr],
        [ax_quiet_fm,  ax_motion_fm,  ax_loc_fm,  ax_disc_fm],
        [quies_label,  active_label,  loc_label,  disc_label],
        [quies_color,  active_color,  loc_color,  disc_color]
        ):
        
        # getting trials that match the current trial type
        selected_trials = np.where(trial_class == trial_type)[0]
        ax_theta.set_title(f"{trial_type}: {len(selected_trials)} trials", fontsize=10)

        # checking if there is at least one trial classified in the current trial type
        if len(selected_trials) > 0: # case: at least one trial for the current trial type

            # plotting the filetered lfp measures
            time_lfp_to_plot = (np.arange(sample_no_lfp) / lfp_sample_rate) + time_to_plot_before
            div.plot_with_sem(axis_to_plot = ax_theta,
                              sem = "compute",
                              signal_x = time_lfp_to_plot,
                              signal_y = th_class[selected_trials],
                              color = color, line_width = 2,
                              zorder=1)
            div.plot_with_sem(axis_to_plot = ax_delta,
                              sem = "compute",
                              signal_x = time_lfp_to_plot,
                              signal_y = sw_class[selected_trials],
                              color = color, line_width = 2,
                              zorder=1)

            # plotting the state measures
            time_state_to_plot = (np.arange(sample_no_state) / camera_sample_rate) + time_to_plot_before
            ax_tread.plot(time_state_to_plot, abs(tr_class[selected_trials].T), linewidth=0.5, color = replicates_color, zorder=0)
            ax_tread.plot(time_state_to_plot, abs(np.mean(tr_class[selected_trials], axis=0)), linewidth=2, color = color, zorder=1)
            ax_tread.plot(time_state_to_plot, np.ones(len(time_state_to_plot)) * locomotion_threshold, color = "#808080", zorder = 2, linestyle = "--")
            ax_motion.plot(time_state_to_plot, fm_class[selected_trials].T, linewidth=0.5, color = replicates_color, zorder=0)
            ax_motion.plot(time_state_to_plot, np.mean(fm_class[selected_trials], axis=0), linewidth=2, color = color, zorder=1)
            ax_motion.plot(time_state_to_plot, np.ones(len(time_state_to_plot)) * face_motion_threshold, color = "#808080", zorder = 2, linestyle = "--")

            # highlighting the window considered for the visual response
            for axis in [ax_theta, ax_delta, ax_tread, ax_motion]:
                div.highlight_event(axis_to_plot = axis,
                                    highlight_type = "shaded_area",
                                    event_onset = 0, event_offset = window_size,
                                    event_color = window_color, zorder=3)
                
            # cumulating ylims to allow for inter-state comparisons (all plots with same scale)
            theta_ylims.append(ax_theta.get_ylim())
            delta_ylims.append(ax_delta.get_ylim())
            tread_ylims.append(ax_tread.get_ylim())
            motion_ylims.append(ax_motion.get_ylim())

    # whole session plots
    ax_emg_sess.plot(ts_emg[3:-3], abs(data_emg[3:-3]), linewidth = 0.5, color = state_color)
    ax_emg_sess.plot(ts_emg[3:-3], np.ones(len(data_emg[3:-3])) * emg_threshold, linewidth = 0.75, color= highlight_color, linestyle= "--")
    ax_tread_sess.plot(strobe_in_seconds, abs(treadmill), linewidth = 0.5, color = state_color)
    ax_tread_sess.plot(strobe_in_seconds, np.ones(len(treadmill)) * locomotion_threshold, linewidth = 0.75, color=highlight_color, linestyle='--')
    ax_facem_sess.plot(strobe_in_seconds, face_motion, linewidth=0.5, color = state_color)
    ax_facem_sess.plot(strobe_in_seconds, np.ones(len(face_motion)) * face_motion_threshold, linewidth = 0.75, color=highlight_color, linestyle='--')

    stateGab.plot_states(axis_to_plot=ax_state_sess, state_score=scored_state, frames_in_sec=strobe_in_seconds, input_format="integer",
                         colors = [quies_color, active_color, loc_color, disc_color])

    # getting trial-wise info
    opto_on_mask = vis_stim_opto.coords["opto_stimulation"].where(vis_stim_mask).dropna(dim="trial").astype(bool).values
    trial_direcs = vis_stim_opto.coords["stimulus_orientation"].where(vis_stim_mask).dropna(dim="trial").values

    # counting number of state-scored trials for opto-on and opto-off conditions
    unique_directions = np.unique(trial_direcs)
    opto_on_scored_trials = np.zeros([len(all_labels), len(unique_directions)])
    opto_off_scored_trials = np.zeros([len(all_labels), len(unique_directions)])

    # all_labels = np.unique(trial_class)
    unique_directions = np.sort(np.unique(trial_direcs))

    # Expanding dimensions for broadcasting
    labels_exp = trial_class[:, None]       # shape (n_trials, 1)
    direcs_exp = trial_direcs[:, None]      # shape (n_trials, 1)
    opto_exp = opto_on_mask[:, None, None]        # shape (n_trials, 1, 1)

    # Compare with each label/direction
    label_matches = labels_exp == all_labels[None, :]        # shape (n_trials, n_labels)
    direc_matches = direcs_exp == unique_directions[None, :] # shape (n_trials, n_directions)

    # Count opto-on and opto-off trials
    opto_on_counts = np.sum(opto_exp & label_matches[:, :, None] & direc_matches[:, None, :], axis=0)
    opto_off_counts = np.sum(~opto_exp & label_matches[:, :, None] & direc_matches[:, None, :], axis=0)

    # reshaping the data to have it grouped for the scatter plot
    grouped = np.stack([opto_off_counts, opto_on_counts], axis=0)
    grouped = grouped.transpose(0, 2, 1)
    div.categorical_scatter(
        axis = ax_trial_count,
        grouped_data = grouped,
        conditions_labels = all_labels,
        group_labels= ["Opto-OFF", "Opto-ON"],
        replicates_color = replicates_color,
        averages_color = [opto_off_color, opto_on_color],
        color_scheme_avg = "group",
        group_space = 0.4
    )
    
    # adding legend to the scatter plot
    handles, labels = ax_trial_count.get_legend_handles_labels()
    ax_trial_count.legend(handles[:2], labels[:2], fontsize=7, frameon=False)

    # getting general ylim for all plots
    theta_ylim = [np.min(theta_ylims), np.max(theta_ylims)]
    delta_ylim = [np.min(delta_ylims), np.max(delta_ylims)]
    tread_ylim = [np.min(tread_ylims), np.max(tread_ylims)]
    motion_ylim = [np.min(motion_ylims), np.max(motion_ylims)]

    div.spine_settings(
        axis_objects=[ax_quiet_th, ax_quiet_sw, ax_quiet_tr, ax_quiet_fm,
                      ax_motion_th, ax_motion_sw, ax_motion_tr, ax_motion_fm,
                      ax_loc_th, ax_loc_sw, ax_loc_tr, ax_loc_fm,
                      ax_disc_th, ax_disc_sw, ax_disc_tr, ax_disc_fm,
                      ax_emg_sess, ax_tread_sess, ax_facem_sess, ax_state_sess, ax_trial_count],
        settings_dicts=[{**general_settings, "left": {**general_settings["left"], "lim": theta_ylim, "ticks": ("margin", 0), "label": "6-10 Hz\nLFP [mV]", "label_pad": -8}},
                        {**general_settings, "left": {**general_settings["left"], "lim": delta_ylim, "ticks": ("margin", 0), "label": "0.5-4 Hz\nLFP [mV]", "label_pad": -8}},
                        {**general_settings, "left": {**general_settings["left"], "lim": tread_ylim, "ticks": ("margin", 0), "label": "Velocity\n[cm/s]", "label_pad": -5}},
                        {**general_settings, "left": {**general_settings["left"], "lim": motion_ylim, "ticks": ("data_lim", 0), "label": "Norm. face\nmotion [A.U.]"},
                                             "bottom": {**general_settings["bottom"], "ticks": ([0, 2], 0), "label": "Time [sec]", "label_pad": -10}},
                        {**general_settings, "left": {**general_settings["left"], "lim": theta_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": delta_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": tread_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": motion_ylim},
                                             "bottom": {**general_settings["bottom"], "ticks": ([0, 2], 0), "label": "Time [sec]", "label_pad": -10}},
                        {**general_settings, "left": {**general_settings["left"], "lim": theta_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": delta_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": tread_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": motion_ylim},
                                             "bottom": {**general_settings["bottom"], "ticks": ([0, 2], 0), "label": "Time [sec]", "label_pad": -10}},
                        {**general_settings, "left": {**general_settings["left"], "lim": theta_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": delta_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": tread_ylim}},
                        {**general_settings, "left": {**general_settings["left"], "lim": motion_ylim},
                                             "bottom": {**general_settings["bottom"], "ticks": ([0, 2], 0), "label": "Time [sec]", "label_pad": -10}},
                        {**whole_sess_settings, "left": {**whole_sess_settings["left"], "lim": [0, np.percentile(data_emg, 99)]}},
                        {**whole_sess_settings, "left": {**whole_sess_settings["left"], "lim": [0, np.percentile(treadmill, 99)]}},
                        {**whole_sess_settings, "left": {**whole_sess_settings["left"], "lim": [0, np.percentile(face_motion, 99)]}},
                        whole_sess_settings, trial_settings])

    if save_fig:
        fig_name = "_".join([os.path.basename(session), f"loc_{locomotion_threshold}_{locomotion_time_threshold}_emg_{emg_perc}__fm_{face_motion_perc}_{face_motion_time_threshold}_test1".replace(".","")])
        fig.savefig(os.path.join(folder_to_save, fig_name), dpi=300, bbox_inches='tight')
        
    plt.close(fig)
