In [18]:
%matplotlib notebook
import pickle
import matplotlib
import os

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.cm as cm

from matplotlib import gridspec

In [19]:
dends_dict = {
    20: 'dend2_01000',
    66: 'dend4_0110',
    86: 'dend5_0110110',
    108: 'dend5_01111111100',
    117: 'dend5_01111111111101',
    127: 'dend5_011111111111110'
}

In [20]:
data_folder = 'L:/simulation_data/ca_middle'
dend = 127
fname = f'IN_{dend}'

with open(f'{data_folder}/dendint_single_{fname}.pkl', 'rb') as f:
    simulation_data_single = pickle.load(f)

with open(f'{data_folder}/dendint_together_{fname}.pkl', 'rb') as f:
    simulation_data_together = pickle.load(f)

### plotting functions

In [1]:
# --- Helper: Plot NSyn Traces ---
def plot_together_nsyn_range(sim_together_list, seg_label='soma(0.5)', nsyn=30, legend=True):
    fig, ax = plt.subplots(figsize=(10, 6))
    colormap = matplotlib.colormaps.get_cmap('viridis')
    selected_ns = [1, 2, 4, 6, 8, 10, 15, 20, 25, 30]

    for n in selected_ns:
        if n > len(sim_together_list):
            continue

        sim = sim_together_list[n - 1]
        seg_labels, v = sim['membrane_potential_data']
        t = np.array(sim['taxis'])

        try:
            idx = list(seg_labels).index(seg_label)
        except ValueError:
            print(f"[Warning] Segment {seg_label} not found in sim {n}. Skipping.")
            continue

        color = colormap((n - 1) / (nsyn - 1))
        ax.plot(t, v[idx, :], label=f'{n}', color=color)

    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('Membrane Potential (mV)')
    ax.set_title(f'{seg_label}')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    if legend:
        ax.legend(fontsize=7, loc='upper right')
    fig.tight_layout()

    return fig


# --- Helper: Compute EPSPs for Input-Output Plot ---
def compute_epsp(sim_single_list, sim_together_list, seg_label, selected_ns, baseline_start=50, baseline_end=90):
    expected_epsps = []
    measured_epsps = []

    for n in selected_ns:
        if n > len(sim_single_list) or n > len(sim_together_list):
            continue

        # --- Expected EPSP: summed singles ---
        summed_trace = None
        valid = True

        for i in range(n):
            sim = sim_single_list[i]
            seg_labels, v = sim['membrane_potential_data']
            t = np.array(sim['taxis'])

            try:
                idx = list(seg_labels).index(seg_label)
            except ValueError:
                valid = False
                break

            trace = v[idx]
            if summed_trace is None:
                summed_trace = np.array(trace)
                time_vector = np.array(t)
            else:
                min_len = min(len(summed_trace), len(trace))
                summed_trace = summed_trace[:min_len] + trace[:min_len]
                time_vector = time_vector[:min_len]

        if not valid:
            continue

        baseline_idx = np.where((time_vector >= baseline_start) & (time_vector < baseline_end))[0]
        baseline = np.mean(summed_trace[baseline_idx])
        expected_peak = np.max(summed_trace) - baseline

        # --- Measured EPSP: from combined simulation ---
        sim = sim_together_list[n - 1]
        seg_labels, v = sim['membrane_potential_data']
        t = np.array(sim['taxis'])

        try:
            idx = list(seg_labels).index(seg_label)
        except ValueError:
            continue

        trace = v[idx]
        min_len = min(len(trace), len(t))
        trace = trace[:min_len]
        t = t[:min_len]

        baseline_idx = np.where((t >= baseline_start) & (t < baseline_end))[0]
        baseline_t = np.mean(trace[baseline_idx])
        measured_peak = np.max(trace) - baseline_t

        # Append absolute EPSP values with baseline re-added
        expected_epsps.append(expected_peak + baseline_t)
        measured_epsps.append(measured_peak + baseline_t)

    return expected_epsps, measured_epsps


# --- Prepare Subplots and Composite Figure ---
def plot_combined_summary(sim_single_list, sim_together_list, seg_label_dend, dends_dict, dend, nsyn=30):
    selected_ns = [1, 2, 4, 6, 8, 10, 15, 20, 25, 30]
    colormap = matplotlib.colormaps.get_cmap('viridis')

    # Generate individual figures
    fig_soma = plot_together_nsyn_range(sim_together_list, seg_label='soma(0.5)', nsyn=nsyn, legend=True)
    fig_dend = plot_together_nsyn_range(sim_together_list, seg_label=seg_label_dend, nsyn=nsyn, legend=False)

    ax_soma = fig_soma.axes[0]
    ax_dend = fig_dend.axes[0]

    plt.close(fig_soma)  # Prevent display
    plt.close(fig_dend)  # Prevent display

    # Compute I/O data
    expected_epsps, measured_epsps = compute_epsp(
        sim_single_list, sim_together_list,
        seg_label='soma(0.5)',
        selected_ns=selected_ns
    )

    # --- Setup composite layout ---
    fig = plt.figure(figsize=(14, 8), constrained_layout=True)
    gs = gridspec.GridSpec(2, 2, width_ratios=[3, 2], height_ratios=[1, 1], hspace=0.4)

    ax0 = fig.add_subplot(gs[0, 0])   # Soma
    ax1 = fig.add_subplot(gs[1, 0])   # Dendrite
    ax2 = fig.add_subplot(gs[:, 1])   # IO

    # Copy traces from generated figures
    for src_ax, target_ax in zip([ax_soma, ax_dend], [ax0, ax1]):
        for line in src_ax.get_lines():
            target_ax.plot(line.get_xdata(), line.get_ydata(), label=line.get_label(), color=line.get_color())
        target_ax.set_title(src_ax.get_title())
        target_ax.set_xlabel(src_ax.get_xlabel())
        target_ax.set_ylabel(src_ax.get_ylabel())
        target_ax.legend(fontsize=7, loc='upper right')
        target_ax.spines['top'].set_visible(False)
        target_ax.spines['right'].set_visible(False)

    # --- Plot I/O curve ---
    colors = [colormap(n / (nsyn - 1)) for n in selected_ns]
    ax2.scatter(expected_epsps, measured_epsps, color=colors, zorder=3)
    ax2.plot([-69, -63], [-69, -63], '-', color='lightgray', zorder=1)

    # Connect lines
    for i in range(len(expected_epsps) - 1):
        ax2.plot(expected_epsps[i:i+2], measured_epsps[i:i+2], color='black', linewidth=0.8, zorder=1)

    ax2.set_xlabel('Expected EPSP (mV)')
    ax2.set_ylabel('Measured EPSP (mV)')
    ax2.set_title('Input-Output Curve @ Soma')
    ax2.axis('equal')
    ax2.spines['top'].set_visible(False)
    ax2.spines['right'].set_visible(False)

    plt.show()

### plot all figures together

In [2]:
seg_label_dend = f'{dends_dict[dend]}(0.5)'
plot_combined_summary(
    sim_single_list=simulation_data_single,
    sim_together_list=simulation_data_together,
    seg_label_dend=seg_label_dend,
    dends_dict=dends_dict,
    dend=dend,
    nsyn=30
)

NameError: name 'dends_dict' is not defined