In [86]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import h5py
from collections import Counter


from utils.functions import *
from utils.alignment import *
from utils.indication import *
from plotting.plots import *
from utils.save_plots import *

from plotting.plot_values import compute_fec_CR_data, compute_fec_averages
from plotting.plots import plot_histogram, plot_scatter, plot_hexbin, plot_fec_trial

session_date = "20241030"
static_threshold = 2
min_FEC = 0.3
overal_summary_file = f"./outputs/{session_date}/summaries/overal_summary_{session_date}.pdf"

In [87]:
# processing_files(bpod_file = f"./data/{session_date}/bpod_session_data.mat", 
#                      raw_voltage_file = f"./data/{session_date}/raw_voltages.h5", 
#                      dff_file = f"./data/{session_date}/dff.h5", 
#                      save_path = f"./data/{session_date}/saved_trials.h5", 
#                      exclude_start=20, exclude_end=20)

In [88]:
trials = h5py.File(f"./data/{session_date}/saved_trials.h5")["trial_id"]
init_time, init_index, ending_time, ending_index, led_index, ap_index = aligning_times(trials=trials)
fec, fec_time_0 = fec_zero(trials)
fec_0 = moving_average(fec , window_size=7)
fec_normed = min_max_normalize(fec_0)
shorts, longs = block_type(trials)
CR_stat, CR_interval_avg, base_line_avg  = CR_stat_indication(trials , static_threshold = static_threshold, AP_delay = 3)
short_CRp_fec, short_CRn_fec, long_CRp_fec, long_CRn_fec = block_and_CR_fec(CR_stat,fec_0, shorts, longs)
short_CRp_fec_normed, short_CRn_fec_normed, long_CRp_fec_normed, long_CRn_fec_normed = block_and_CR_fec(CR_stat,fec_normed, shorts, longs)
all_id = sort_numbers_as_strings(shorts + longs)
event_diff, ap_diff , ending_diff = index_differences(init_index , led_index, ending_index, ap_index)

In [89]:
short_crp_aligned_dff , short_crp_aligned_time = aligned_dff(trials,shorts,CR_stat, 1, init_index, ending_index, shorts[0])
short_crn_aligned_dff , short_crn_aligned_time = aligned_dff(trials,shorts,CR_stat, 0, init_index, ending_index, shorts[0])
long_crp_aligned_dff , long_crp_aligned_time = aligned_dff(trials,longs,CR_stat, 1, init_index, ending_index, longs[0])
long_crn_aligned_dff , long_crn_aligned_time = aligned_dff(trials,longs,CR_stat, 0, init_index, ending_index, longs[0])

short_crp_avg_pooled, short_crp_sem_pooled, n_short_crp_pooled = calculate_average_dff_pool(short_crp_aligned_dff)
short_crn_avg_pooled, short_crn_sem_pooled, n_short_crn_pooled = calculate_average_dff_pool(short_crn_aligned_dff)
long_crp_avg_pooled,   long_crp_sem_pooled, n_long_crp_pooled = calculate_average_dff_pool(long_crp_aligned_dff)
long_crn_avg_pooled,   long_crn_sem_pooled, n_long_crn_pooled = calculate_average_dff_pool(long_crn_aligned_dff)

short_crp_avg_dff, short_crp_sem_dff, n_short_crp_roi = calculate_average_dff_roi(aligned_dff=short_crp_aligned_dff)
short_crn_avg_dff, short_crn_sem_dff, n_short_crn_roi = calculate_average_dff_roi(aligned_dff=short_crn_aligned_dff)
long_crp_avg_dff,   long_crp_sem_dff, n_long_crp_roi = calculate_average_dff_roi(aligned_dff=long_crp_aligned_dff)
long_crn_avg_dff,   long_crn_sem_dff, n_long_crn_roi = calculate_average_dff_roi(aligned_dff=long_crn_aligned_dff)

short_crp_avg_roi, short_crp_sem_roi = average_over_roi(short_crp_avg_dff)
short_crn_avg_roi, short_crn_sem_roi = average_over_roi(short_crn_avg_dff)
long_crp_avg_roi, long_crp_sem_roi =   average_over_roi(long_crp_avg_dff)
long_crn_avg_roi, long_crn_sem_roi =   average_over_roi(long_crn_avg_dff)

0 52
0 52
0 52
0 52


In [90]:
save_roi_plots_to_pdf(short_crp_avg_dff, short_crn_avg_dff, short_crp_sem_dff, short_crn_sem_dff, 
                    short_crp_aligned_time, long_crp_avg_dff, long_crn_avg_dff, long_crp_sem_dff, long_crn_sem_dff, 
                    long_crp_aligned_time, trials, pdf_filename = f"./outputs/{session_date}/individual_ROI_plots.pdf")

save_fec_plots_to_pdf(trials, fec_time_0, fec_0, CR_stat,all_id, f"./outputs/{session_date}/individual_FEC_plots.pdf")

All plots have been saved to ./outputs/20241030/individual_ROI_plots.pdf


In [91]:
interval_window_led = 120
interval_window_cr = 120
interval_window_ap = 120
interval_window_bl = 120

cr_interval_short_crn, led_interval_short_crn, ap_interval_short_crn, base_line_interval_short_crn = intervals(
    short_crn_aligned_dff, led_index, ap_index, interval_window_led, interval_window_cr, interval_window_ap, interval_window_bl, isi_time=200, sample_id=list(short_crn_aligned_dff[0].keys())[0])

cr_interval_short_crp, led_interval_short_crp, ap_interval_short_crp, base_line_interval_short_crp = intervals(
    short_crp_aligned_dff, led_index, ap_index, interval_window_led, interval_window_cr, interval_window_ap, interval_window_bl, isi_time=200, sample_id=list(short_crp_aligned_dff[0].keys())[0])

cr_interval_long_crn, led_interval_long_crn, ap_interval_long_crn, base_line_interval_long_crn = intervals(
    long_crn_aligned_dff, led_index, ap_index, interval_window_led, interval_window_cr, interval_window_ap, interval_window_bl, isi_time=400, sample_id=list(long_crn_aligned_dff[0].keys())[0])

cr_interval_long_crp, led_interval_long_crp, ap_interval_long_crp, base_line_interval_long_crp = intervals(
    long_crp_aligned_dff, led_index, ap_index, interval_window_led, interval_window_cr, interval_window_ap, interval_window_bl, isi_time=400, sample_id=list(long_crp_aligned_dff[0].keys())[0])

trial_types = {
    "Short CRN": {"baseline": interval_averaging(base_line_interval_short_crn),"led": interval_averaging(led_interval_short_crn),"cr": interval_averaging(cr_interval_short_crn),"ap": interval_averaging(ap_interval_short_crn),"color": "blue"},
    "Short CRP": {"baseline": interval_averaging(base_line_interval_short_crp),"led": interval_averaging(led_interval_short_crp),"cr": interval_averaging(cr_interval_short_crp),"ap": interval_averaging(ap_interval_short_crp),"color": "red"},
    "Long CRN": {"baseline": interval_averaging(base_line_interval_long_crn),"led": interval_averaging(led_interval_long_crn),"cr": interval_averaging(cr_interval_long_crn),"ap": interval_averaging(ap_interval_long_crn),"color": "blue"},
    "Long CRP": {"baseline": interval_averaging(base_line_interval_long_crp),"led": interval_averaging(led_interval_long_crp),"cr": interval_averaging(cr_interval_long_crp),"ap": interval_averaging(ap_interval_long_crp),"color": "red"},
}

valid_ROIs = {trial_type: {event: [] for event in ["led", "cr", "ap"]} for trial_type in trial_types}
for trial_type, data in trial_types.items():
    baseline_avg = data["baseline"]
    for event in ["led", "cr", "ap"]:
        event_avg = data[event]
        for roi, event_values in event_avg.items():
            baseline_value = baseline_avg.get(roi, np.nan)
            if np.nanmean(event_values) > np.nanmean(baseline_value):
                valid_ROIs[trial_type][event].append(roi)
print(valid_ROIs)


t_stat_short_crn_led, p_value_short_crn_led = ttest_intervals(base_interval=base_line_interval_short_crn, interval_under_test=led_interval_short_crn, roi_list=valid_ROIs["Short CRN"]["led"])
t_stat_short_crn_ap, p_value_short_crn_ap = ttest_intervals(base_interval=base_line_interval_short_crn, interval_under_test=ap_interval_short_crn, roi_list=valid_ROIs["Short CRN"]["ap"]) 
t_stat_short_crn_cr, p_value_short_crn_cr = ttest_intervals(base_interval=base_line_interval_short_crn, interval_under_test=cr_interval_short_crn, roi_list=valid_ROIs["Short CRN"]["cr"])

t_stat_short_crp_led, p_value_short_crp_led = ttest_intervals(base_interval=base_line_interval_short_crp, interval_under_test=led_interval_short_crp, roi_list=valid_ROIs["Short CRP"]["led"])
t_stat_short_crp_ap, p_value_short_crp_ap = ttest_intervals(base_interval=base_line_interval_short_crp, interval_under_test=ap_interval_short_crp, roi_list=valid_ROIs["Short CRP"]["ap"])
t_stat_short_crp_cr, p_value_short_crp_cr = ttest_intervals(base_interval=base_line_interval_short_crp, interval_under_test=cr_interval_short_crp, roi_list=valid_ROIs["Short CRP"]["cr"])

t_stat_long_crn_led, p_value_long_crn_led = ttest_intervals(base_interval=base_line_interval_long_crn, interval_under_test=led_interval_long_crn, roi_list=valid_ROIs["Long CRN"]["led"])
t_stat_long_crn_ap, p_value_long_crn_ap = ttest_intervals(base_interval=base_line_interval_long_crn, interval_under_test=ap_interval_long_crn, roi_list=valid_ROIs["Long CRN"]["ap"])
t_stat_long_crn_cr, p_value_long_crn_cr = ttest_intervals(base_interval=base_line_interval_long_crn, interval_under_test=cr_interval_long_crn, roi_list=valid_ROIs["Long CRN"]["cr"])

t_stat_long_crp_led, p_value_long_crp_led = ttest_intervals(base_interval=base_line_interval_long_crp, interval_under_test=led_interval_long_crp, roi_list=valid_ROIs["Long CRP"]["led"])
t_stat_long_crp_ap, p_value_long_crp_ap = ttest_intervals(base_interval=base_line_interval_long_crp, interval_under_test=ap_interval_long_crp, roi_list=valid_ROIs["Long CRP"]["ap"])
t_stat_long_crp_cr, p_value_long_crp_cr = ttest_intervals(base_interval=base_line_interval_long_crp, interval_under_test=cr_interval_long_crp, roi_list=valid_ROIs["Long CRP"]["cr"])

t_avg_short_crn_led = calculate_average_ttest(t_stat_short_crn_led)
t_avg_short_crn_ap = calculate_average_ttest(t_stat_short_crn_ap)
t_avg_short_crn_cr = calculate_average_ttest(t_stat_short_crn_cr)

t_avg_short_crp_led = calculate_average_ttest(t_stat_short_crp_led)
t_avg_short_crp_ap = calculate_average_ttest(t_stat_short_crp_ap)
t_avg_short_crp_cr = calculate_average_ttest(t_stat_short_crp_cr)

t_avg_long_crn_led = calculate_average_ttest(t_stat_long_crn_led)
t_avg_long_crn_ap = calculate_average_ttest(t_stat_long_crn_ap)
t_avg_long_crn_cr = calculate_average_ttest(t_stat_long_crn_cr)

t_avg_long_crp_led = calculate_average_ttest(t_stat_long_crp_led)
t_avg_long_crp_ap = calculate_average_ttest(t_stat_long_crp_ap)
t_avg_long_crp_cr = calculate_average_ttest(t_stat_long_crp_cr)

t_stats = {
    "led": [t_avg_short_crn_led, t_avg_short_crp_led, t_avg_long_crn_led, t_avg_long_crp_led],
    "ap": [t_avg_short_crn_ap, t_avg_short_crp_ap, t_avg_long_crn_ap, t_avg_long_crp_ap],
    "cr": [t_avg_short_crn_cr, t_avg_short_crp_cr, t_avg_long_crn_cr, t_avg_long_crp_cr],
}

common_rois = {event: Counter(extract_top_rois(t_stats_list)).most_common(7) for event, t_stats_list in t_stats.items()}

# Extract top ROI IDs
led_roi = [int(roi) for roi, _ in common_rois["led"]]
ap_roi = [int(roi) for roi, _ in common_rois["ap"]]
cr_roi = [int(roi) for roi, _ in common_rois["cr"]]

sig_rois = {}
sig_rois["led"] = led_roi
sig_rois["ap"] = ap_roi
sig_rois["cr"] = cr_roi

{'Short CRN': {'led': [0, 1, 5, 8, 10, 13, 14, 20, 21, 22, 24, 25, 26, 30, 33, 34, 35, 38, 42, 43, 44, 46, 47, 51], 'cr': [5, 6, 8, 11, 14, 15, 21, 25, 26, 28, 29, 30, 33, 34, 38, 43, 51], 'ap': [3, 5, 11, 14, 15, 17, 20, 24, 25, 26, 28, 30, 34, 35, 38, 40, 41, 43, 44, 51]}, 'Short CRP': {'led': [0, 3, 5, 6, 7, 8, 9, 10, 12, 16, 19, 22, 24, 25, 26, 27, 33, 36, 39, 41, 45, 46, 50, 51], 'cr': [0, 3, 4, 5, 6, 7, 8, 11, 12, 16, 17, 18, 19, 22, 26, 27, 31, 32, 33, 35, 36, 38, 39, 41, 42, 49, 51], 'ap': [0, 1, 3, 5, 6, 7, 9, 10, 11, 12, 16, 19, 20, 21, 26, 27, 28, 30, 32, 36, 38, 39, 41, 42, 45, 46, 47, 49, 51]}, 'Long CRN': {'led': [0, 3, 6, 8, 10, 12, 13, 14, 15, 16, 17, 21, 25, 26, 31, 32, 33, 34, 36, 37, 38, 39, 40, 42, 44, 45, 46, 47, 48, 50, 51], 'cr': [3, 4, 5, 8, 10, 11, 13, 15, 17, 19, 21, 24, 25, 28, 32, 34, 35, 36, 37, 40, 42, 44, 45, 46, 48, 50], 'ap': [3, 4, 5, 8, 10, 11, 15, 17, 19, 25, 28, 29, 31, 32, 35, 36, 37, 40, 42, 43, 44, 45, 46, 48, 50]}, 'Long CRP': {'led': [5, 10, 12

In [92]:
short_crp_avg_led_sig, short_crp_sem_led_sig = calculate_average_sig(short_crp_aligned_dff, roi_indices=led_roi)
short_crn_avg_led_sig, short_crn_sem_led_sig = calculate_average_sig(short_crn_aligned_dff, roi_indices=led_roi)
long_crp_avg_led_sig, long_crp_sem_led_sig =   calculate_average_sig(long_crp_aligned_dff , roi_indices=led_roi)
long_crn_avg_led_sig, long_crn_sem_led_sig =   calculate_average_sig(long_crn_aligned_dff , roi_indices=led_roi)
short_crp_avg_ap_sig, short_crp_sem_ap_sig = calculate_average_sig(short_crp_aligned_dff, roi_indices=ap_roi)
short_crn_avg_ap_sig, short_crn_sem_ap_sig = calculate_average_sig(short_crn_aligned_dff, roi_indices=ap_roi)
long_crp_avg_ap_sig, long_crp_sem_ap_sig =   calculate_average_sig(long_crp_aligned_dff , roi_indices=ap_roi)
long_crn_avg_ap_sig, long_crn_sem_ap_sig =   calculate_average_sig(long_crn_aligned_dff , roi_indices=ap_roi)
short_crp_avg_cr_sig, short_crp_sem_cr_sig = calculate_average_sig(short_crp_aligned_dff, roi_indices=cr_roi)
short_crn_avg_cr_sig, short_crn_sem_cr_sig = calculate_average_sig(short_crn_aligned_dff, roi_indices=cr_roi)
long_crp_avg_cr_sig, long_crp_sem_cr_sig =   calculate_average_sig(long_crp_aligned_dff , roi_indices=cr_roi)
long_crn_avg_cr_sig, long_crn_sem_cr_sig =   calculate_average_sig(long_crn_aligned_dff , roi_indices=cr_roi)
with PdfPages(filename=f"./outputs/{session_date}/significant_ROIs/sig_ROIs_summary.pdf") as sig_summary_pdf:
    fig, axs = plt.subplots(7, 3, figsize=(20, 40))

    # Assign specific axes for the first set of plots
    ax0 = axs[0, 0]
    ax1 = axs[0, 1]
    ax2 = axs[1, 0]
    ax3 = axs[1, 1]
    ax4 = axs[2, 0]
    ax5 = axs[2, 1]

    # Plot using `plot_trial_averages_sig` for the first set of axes
    plot_trial_averages_sig(trials, short_crp_aligned_time, short_crp_avg_led_sig, short_crp_sem_led_sig, short_crn_avg_led_sig, short_crn_sem_led_sig, title_suffix="Short", event="LED", pooled=True, ax=ax0)
    plot_trial_averages_sig(trials, long_crp_aligned_time, long_crp_avg_led_sig, long_crp_sem_led_sig, long_crn_avg_led_sig, long_crn_sem_led_sig, title_suffix="Long", event="LED", pooled=True, ax=ax1)
    plot_trial_averages_sig(trials, short_crp_aligned_time, short_crp_avg_ap_sig, short_crp_sem_ap_sig, short_crn_avg_ap_sig, short_crn_sem_ap_sig, title_suffix="Short", event="AirPuff", pooled=True, ax=ax2)
    plot_trial_averages_sig(trials, long_crp_aligned_time, long_crp_avg_ap_sig, long_crp_sem_ap_sig, long_crn_avg_ap_sig, long_crn_sem_ap_sig, title_suffix="Long", event="AP", pooled=True, ax=ax3)
    plot_trial_averages_sig(trials, short_crp_aligned_time, short_crp_avg_cr_sig, short_crp_sem_cr_sig, short_crn_avg_cr_sig, short_crn_sem_cr_sig, title_suffix="Short", event="CR", pooled=True, ax=ax4)
    plot_trial_averages_sig(trials, long_crp_aligned_time, long_crp_avg_cr_sig, long_crp_sem_cr_sig, long_crn_avg_cr_sig, long_crn_sem_cr_sig, title_suffix="Long", event="CR", pooled=True, ax=ax5)

    # fig, axs = plt.subplots(4, 3, figsize=(15, 40))

    # Style adjustments for all axes
    for ax in axs.flat:
        ax.spines['top'].set_visible(False)  # Hide the top spine
        ax.spines['right'].set_visible(False)  # Hide the right spine
        ax.yaxis.set_ticks_position('left')  # Show ticks only on the left
        ax.xaxis.set_ticks_position('bottom')  # Show ticks only on the bottom

    # Hide axes for the correct positions (third column for specific rows)
    for i, ax_row in enumerate(axs):
        for j, ax in enumerate(ax_row):
            if i not in [3, 4, 5, 6] and j == 2:  # Hide axes in the third column for specific rows
                ax.axis("off")

    # Additional scatter plots for trial types and events
    for idx, (trial_type, data) in enumerate(trial_types.items()):
        baseline_avg = data["baseline"]
        color_other = "blue" if "CRN" in trial_type else "red"
        for event_idx, event in enumerate(["led", "cr", "ap"]):
            ax = axs[idx + 3, event_idx]  # Access the correct axis directly from axs grid
            event_avg = data[event]
            baseline_values = []
            event_values = []
            colors = []
            for roi, event_values_array in event_avg.items():
                baseline_value = baseline_avg.get(roi, np.nan)
                event_value_mean = np.nanmean(event_values_array)
                baseline_value_mean = np.nanmean(baseline_value)
                baseline_values.append(baseline_value_mean)
                event_values.append(event_value_mean)
                if roi in valid_ROIs[trial_type][event] and roi in sig_rois[event]:
                    colors.append("lime")
                else:
                    colors.append(color_other)
            
            ax.scatter(baseline_values, event_values, c=colors, alpha=0.7, edgecolor="black")
            ax.set_title(f"{trial_type} - {event.capitalize()} Event")
            ax.set_xlabel("Baseline Average")
            ax.set_ylabel(f"Evoked signal for {event.capitalize()} event Average")
            ax.axline((0, 0), slope=1, color="gray", linestyle="--")  # Line y=x for reference
            ax.legend(
                handles=[
                    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lime', markersize=8, label="Significant ROI"),
                    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_other, markersize=8, label="Non-significant ROI")
                ], 
                loc="upper left"
            )
            ax.grid(False)

    print(f"saved in ./outputs/{session_date}/significant_ROIs/sig_ROIs_summary.pdf")

    plt.tight_layout()
    sig_summary_pdf.savefig(fig)
    plt.close()

save_roi_plots_to_pdf_sig(short_crp_avg_dff, short_crn_avg_dff, short_crp_sem_dff, short_crn_sem_dff, 
                          short_crp_aligned_time, long_crp_avg_dff, long_crn_avg_dff, long_crp_sem_dff, long_crn_sem_dff, 
                          long_crp_aligned_time, trials, pdf_filename = sig_pdf_name(session_date, sig_event="LED"), ROI_list=led_roi)

save_roi_plots_to_pdf_sig(short_crp_avg_dff, short_crn_avg_dff, short_crp_sem_dff, short_crn_sem_dff, 
                          short_crp_aligned_time, long_crp_avg_dff, long_crn_avg_dff, long_crp_sem_dff, long_crn_sem_dff, 
                          long_crp_aligned_time, trials, pdf_filename = sig_pdf_name(session_date, sig_event="AP"), ROI_list=ap_roi)

save_roi_plots_to_pdf_sig(short_crp_avg_dff, short_crn_avg_dff, short_crp_sem_dff, short_crn_sem_dff, 
                          short_crp_aligned_time, long_crp_avg_dff, long_crn_avg_dff, long_crp_sem_dff, long_crn_sem_dff, 
                          long_crp_aligned_time, trials, pdf_filename = sig_pdf_name(session_date, sig_event="CR"), ROI_list=cr_roi)

saved in ./outputs/20241030/significant_ROIs/sig_ROIs_summary.pdf
All plots have been saved to ./outputs/20241030/significant_ROIs/individual_LED_sig_roi.pdf
All plots have been saved to ./outputs/20241030/significant_ROIs/individual_AP_sig_roi.pdf
All plots have been saved to ./outputs/20241030/significant_ROIs/individual_CR_sig_roi.pdf


In [93]:
data_fec_scatter = compute_fec_CR_data(base_line_avg, CR_interval_avg, CR_stat)
data_fec_average = compute_fec_averages(short_CRp_fec, short_CRn_fec, long_CRp_fec, long_CRn_fec, fec_time_0, shorts, longs, trials)
data_fec_average_normed = compute_fec_averages(short_CRp_fec_normed, short_CRn_fec_normed, 
            long_CRp_fec_normed, long_CRn_fec_normed, fec_time_0, shorts, longs, trials)

baselines = data_fec_scatter['baselines']
cr_amplitudes = data_fec_scatter['cr_amplitudes']
cr_relative_changes = data_fec_scatter['cr_relative_changes']
baselines_crp = data_fec_scatter['baselines_crp']
baselines_crn = data_fec_scatter['baselines_crn']
cr_amplitudes_crp = data_fec_scatter['cr_amplitudes_crp']
cr_amplitudes_crn = data_fec_scatter['cr_amplitudes_crn']
cr_relative_changes_crp = data_fec_scatter['cr_relative_changes_crp']
cr_relative_changes_crn = data_fec_scatter['cr_relative_changes_crn']
all_baselines = data_fec_scatter['all_baselines']
all_relative_changes = data_fec_scatter['all_relative_changes']

# Sort and create heatmaps for all datasets
sorted_avg_short_crp_roi = sort_dff_avg(short_crp_avg_dff, event_diff, ap_diff)
sorted_avg_long_crp_roi = sort_dff_avg(long_crp_avg_dff, event_diff, ap_diff)
sorted_avg_short_crn_roi = sort_dff_avg(short_crn_avg_dff, event_diff, ap_diff)
sorted_avg_long_crn_roi = sort_dff_avg(long_crn_avg_dff, event_diff, ap_diff)

# Create heat arrays
heat_arrays_avg = []
for sorted_avg_rois, dff in [
    (sorted_avg_short_crp_roi, short_crp_avg_dff),
    (sorted_avg_long_crp_roi, long_crp_avg_dff),
    (sorted_avg_short_crn_roi, short_crn_avg_dff),
    (sorted_avg_long_crn_roi, long_crn_avg_dff),
]:
    heat_array_0 = [dff[roi] for roi in list(sorted_avg_rois)]
    heat_arrays_avg.append(np.vstack(heat_array_0))


sorted_max_short_crp_roi = sort_dff_max_index(short_crp_avg_dff, event_diff, ap_diff)
sorted_max_short_crn_roi = sort_dff_max_index(short_crn_avg_dff, event_diff, ap_diff)
sorted_max_long_crp_roi =  sort_dff_max_index(long_crp_avg_dff, event_diff, ap_diff)
sorted_max_long_crn_roi =  sort_dff_max_index(long_crn_avg_dff, event_diff, ap_diff)

# Create heat arrays
heat_arrays_max = []
for sorted_max_rois, dff in [
    (sorted_max_short_crp_roi, short_crp_avg_dff),
    (sorted_max_long_crp_roi, long_crp_avg_dff),
    (sorted_max_short_crn_roi, short_crn_avg_dff),
    (sorted_max_long_crn_roi, long_crn_avg_dff),
]:
    heat_array_0 = [dff[roi] for roi in list(sorted_max_rois)]
    heat_arrays_max.append(np.vstack(heat_array_0))

aligned_times = [
    short_crp_aligned_time,
    long_crp_aligned_time,
    short_crn_aligned_time,
    long_crn_aligned_time
]

heatmap_titles_avg = ["Sorted heatmap according to the average signal in the CR window - Short CR+", "Sorted heatmap according to the average signal in the CR window - Long CR+", "Sorted heatmap according to the average signal in the CR window - Short CR-", "Sorted heatmap according to the average signal in the CR window - Long CR-"]
heatmap_titles_max = ["Sorted heatmap according to the time of maximum value of the signal in the CR window - Short CR+", "Sorted heatmap according to the time of maximum value of the signal in the CR window - Long CR+", "Sorted heatmap according to the time of maximum value of the signal in the CR window - Short CR-", "Sorted heatmap according to the time of maximum value of the signal in the CR window - Long CR-"]
color_maps = ["magma", "magma", "viridis", "viridis"]


metadata = {
    'Title': f"Overall Summary of {session_date}",  # Set the PDF title here
    'Author': 'Shayan Malekpour',  # Optional: Add author metadata
    'Subject': 'Summary of Analysis',  # Optional: Add subject metadata
    'Keywords': 'FEC, CR, Baseline, Analysis'  # Optional: Add keywords
}

with PdfPages(filename=overal_summary_file, metadata=metadata) as summary_pdf:
    # fig, axes = plt.subplots(7, 3, figsize=(20, 40))
    fig, axes = plt.subplots(14, 3, figsize=(20, 80), gridspec_kw={'width_ratios': [1, 1, 0.03] , 'height_ratios': [1, 1, 1, 1, 0.5, 0.5, 0.5, 0.5, 1, 1, 1, 1, 1, 1]}, squeeze=True)
    fig.suptitle(f"Overall Summary Report of {session_date}", fontsize=24)
    # Hide empty axes
    for i, ax_row in enumerate(axes):
        for j, ax in enumerate(ax_row):
            if i not in [4, 5, 6, 7, 10] and j == 2:  # Corrected condition to target the third column
                ax.axis("off")
    for ax in axes.flat:  # Iterate over all axes
        ax.spines['top'].set_visible(False)  # Hide the top spine
        ax.spines['right'].set_visible(False)  # Hide the right spine
        ax.yaxis.set_ticks_position('left')  # Show ticks only on the left
        ax.xaxis.set_ticks_position('bottom')  # Show ticks only on the bottom





    short_data = data_fec_average["short_trials"]
    plot_fec_trial(
        axes[0, 0],
        short_data["time"],
        short_data["mean1"],
        short_data["std1"],
        short_data["mean0"],
        short_data["std0"],
        short_data["led"],
        short_data["airpuff"],
        y_lim = min_FEC,
        title="FEC Average for Short Trials",
    )

    # Data for long trials
    long_data = data_fec_average["long_trials"]
    plot_fec_trial(
        axes[0, 1],
        long_data["time"],
        long_data["mean1"],
        long_data["std1"],
        long_data["mean0"],
        long_data["std0"],
        long_data["led"],
        long_data["airpuff"],
        y_lim = min_FEC,
        title="FEC Average for Long Trials"
    )

    short_data_normed = data_fec_average_normed["short_trials"]
    plot_fec_trial(
        axes[1, 0],
        short_data_normed["time"],
        short_data_normed["mean1"],
        short_data_normed["std1"],
        short_data_normed["mean0"],
        short_data_normed["std0"],
        short_data_normed["led"],
        short_data_normed["airpuff"],
        y_lim = 0,
        title="Normalized FEC Average for Short Trials",
    )

    # Data for long trials
    long_data_normed = data_fec_average_normed["long_trials"]
    plot_fec_trial(
        axes[1, 1],
        long_data_normed["time"],
        long_data_normed["mean1"],
        long_data_normed["std1"],
        long_data_normed["mean0"],
        long_data_normed["std0"],
        long_data_normed["led"],
        long_data_normed["airpuff"],
        y_lim = 0,
        title="Normalized FEC Average for Long Trials"
    )

    plot_trial_averages_side_by_side(
        axes[2, 0], axes[2, 1],  # Axes for plotting
        n_short_crp_roi, n_short_crn_roi, short_crp_aligned_time, short_crp_avg_roi, short_crp_sem_roi, short_crn_avg_roi, short_crn_sem_roi,  # Short trial data
        n_long_crp_roi, n_long_crn_roi, long_crp_aligned_time, long_crp_avg_roi, long_crp_sem_roi, long_crn_avg_roi, long_crn_sem_roi,  # Long trial data
        trials,  # Trial information
        title_suffix1="Short", title_suffix2="Long"  # Titles for plots
    )

    plot_trial_averages_side_by_side(
        axes[3, 0], axes[3, 1],
        n_short_crp_pooled, n_short_crn_pooled, short_crp_aligned_time, short_crp_avg_pooled, short_crp_sem_pooled, 
        short_crn_avg_pooled, short_crn_sem_pooled, 
        n_long_crp_pooled, n_long_crn_pooled, long_crp_aligned_time, long_crp_avg_pooled, long_crp_sem_pooled, 
        long_crn_avg_pooled, long_crn_sem_pooled, 
        trials, title_suffix1="Short", title_suffix2="Long", pooled=True)

    plot_heatmaps_side_by_side(heat_arrays_avg, aligned_times, heatmap_titles_avg, trials, color_maps=color_maps, axes= [axes[4,0], axes[4, 1], axes[5,0], axes[5,1]])
    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap="magma"), cax=axes[4,2])
    cbar.set_label("dF/F intensity")

    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap="viridis" ), cax=axes[5,2])
    cbar.set_label("dF/F intensity")

    plot_heatmaps_side_by_side(heat_arrays_max, aligned_times, heatmap_titles_max, trials, color_maps=color_maps, axes= [axes[6,0], axes[6, 1], axes[7,0], axes[7,1]])

    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap="magma" ), cax=axes[6,2])
    cbar.set_label("dF/F intensity")

    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap="viridis" ), cax=axes[7,2])
    cbar.set_label("dF/F intensity")

    # Histogram of baseline4
    plot_histogram(
        axes[8 , 0], baselines, bins=20, color='lime', edgecolor='black', alpha=0.7,
        title='Distribution of Baseline Values Across Sessions',
        xlabel='Baseline Value', ylabel='Frequency'
    )

    # Histogram of CR amplitudes
    plot_histogram(
        axes[8, 1], cr_amplitudes, bins=20, color='green', edgecolor='black', alpha=0.7,
        title='Distribution of CR Amplitudes Across Sessions',
        xlabel='CR Value', ylabel='Frequency'
    )

    # Scatter plot for CR+ and CR-
    plot_scatter(
        axes[9, 0], baselines_crp, cr_amplitudes_crp, color='red', alpha=0.7, label='CR+',
        title='CR Amplitude (Absolute) vs. Baseline', xlabel='Baseline', ylabel='CR Amplitude (Absolute)'
    )
    plot_scatter(
        axes[9, 0], baselines_crn, cr_amplitudes_crn, color='blue', alpha=0.7, label='CR-',
        title='', xlabel='', ylabel=''  # Title and labels already set
    )

    # Scatter plot for relative change (CR+ and CR-)
    plot_scatter(
        axes[9, 1], baselines_crp, cr_relative_changes_crp, color='red', alpha=0.7, label='CR+',
        title='CR Size (Relative Change) vs. Baseline', xlabel='Baseline', ylabel='CR Size (Relative Change)'
    )
    plot_scatter(
        axes[9, 1], baselines_crn, cr_relative_changes_crn, color='blue', alpha=0.7, label='CR-',
        title='', xlabel='', ylabel=''  # Title and labels already set
    )

    # Hexbin for CR amplitude vs. baseline
    plot_hexbin(
        axes[10, 0], baselines, cr_amplitudes, gridsize=30, cmap='Greens', mincnt=1, alpha=1.0,
        colorbar_label='Count',
        title='Joint Distribution of CR Amplitude and Baseline',
        xlabel='Baseline', ylabel='CR Amplitude (Absolute)'
    )

    # Hexbin for relative change vs. baseline
    plot_hexbin(
        axes[10, 1], all_baselines, all_relative_changes, gridsize=30, cmap='Greens', mincnt=1, alpha=0.7,
        colorbar_label='Count',
        title='Joint Distribution of CR Size (Relative Change) and Baseline',
        xlabel='Baseline', ylabel='CR Size (Relative Change)'
    )

    cbar = plt.colorbar(plt.cm.ScalarMappable(cmap="Greens" ), cax=axes[10,2])
    cbar.set_label("dF/F intensity")

        # Assign specific axes for the first set of plots
    ax0 = axes[11, 0]
    ax1 = axes[11, 1]
    ax2 = axes[12, 0]
    ax3 = axes[12, 1]
    ax4 = axes[13, 0]
    ax5 = axes[13, 1]

    # Plot using `plot_trial_averages_sig` for the first set of axes
    plot_trial_averages_sig(trials, short_crp_aligned_time, short_crp_avg_led_sig, short_crp_sem_led_sig, short_crn_avg_led_sig, short_crn_sem_led_sig, title_suffix="Short", event="LED", pooled=True, ax=ax0)
    plot_trial_averages_sig(trials, long_crp_aligned_time, long_crp_avg_led_sig, long_crp_sem_led_sig, long_crn_avg_led_sig, long_crn_sem_led_sig, title_suffix="Long", event="LED", pooled=True, ax=ax1)
    plot_trial_averages_sig(trials, short_crp_aligned_time, short_crp_avg_ap_sig, short_crp_sem_ap_sig, short_crn_avg_ap_sig, short_crn_sem_ap_sig, title_suffix="Short", event="AirPuff", pooled=True, ax=ax2)
    plot_trial_averages_sig(trials, long_crp_aligned_time, long_crp_avg_ap_sig, long_crp_sem_ap_sig, long_crn_avg_ap_sig, long_crn_sem_ap_sig, title_suffix="Long", event="AP", pooled=True, ax=ax3)
    plot_trial_averages_sig(trials, short_crp_aligned_time, short_crp_avg_cr_sig, short_crp_sem_cr_sig, short_crn_avg_cr_sig, short_crn_sem_cr_sig, title_suffix="Short", event="CR", pooled=True, ax=ax4)
    plot_trial_averages_sig(trials, long_crp_aligned_time, long_crp_avg_cr_sig, long_crp_sem_cr_sig, long_crn_avg_cr_sig, long_crn_sem_cr_sig, title_suffix="Long", event="CR", pooled=True, ax=ax5)


    plt.tight_layout(rect=[0, 0, 1, 0.98])
    summary_pdf.savefig(fig)
    plt.close(fig)