In [None]:
import sys
import os
import random
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', )))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', '..')))

from collections import Counter
from pathlib import Path

from scipy.stats import mannwhitneyu, wilcoxon

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib import rc
import matplotlib.gridspec as gridspec

import seaborn as sns

from config_colors import *
from config_paths import *

from nwb_io import *

from calc_responsive_units import *
from calc_cluster_permutation import *
from plot_response_functions import *

from tabulate import tabulate
from nilearn.plotting import plot_markers

np.set_printoptions(suppress=True)
pd.set_option('display.float_format', '{:.8f}'.format)

######################

# save panels directly to the relevant svg/ subdir
panel_save_dir = Path.cwd().parent.parent / "figure_generation" / "figure_responsive_units" / "svg"
save_path = Path("/home/al/Documents/phd/code/dhv_dataset/plot_code/responsive_units/results/")

data_dir = NWB_data_dir

pat_subset = [14,20,23,28,30,31,33,39,41,42,
              44,50,52,53,60,61,64,65,66,68,
              73,83,87,88,89,90,92,96,98]

titlesize = 25
labelsize = 22
ticklabelsize = 20

## Run data from NWB files.

In [None]:
for label in ["summer", "scenes", "tom"]:
    df_annotation = collect_all_annotation_data(data_dir, label)
    df_annotation.to_pickle(save_path / f"nwb_{label}_df_annotation.pkl")

In [None]:
region = "A"

df_region_restricted_units = collect_all_spike_data_in_region(data_dir, region, region_name_dict=REGION_ALTERNATIVE_NAMES, drop_waveforms=False)
df_region_restricted_units.to_pickle(save_path / f"nwb_{region}_df_region_restricted.pkl")

In [None]:
region = "H"

df_region_restricted_units = collect_all_spike_data_in_region(data_dir, region, region_name_dict=REGION_ALTERNATIVE_NAMES, drop_waveforms=False)
df_region_restricted_units.to_pickle(save_path / f"nwb_{region}_df_region_restricted.pkl")

In [None]:
region = "EC"

df_region_restricted_units = collect_all_spike_data_in_region(data_dir, region, region_name_dict=REGION_ALTERNATIVE_NAMES, drop_waveforms=False)
df_region_restricted_units.to_pickle(save_path / f"nwb_{region}_df_region_restricted.pkl")

## PHC: camera cuts

#### Load from NWB

In [None]:
region = "PHC"

df_region_restricted_units = collect_all_spike_data_in_region(data_dir, region, region_name_dict=REGION_ALTERNATIVE_NAMES, drop_waveforms=False)
df_region_restricted_units.to_pickle(save_path / f"nwb_{region}_df_region_restricted.pkl")

In [None]:
label = "camera-cuts"
df_annotation = collect_all_annotation_data(data_dir, label)
df_annotation.to_pickle(save_path / f"nwb_{label}_df_annotation.pkl")

#### Load from file

In [None]:
region = "PHC"
load_path = f"/home/al/Documents/phd/code/dhv_dataset/plot_code/responsive_units/results/nwb_{region}_df_region_restricted.pkl"
df_region_restricted_units = pd.read_pickle(load_path)

label = "camera-cuts"
df_annotation = pd.read_pickle(save_path / f"nwb_{label}_df_annotation.pkl")

#### Calculate responses

In [None]:
baseline_time = 1000
stimulus_time = 1000
key = 1
restriction = "increase"
alpha = 0.001

# restrict the spike dataframe to just that of the patient
#df_patient_units = df_region_restricted_units[df_region_restricted_units["patient_id"]==patient_id]
df_unit_data = df_region_restricted_units.sort_values(by=["patient_id", "unit_id"])

reg_act, reg_pvals, sort_inds, ct_sig_units, start_ns, end_ns = identify_responses(df_unit_data, df_annotation, pat_subset, 
                                                                                   baseline_time=baseline_time, stimulus_time=stimulus_time, key=key, restriction=restriction, alpha=alpha)

df_ = pd.DataFrame(data=reg_act)
df_["pval"] = reg_pvals
df = df_.sort_values(by=['pval'])
pval_inds = df["pval"]
df_plot = df.copy()
df_plot = df_plot.drop(columns="pval")
lineplot_df = df_formatter(reg_pvals, reg_act)

df["patient_id"] = df_unit_data["patient_id"].to_numpy()[df.index]
df["unit_id"] = df_unit_data["unit_id"].to_numpy()[df.index]
df.to_pickle(save_path / f"nwb_{label}_{region}_{restriction}_processed_responses.pkl")

#### Plot the response.

In [None]:
a = 150
# DENSITY_BINS = np.linspace(-a, a, a)
# xticks = np.linspace(0, 19, 5)
# ticklabels = [-1000, -500, 0, 500, 1000]

amp_min = -125
amp_max = 250
steps = 200

DENSITY_BINS = np.linspace(amp_min, amp_max, steps)
xticks = np.linspace(0, 19, 5)
ticklabels = [-1000, -500, 0, 500, 1000]

yticks = [0, steps]
yticklabels = [amp_min, amp_max]

# yticks = np.linspace(0, 150, 3)
# yticklabels = [amp_min, 0, amp_max]

In [None]:
patient_id = 68
unit_id = 33
unit_rank = None

act_z, bins, binned_act, spike_times, event_spikes, waveform, onsets = grab_unit_response(unit_rank=unit_rank, patient_id=patient_id, unit_id=unit_id, 
                      df=df, df_unit_data=df_unit_data, df_annotation=df_annotation, 
                      key=key, baseline_time=baseline_time, stimulus_time=stimulus_time)

In [None]:
label = "camera-cuts"
title = f"Camera Cuts\nPHC\n{patient_id}-{unit_id}, {len(onsets)} events"

fig = plt.figure(figsize=(6, 13))
gs0 = gridspec.GridSpec(1, 1, figure=fig)

l_c = sns.color_palette("Spectral", n_colors=6).as_hex()[0]

gs = gridspec.GridSpecFromSubplotSpec(5, 3, subplot_spec=gs0[0], height_ratios=[75, 1, 300, 1, 75],
                                        hspace=0.)

ax_shape  = fig.add_subplot(gs[0, 1])
ax_raster = fig.add_subplot(gs[2:4, :])
ax_firing = fig.add_subplot(gs[4, :], sharex=ax_raster)

spike_amp_plot_asset(ax_shape, np.array(waveform), "Spectral_r", invert=False, DENSITY_BINS=DENSITY_BINS, yicks=yticks, ytick_labels=yticklabels)
raster_plot_asset(ax_raster, event_spikes, subsample=False, raster_linelength=2, y_label=False, invert=False)
firingrate_plot_asset(ax_firing, ax_raster, act_z, bins, l_c, 
                          x_label=True, y_label=False, invert=False)

sns.despine(left=True, top=True, right=True, bottom=True, ax=ax_shape)
sns.despine(left=True, top=True, right=True, bottom=True, ax=ax_raster)
sns.despine(top=True, right=True, trim=True, ax=ax_firing)

ax_shape.set_title(title, fontsize=titlesize, pad=20)

print(f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.svg")
#plt.tight_layout()

plt.savefig(panel_save_dir / f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.svg", dpi=300, bbox_inches='tight')
plt.savefig(panel_save_dir / f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.png", dpi=300, bbox_inches='tight')
plt.show()

## PHC: Outdoor

In [None]:
region = "PHC"
load_path = f"/home/al/Documents/phd/code/dhv_dataset/plot_code/responsive_units/results/nwb_{region}_df_region_restricted.pkl"

df_region_restricted_units = pd.read_pickle(load_path)

In [None]:
label = "indoor-setting"
df_annotation = pd.read_pickle(save_path / f"nwb_{label}_df_annotation.pkl")

In [None]:
baseline_time = 1000
stimulus_time = 1000
key = 0
restriction = "increase"
alpha = 0.001

# restrict the spike dataframe to just that of the patient
#df_patient_units = df_region_restricted_units[df_region_restricted_units["patient_id"]==patient_id]
df_unit_data = df_region_restricted_units.sort_values(by=["patient_id", "unit_id"])

reg_act, reg_pvals, sort_inds, ct_sig_units, start_ns, end_ns = identify_responses(df_unit_data, df_annotation, pat_subset, 
                                                                                   baseline_time=baseline_time, stimulus_time=stimulus_time, key=key, restriction=restriction, alpha=alpha)

df_ = pd.DataFrame(data=reg_act)
df_["pval"] = reg_pvals
df = df_.sort_values(by=['pval'])
pval_inds = df["pval"]
df = df.drop(columns="pval")
lineplot_df = df_formatter(reg_pvals, reg_act)

df["patient_id"] = df_unit_data["patient_id"].to_numpy()[df.index]
df["unit_id"] = df_unit_data["unit_id"].to_numpy()[df.index]
df.to_pickle(save_path / f"nwb_{label}_{region}_{restriction}_processed_responses.pkl")

In [None]:
a = 150
# DENSITY_BINS = np.linspace(-a, a, a)
# xticks = np.linspace(0, 19, 5)
# ticklabels = [-1000, -500, 0, 500, 1000]

amp_min = -150
amp_max = 150
steps = 150

DENSITY_BINS = np.linspace(amp_min, amp_max, steps)
xticks = np.linspace(0, 19, 5)
ticklabels = [-1000, -500, 0, 500, 1000]

# yticks = [0, steps]
# yticklabels = [amp_min, amp_max]

yticks = np.linspace(0, 150, 3)
yticklabels = [amp_min, 0, amp_max]

In [None]:
patient_id = 42
unit_id = 30
unit_rank = None

act_z, bins, binned_act, spike_times, event_spikes, waveform, onsets = grab_unit_response(unit_rank=unit_rank, patient_id=patient_id, unit_id=unit_id, 
                      df=df, df_unit_data=df_unit_data, df_annotation=df_annotation, 
                      key=key, baseline_time=baseline_time, stimulus_time=stimulus_time)

In [None]:
title = f"Outdoor Scenes\nPHC\n{patient_id}-{unit_id}, {len(onsets)} events"

fig = plt.figure(figsize=(6, 13))
gs0 = gridspec.GridSpec(1, 1, figure=fig)

l_c = sns.color_palette("Spectral", n_colors=6).as_hex()[0]

gs = gridspec.GridSpecFromSubplotSpec(5, 3, subplot_spec=gs0[0], height_ratios=[75, 1, 300, 1, 75],
                                        hspace=0.)

ax_shape  = fig.add_subplot(gs[0, 1])
ax_raster = fig.add_subplot(gs[2:4, :])
ax_firing = fig.add_subplot(gs[4, :], sharex=ax_raster)

spike_amp_plot_asset(ax_shape, np.array(waveform), "Spectral_r", invert=False)
raster_plot_asset(ax_raster, event_spikes, subsample=False, raster_linelength=1, y_label=False, invert=False)
firingrate_plot_asset(ax_firing, ax_raster, act_z, bins, l_c, 
                          x_label=True, y_label=False, invert=False)

sns.despine(left=True, top=True, right=True, bottom=True, ax=ax_shape)
sns.despine(left=True, top=True, right=True, bottom=True, ax=ax_raster)
sns.despine(top=True, right=True, trim=True, ax=ax_firing)

ax_shape.set_title(title, fontsize=titlesize, pad=20)

print(f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.svg")
#plt.tight_layout()

plt.savefig(panel_save_dir / f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.svg", dpi=300, bbox_inches='tight')
plt.savefig(panel_save_dir / f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.png", dpi=300, bbox_inches='tight')
plt.show()

## A: Summer

In [None]:
region = "A"
load_path = f"/home/al/Documents/phd/code/dhv_dataset/plot_code/responsive_units/results/nwb_{region}_df_region_restricted.pkl"

df_region_restricted_units = pd.read_pickle(load_path)

In [None]:
label = "summer"
df_annotation = pd.read_pickle(save_path / f"nwb_{label}_df_annotation.pkl")

In [None]:
baseline_time = 1000
stimulus_time = 1000
key = 1
restriction = "increase"
alpha = 0.001

# restrict the spike dataframe to just that of the patient
#df_patient_units = df_region_restricted_units[df_region_restricted_units["patient_id"]==patient_id]
df_unit_data = df_region_restricted_units.sort_values(by=["patient_id", "unit_id"])

reg_act, reg_pvals, sort_inds, ct_sig_units, start_ns, end_ns = identify_responses(df_unit_data, df_annotation, pat_subset, 
                                                                                   baseline_time=baseline_time, stimulus_time=stimulus_time, key=key, restriction=restriction, alpha=alpha)

df_ = pd.DataFrame(data=reg_act)
df_["pval"] = reg_pvals
df = df_.sort_values(by=['pval'])
pval_inds = df["pval"]
df = df.drop(columns="pval")
lineplot_df = df_formatter(reg_pvals, reg_act)

df["patient_id"] = df_unit_data["patient_id"].to_numpy()[df.index]
df["unit_id"] = df_unit_data["unit_id"].to_numpy()[df.index]
df.to_pickle(save_path / f"nwb_{label}_{region}_{restriction}_processed_responses.pkl")

In [None]:
a = 150
# DENSITY_BINS = np.linspace(-a, a, a)
# xticks = np.linspace(0, 19, 5)
# ticklabels = [-1000, -500, 0, 500, 1000]

amp_min = -150
amp_max = 150
steps = 150

DENSITY_BINS = np.linspace(amp_min, amp_max, steps)
xticks = np.linspace(0, 19, 5)
ticklabels = [-1000, -500, 0, 500, 1000]

# yticks = [0, steps]
# yticklabels = [amp_min, amp_max]

yticks = np.linspace(0, 150, 3)
yticklabels = [amp_min, 0, amp_max]

In [None]:
patient_id = None
unit_id = None
unit_rank = 1

act_z, bins, binned_act, spike_times, event_spikes, waveform, onsets = grab_unit_response(unit_rank=unit_rank, patient_id=patient_id, unit_id=unit_id, 
                      df=df, df_unit_data=df_unit_data, df_annotation=df_annotation, 
                      key=key, baseline_time=baseline_time, stimulus_time=stimulus_time)

In [None]:
title = f"Summer\nH\n{patient_id}-{unit_id}, {len(onsets)} events"

fig = plt.figure(figsize=(6, 13))
gs0 = gridspec.GridSpec(1, 1, figure=fig)

l_c = sns.color_palette("Spectral", n_colors=6).as_hex()[0]

gs = gridspec.GridSpecFromSubplotSpec(5, 3, subplot_spec=gs0[0], height_ratios=[75, 1, 300, 1, 75],
                                        hspace=0.)

ax_shape  = fig.add_subplot(gs[0, 1])
ax_raster = fig.add_subplot(gs[2:4, :])
ax_firing = fig.add_subplot(gs[4, :], sharex=ax_raster)

spike_amp_plot_asset(ax_shape, np.array(waveform), "Spectral_r", invert=False)
raster_plot_asset(ax_raster, event_spikes, subsample=False, raster_linelength=1, y_label=False, invert=False)
firingrate_plot_asset(ax_firing, ax_raster, act_z, bins, l_c, 
                          x_label=True, y_label=False, invert=False)

sns.despine(left=True, top=True, right=True, bottom=True, ax=ax_shape)
sns.despine(left=True, top=True, right=True, bottom=True, ax=ax_raster)
sns.despine(top=True, right=True, trim=True, ax=ax_firing)

ax_shape.set_title(title, fontsize=titlesize, pad=20)

print(f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.svg")
#plt.tight_layout()

# plt.savefig(panel_save_dir / f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.svg", dpi=300, bbox_inches='tight')
# plt.savefig(panel_save_dir / f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.png", dpi=300, bbox_inches='tight')
plt.show()

## checks

In [None]:
region = "H"
load_path = f"/home/al/Documents/phd/code/dhv_dataset/plot_code/responsive_units/results/nwb_{region}_df_region_restricted.pkl"

df_region_restricted_units = pd.read_pickle(load_path)

In [None]:
df_region_restricted_units.query("patient_id == 42 and unit_id == 45")

## Batching

In [None]:
region = "A"
load_path = f"/home/al/Documents/phd/code/dhv_dataset/plot_code/responsive_units/results/nwb_{region}_df_region_restricted.pkl"

df_region_restricted_units = pd.read_pickle(load_path)

In [None]:
baseline_time = 1000
stimulus_time = 1000
restriction = "decrease"
alpha = 0.001

for label in ["indoor-setting", "summer", "camera-cuts", "scenes", "tom"]:

    if label == "indoor-setting":
        key = 0
    else:
        key = 1

    df_annotation = pd.read_pickle(save_path / f"nwb_{label}_df_annotation.pkl")

    # restrict the spike dataframe to just that of the patient
    #df_patient_units = df_region_restricted_units[df_region_restricted_units["patient_id"]==patient_id]
    df_unit_data = df_region_restricted_units.sort_values(by=["patient_id", "unit_id"])

    reg_act, reg_pvals, sort_inds, ct_sig_units, start_ns, end_ns = identify_responses(df_unit_data, df_annotation, pat_subset, 
                                                                                    baseline_time=baseline_time, stimulus_time=stimulus_time, key=key, restriction=restriction, alpha=alpha)

    df_ = pd.DataFrame(data=reg_act)
    df_["pval"] = reg_pvals
    df = df_.sort_values(by=['pval'])
    pval_inds = df["pval"]
    df = df.drop(columns="pval")
    lineplot_df = df_formatter(reg_pvals, reg_act)

    df["patient_id"] = df_unit_data["patient_id"].to_numpy()[df.index]
    df["unit_id"] = df_unit_data["unit_id"].to_numpy()[df.index]
    df.to_pickle(save_path / f"nwb_{label}_{region}_{restriction}_processed_responses.pkl")

    save_dir = Path(f"/home/al/Documents/phd/analysis/psth_SU_movieAlignedLabels/updated_brain_regions/specific_units/{region}_{label}_{restriction}")
    save_dir.mkdir(parents=True, exist_ok=True)

    for unit_rank in range(0, 40):

        df_reindexed = df.reset_index()
        patient_id = int(df_reindexed.iloc[unit_rank]["patient_id"])
        unit_id = int(df_reindexed.iloc[unit_rank]["unit_id"])

        print(patient_id, unit_id)

        specific_unit_data = df_unit_data[(df_unit_data["patient_id"] == patient_id) & (df_unit_data["unit_id"] == unit_id)]
        spike_times = specific_unit_data["spike_times"].iloc[0]
        waveform = specific_unit_data["waveforms"].iloc[0]

        df_annotations_patient = df_annotation[df_annotation["patient_id"]==patient_id]
        starts = df_annotations_patient["start_time"].to_numpy()
        stops = df_annotations_patient["stop_time"].to_numpy()
        values = df_annotations_patient["value"].to_numpy()

        starts = starts[values==key]
        stops = stops[values==key]
        values = values[values==key]

        stim_onset_times = []
        for i in range(1, len(starts)):
            on = starts[i]   
            prev_off = stops[i-1]
            baseline_start = on - baseline_time

            if baseline_start > prev_off:
                stim_onset_times.append(on)

        onsets = stim_onset_times

        event_spikes = times_by_event(spike_times, onsets, baseline_time*-1, stimulus_time)

        bins = np.arange(baseline_time*-1, stimulus_time + 1, 100)
        binned_act = times_to_histogram(event_spikes, bins)

        # norm -- currently norming to whole "trial", not baseline
        act_z = zscore(binned_act)


        title = f"{label}\n{region}\n{patient_id}-{unit_id}, {len(onsets)} events"

        fig = plt.figure(figsize=(6, 13))
        gs0 = gridspec.GridSpec(1, 1, figure=fig)

        l_c = sns.color_palette("Spectral", n_colors=6).as_hex()[0]

        gs = gridspec.GridSpecFromSubplotSpec(5, 3, subplot_spec=gs0[0], height_ratios=[75, 1, 300, 1, 75],
                                                hspace=0.)

        ax_shape  = fig.add_subplot(gs[0, 1])
        ax_raster = fig.add_subplot(gs[2:4, :])
        ax_firing = fig.add_subplot(gs[4, :], sharex=ax_raster)

        spike_amp_plot_asset(ax_shape, np.array(waveform), "Spectral_r", invert=False)
        raster_plot_asset(ax_raster, event_spikes, subsample=False, raster_linelength=1, y_label=False, invert=False)
        firingrate_plot_asset(ax_firing, ax_raster, act_z, bins, l_c, 
                                x_label=True, y_label=False, invert=False)

        sns.despine(left=True, top=True, right=True, bottom=True, ax=ax_shape)
        sns.despine(left=True, top=True, right=True, bottom=True, ax=ax_raster)
        sns.despine(top=True, right=True, trim=True, ax=ax_firing)

        ax_shape.set_title(title, fontsize=titlesize, pad=20)

        print(f"{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.svg")
        #plt.tight_layout()

        #plt.savefig(save_dir / f"rank{unit_rank}_{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.svg", dpi=300, bbox_inches='tight')
        plt.savefig(save_dir / f"rank{unit_rank}_{label}_pat{patient_id}_unit{unit_id}_{region}_key{key}.png", dpi=100, bbox_inches='tight')
        plt.show()

## Clus Perm: PHC

In [None]:
region = "PHC"
load_path = f"/home/al/Documents/phd/code/dhv_dataset/plot_code/responsive_units/results/nwb_{region}_df_region_restricted.pkl"
df_region_restricted_units = pd.read_pickle(load_path)

label = "camera-cuts"
df_annotation = pd.read_pickle(save_path / f"nwb_{label}_df_annotation.pkl")


baseline_time = 1000
stimulus_time = 1000
key = 1
restriction = "increase"
alpha = 0.001

# restrict the spike dataframe to just that of the patient
#df_patient_units = df_region_restricted_units[df_region_restricted_units["patient_id"]==patient_id]
df_unit_data = df_region_restricted_units.sort_values(by=["patient_id", "unit_id"])

reg_act, reg_pvals, sort_inds, ct_sig_units, start_ns, end_ns = identify_responses(df_unit_data, df_annotation, pat_subset, 
                                                                                   baseline_time=baseline_time, stimulus_time=stimulus_time, key=key, restriction=restriction, alpha=alpha)

df_ = pd.DataFrame(data=reg_act)
df_["pval"] = reg_pvals
df = df_.sort_values(by=['pval'])
pval_inds = df["pval"]
df_plot = df.copy()
df_plot = df_plot.drop(columns="pval")
lineplot_df = df_formatter(reg_pvals, reg_act)

df["patient_id"] = df_unit_data["patient_id"].to_numpy()[df.index]
df["unit_id"] = df_unit_data["unit_id"].to_numpy()[df.index]
df.to_pickle(save_path / f"nwb_{label}_{region}_{restriction}_processed_responses.pkl")

In [None]:
nm_perms = 1000

df_ = df.drop(columns=["patient_id", "unit_id"])

condA, condB, condA_inds, condB_inds = parse_dataLabels(df_, pval_inds, return_inds=True)
gt_stats, gt_pvals = calc_testStat(condA, condB)

clus_stats, clus_ids, clus_inds = cluster_curves(gt_stats, gt_pvals)
A_list, B_list = generate_permutations(df, condA_inds, condB_inds, nm_perms)
sc_collection = permute_scores(df, A_list, B_list)
clus_pvals = get_cluster_pvals(sc_collection, clus_ids, clus_stats, nm_perms)

In [None]:
import matplotlib.lines as mlines


# Dummy variables
df = df_
label_name = "Camera Cuts"
title = "Camera Cuts\nPHC\n293 units"
region = "PHC"  
label_dict = {  
    "A_heatmap_cbar": "Z-Scored\nFiring Rate",
    "A_heatmap_ylabel": "Neurons",
    "B_lineplot_ylabel": "Z-Scored\nFiring Rate",
    "B_lineplot_xlabel": "Time",
    "B_lineplot_legend": True
}
key = 1  
save_dir = None
cmap = "Spectral_r" 
vlinecolor = "tab:blue"
lineplotcolor = "peru"  
invert_colors = False  

title_pad = 10

sig_spacer = 0.35
xticks = np.linspace(0, 20, 5)
ticklabels = [-1000, -500, 0, 500, 1000]

if label_dict is None:
    label_dict = {
        "A_heatmap_cbar": "Z-Scored\nFiring Rate",
        "A_heatmap_ylabel": "Neurons",
        "B_lineplot_ylabel": "Z-Scored\nFiring Rate",
        "B_lineplot_xlabel": "Time",
        "B_lineplot_legend": True
    }

if invert_colors:
    plt.style.use('dark_background')
    text_color = "white"
    vlinecolor = "white"
else:
    text_color = "black"

fig, axes = plt.subplots(2, 2, sharex="col", 
                         gridspec_kw={'width_ratios': [100, 3], "height_ratios": [65, 35]}, 
                         figsize=(7, 13))

ax = axes[0, 0]

vmin = -0.15
vmax = 0.7

sns.heatmap(df, ax=ax, vmin=vmin, vmax=vmax, cmap=cmap, cbar_ax=axes[0, 1], 
           cbar_kws={"ticks": [vmin, vmax]},
            rasterized=True)

axes[0, 1].set_ylabel(label_dict["A_heatmap_cbar"], fontsize=labelsize, color=text_color, rotation=0)
axes[0, 1].yaxis.set_label_coords(8,0.5)
axes[0, 1].tick_params(labelsize=ticklabelsize, colors=text_color)

ylims = ax.get_ylim()
ax.vlines(xticks[2], ylims[0], ylims[1], linewidth=axwidth, color=vlinecolor)
if start_ns != 0:
    ax.hlines(start_ns, xticks[0], xticks[-1], linestyles='--', color="grey", linewidth=axwidth)

sns.despine(ax=ax, top=True, bottom=True, left=True, right=True)
ax.set_yticks([])
ax.set_xlim(0, 20)

ax.set_title(title, pad=title_pad, fontsize=titlesize, color=text_color)

ax = axes[1, 0]
l = sns.lineplot(data=lineplot_df, x="bin", y="FR", hue="sig", ax=ax, hue_order=["ns", "*"], 
                 linewidth=axwidth, palette=["white" if invert_colors else "black", lineplotcolor])

ylims = ax.get_ylim()
ax.set_ylim(bottom=ylims[0] - (ylims[1] - ylims[0]) * 0.1, top=ylims[1])
yticks = [round(y - (y * 0.25), 2) for y in [ylims[0], 0, ylims[1]]]
ax.set_yticks(yticks)
ax.set_yticklabels(yticks, fontsize=ticklabelsize, color=text_color)
ax.vlines(xticks[2], yticks[0], yticks[-1], linewidth=axwidth, alpha=1, color=vlinecolor)
ax.set_xticks(xticks)
ax.set_xticklabels(ticklabels, fontsize=ticklabelsize, color=text_color)
sns.despine(ax=ax, trim=True)
ax.set_xlabel(label_dict["B_lineplot_xlabel"], fontsize=labelsize, color=text_color)

colors = ["white" if invert_colors else "black", lineplotcolor]
handles = [mlines.Line2D([0], [0], color=color, lw=4) for color in colors]

if label_dict["B_lineplot_legend"] is False: 
    ax.get_legend().remove()
elif len(np.unique(lineplot_df["sig"])) == 2:      
    ax.legend(handles=handles, labels=["Non-Resp.", f"Resp."], bbox_to_anchor=(1.1, 1), fontsize=labelsize, 
              loc='upper left', borderaxespad=0, handleheight=1.5, frameon=False, labelcolor=text_color)
else: 
    ax.legend(["Non-Resp."], bbox_to_anchor=(1.1, 1), fontsize=labelsize, 
              loc='upper left', borderaxespad=0, handleheight=1.5, frameon=False, labelcolor=text_color)

axes[0, 0].set_ylabel(label_dict["A_heatmap_ylabel"], fontsize=labelsize, color=text_color, rotation=0)
axes[0, 0].yaxis.set_label_coords(-0.3,0.5)
axes[1, 0].set_ylabel(label_dict["B_lineplot_ylabel"], fontsize=labelsize, color=text_color, rotation=0)
axes[1, 0].yaxis.set_label_coords(-0.3,0.5)

for c, inds in enumerate(clus_inds):

    if len(inds) == 1:
        inds = [inds[0], inds[0]+1]
    else:
        inds = [inds[0], inds[-1] + 1]

    # set cluster significance inset
    bar_width = (ylims[1] - ylims[0]) / 30
    y1 = np.tile(ylims[0]-bar_width, len(inds))
    y2 = np.tile(ylims[0], len(inds))
    ax.fill_between(inds, y1, y2, alpha=1, color=text_color)
    marker = "*"#sig_marker(clus_pvals[c])
    sig_x = ((inds[-1]-inds[0])/2)+inds[0] - sig_spacer
    ax.text(sig_x, ylims[0], marker, fontweight='bold', color=text_color, fontsize=labelsize)

axes[1, 1].remove()

plt.savefig(panel_save_dir / f"{label}_{region}_key{key}.svg", dpi=300, bbox_inches='tight')
plt.savefig(panel_save_dir / f"{label}_{region}_key{key}.png", dpi=300, bbox_inches='tight')
plt.show()