## panels for fig 1 

#### Figure Purpose: Give readers enough information about the dataset as a glance to decide if it is of interest to them. 

* electrode localizations
* unit distribution across patients, stratified by region
* overall unit information
* general stimulus information

In [None]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', )))
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', '..')))

from collections import Counter, OrderedDict
from pathlib import Path

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib import rc
import seaborn as sns
from nilearn.plotting import plot_markers

from config_colors import *
from config_paths import *
from config_plot_params import *
from nwb_io import *
from plot_patientwise_unit_distribution import *

# save panels directly to the relevant svg/ subdir
panel_save_dir = Path.cwd().parent.parent / "figure_generation" / "figure_data_overview" / "svg"

data_dir = NWB_data_dir

In [None]:
## collect all units from all nwb files

i = 0 

for path in tqdm(list(data_dir.glob("*.nwb"))):
    print(path)
    if path.is_dir():
        continue 

    patient_id = int(path.name.split(".")[0][3:])
    print(f"  {patient_id}")
    io = NWBHDF5IO(path, mode="r")
    nwbfile = io.read()

    df_units = nwbfile.units.to_dataframe()
    df_units["unit_id"] = np.arange(0, len(df_units))

    df_units.insert(0, "patient_id", [patient_id] * len(df_units), )
    
    if i == 0:
        df_units_all = df_units.copy()
    else:
        df_units_all = pd.concat([df_units_all, df_units], ignore_index=True)
    
    io.close()
    i += 1

## raster plot

In [None]:
pat = 10

spikes = df_units_all[df_units_all["patient_id"] == pat]["spike_times"]
spikes = np.array(spikes)

In [None]:
minutes = 5
time_limit = minutes * 60 * 1000 

data = []
for unit in spikes:
    unit_ = []
    for s in unit:
        #if s >= onset and s <= offset:
        if s <= time_limit:
            unit_.append(s)
            
    data.append(np.array(unit_))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10,4))

ax.eventplot(data, linewidths=0.8, linelengths=.7, color='black', alpha=0.5)
ax.invert_yaxis()
ax.set_yticks([])
ax.set_yticklabels([])
ax.yaxis.set_ticks_position('none') 

ax.set_xticks([])
ax.set_xticklabels([])
ax.set_xlabel("Time [5 minutes]", fontsize=labelsize, )
#ax.set_ylabel("Unit", fontsize=labelsize, labelpad=-15, )

sns.despine(left=True, bottom=True)
plt.tight_layout()

plt.savefig(panel_save_dir / "raster.png", bbox_inches="tight", dpi=300)
plt.savefig(panel_save_dir / "raster.svg", bbox_inches="tight", dpi=300)

plt.show()

## electrode localizations

In [None]:
i = 0 

for path in tqdm(list(data_dir.glob("*.nwb"))):
    print(path)
    if path.is_dir():
        continue 

    patient_id = int(path.name.split(".")[0][3:])
    print(f"  {patient_id}")
    io = NWBHDF5IO(path, mode="r")
    nwbfile = io.read()

    df_electrodes = nwbfile.electrodes.to_dataframe()
    df_electrodes.insert(0, "patient_id", [patient_id] * len(df_electrodes))
    
    if i == 0:
        df_electrodes_all = df_electrodes.copy()
    else:
        df_electrodes_all = pd.concat([df_electrodes_all, df_electrodes], ignore_index=True)
    
    io.close()
    i += 1

In [None]:
df_electrodes_all_Others_renamed = df_electrodes_all.copy()

df_electrodes_all_Others_renamed["brain_region"].replace("PIC", "Other", inplace=True)
df_electrodes_all_Others_renamed["brain_region"].replace("FF", "Other", inplace=True)
df_electrodes_all_Others_renamed["brain_region"].replace("LG", "Other", inplace=True)
df_electrodes_all_Others_renamed["brain_region"].replace("PRC", "Other", inplace=True)

In [None]:
node_coords = []
for i, row in df_electrodes_all_Others_renamed.iterrows():
    node_coords.append(np.array([row.x, row.y, row.z]))

region_map = {
    "A": 2,
    "AH": 3,
    "MH": 3,
    "PH": 3, 
    "EC": 4,
    "PHC": 5,
    "APH": 5, 
    "MPH": 5, 
    "PPH": 5,
    "Other": 0
}

node_values = [region_map[r]+(np.random.random()) for r in df_electrodes_all_Others_renamed['brain_region']]


In [None]:
from matplotlib.colors import ListedColormap

colors = [
    '#FCAF4A',
    '#E53E24', 
    '#54aead', 
    "#D8D806", 
    '#60A02C', 
]

#cmap=sns.color_palette("Spectral", as_cmap=True)
cmap= ListedColormap(colors, name='brain_v2')
#plt.style.use('dark_background')
fig, ax = plt.subplots(1,1, figsize=(10,3))

node_size = 2.
alpha = 0.6

plot_markers(
    node_values,    
    node_coords,     
    node_size=node_size,
    node_cmap=cmap,
    alpha=alpha,
    figure=fig,
    axes=ax,
    colorbar=False,
    black_bg=False,
)

norm = plt.Normalize(vmin=min(node_values), vmax=max(node_values))

handles = []
for region, value in region_map.items():
    if region in ["AH"]:
        region = "H"
        value = region_map["AH"]
    elif region in ["MH", "PH"]:
        continue

    if region in ["PHC"]:
        region = "PHC"
        value = region_map["PHC"]
    elif region in ["APH", "MPH", "PPH"]:
        continue

    color = cmap(norm(value))
    handle = plt.scatter([], [], color=color, s=100, label=region, edgecolor='w')    
    handles.append(handle)

ax.legend(handles=handles, title="", markerscale=1.5, bbox_to_anchor=(1.13, 1), fontsize=12, frameon=False)

plt.savefig(panel_save_dir / "localizations.png", bbox_inches="tight", dpi=300)
plt.savefig(panel_save_dir / "localizations.svg", bbox_inches="tight", dpi=300)
plt.show()

## units across regions

In [None]:
raw_regions = df_units_all["brain_region"]

raw_regions = [r for r in raw_regions if r not in ["H", "T", "Ta", "Tb", "I"]]
units_regions = pd.Series(raw_regions).replace("AH","H").replace("MH", "H").replace("PH", "H").replace("APH","PHC").replace("MPH","PHC").replace("PPH","PHC").replace("PIC", "Other").replace("FF", "Other").replace("LG", "Other").replace("PRC", "Other")
data = Counter(units_regions)

order_list = ["A", "H", "EC", "PHC", "Other"]
ordered_data = OrderedDict((k, data[k]) for k in order_list)

In [None]:
colors = [
    '#E53E24', 
    '#54aead', 
    "#D8D806",
    '#60A02C', 
    '#FCAF4A',
]

bar_width = 0.85

fig, ax = plt.subplots(1,1, figsize=(3.75,3))
bars = ax.bar(ordered_data.keys(), ordered_data.values(), bar_width, color=colors, )


for bar in bars:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width() / 2, height + 0.5,  # Adjust vertical position
            f'{int(height)}', ha='center', va='bottom', fontsize=ticksize)

ax.set_xlabel(f"Region", fontsize=labelsize)
ax.set_ylabel("Number of units", fontsize=labelsize)

sns.despine()
ymin, ymax = ax.get_ylim()
ax.spines['left'].set_bounds(ymin, 800)

#plt.tight_layout()
plt.savefig(panel_save_dir / "units_by_region.png", bbox_inches="tight", dpi=300)
plt.savefig(panel_save_dir / "units_by_region.svg", bbox_inches="tight", dpi=300)
plt.show()


## units across patient, stratified by regions

In [None]:
df_patient_overview = create_data_overview(data_dir)

In [None]:
df_units_all_Other_filtered = df_units_all.copy()

df_units_all_Other_filtered["brain_region"].replace("PIC", "Other", inplace=True)
df_units_all_Other_filtered["brain_region"].replace("FF", "Other", inplace=True)
df_units_all_Other_filtered["brain_region"].replace("LG", "Other", inplace=True)
df_units_all_Other_filtered["brain_region"].replace("PRC", "Other", inplace=True)

In [None]:
data_collector = UnitRegionDataCollectorNWB(df_units_all_Other_filtered)
data_processor = UnitRegionDataProcessor(data_collector)
filtered_units, nm_units, filtered_regions = data_processor.filter_by_region(target_regions)
consolidated_regions = data_processor.consolidate_subregions(replacement_regions, filtered_regions)
data = format_regions_for_barplot(consolidated_regions, final_regions)

order_list = ["A", "H", "EC", "PHC", "Other",]
ordered_data = OrderedDict((k, data[k]) for k in order_list)

patients = np.arange(1, len(set(df_units_all_Other_filtered["patient_id"]))+1)
rows = order_list
columns = patients
n_rows = len(rows)
data_list = [ordered_data[key] for key in ordered_data]



In [None]:
rc('axes', linewidth=axwidth)
rc('xtick.major', width=tickwidth, size=ticksize)
rc('xtick', labelsize=ticklabelsize)        
rc('ytick.major', width=tickwidth, size=ticksize)
rc('ytick', labelsize=ticklabelsize)

fig_wid = 8.5
fig_height = 3.
fig, ax = plt.subplots(1,1, figsize=(fig_wid, fig_height))

bar_width = 0.85

yticks = range(0, 141, 20)
index = np.arange(len(columns))
y_offset = np.zeros(len(columns))

cell_text = []
for i, row in enumerate(range(n_rows)):
    hbar = ax.bar(index, data_list[row], bar_width, align='center', bottom=y_offset,
                   label=rows[i], color=colors[row])

    if i == n_rows - 1:
        ax.bar_label(hbar, padding=5, fontsize=ticksize)
    y_offset = y_offset + data_list[row]

sns.despine()
ymin, ymax = ax.get_ylim()
ax.spines['left'].set_bounds(ymin, ymax+3)
    
ax.set_xticks(range(len(patients)))
ax.set_xticklabels(patients, fontsize=ticksize)

ax.set_yticks(yticks)
ax.set_yticklabels(yticks, fontsize=ticksize)

ax.set_xlabel("Patient ID", fontsize=labelsize, labelpad=15)
ax.set_ylabel("Number of units",fontsize=labelsize)
ax.margins(x=0.01)
#ax.legend(loc='upper right', bbox_to_anchor=(1.18, 1), fontsize=ticksize, frameon=False)

#plt.tight_layout()
plt.savefig(panel_save_dir / "unit_distribution_across_patients.png", bbox_inches="tight", dpi=300)
plt.savefig(panel_save_dir / "unit_distribution_across_patients.svg", bbox_inches="tight", dpi=300)
plt.show()