In [None]:
# ======================================================================
#  Temporal Summation Simulation (Leaky-Integrate-and-Fire neuron)
#  – presynaptic pulse count can be zero
#  – resting potential adjustable with slider
#  – new post-spike behaviour: overshoot at −80 mV for 1 dt, then passive decay
# ======================================================================

import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from datetime import datetime
import os

# ----------------------------------------------------------------------
#  Core LIF simulation
# ----------------------------------------------------------------------
def simulate(presyn_delay1, presyn_nPulses1, presyn_ISI1,
             presyn_E1, presyn_maxg1,
             V_rest, spike_threshold, spike_enabled):

    dt = 0.01                     # ms time step
    T  = 500                      # ms total time
    n_steps = int(T / dt)
    t = np.linspace(0, T, n_steps)

    # ---------- presynaptic conductance train (1-ms rectangular pulses) ----------
    g1 = np.zeros(n_steps)
    if presyn_nPulses1 > 0:
        pulse_steps = int(1.0 / dt)
        for j in range(int(presyn_nPulses1)):
            spike_time = presyn_delay1 + j * presyn_ISI1
            spike_idx  = int(spike_time / dt)
            if spike_idx < n_steps:
                g1[spike_idx: min(spike_idx + pulse_steps, n_steps)] = presyn_maxg1

    # ---------- membrane state ----------
    V = np.zeros(n_steps)
    V[0] = V_rest

    tau_m = 20.0                  # ms
    g_L   = 1.0 / tau_m           # because C_m = 1

    # ---------- spike waveform ----------------------------------------
    #   1 ms rise  (threshold → +60 mV)
    #   1 ms fall  (+60 mV → −80 mV)   ← final value is −80 mV
    #   (no 3rd phase; passive decay handles recovery)
    # ------------------------------------------------------------------
    ### CHANGES ↓ – waveform ends on −80 mV
    phase1 = np.linspace(spike_threshold, 60.0, int(1.0 / dt), endpoint=False)
    phase2 = np.linspace(60.0, -80.0,        int(1.0 / dt))          # includes −80
    waveform = np.concatenate([phase1, phase2])
    spike_wave_steps = len(waveform)          # 2 ms total

    # ---------- flags & counters ----------
    in_spike = False
    spike_counter = 0
    in_refractory = False
    refractory_counter = 0
    refractory_steps = int(5.0 / dt)          # 5 ms absolute refractory

    # ------------------------------------------------------------------
    #  Main loop
    # ------------------------------------------------------------------
    for i in range(n_steps - 1):

        # -------- inside a spike ----------
        if in_spike:
            V[i+1] = waveform[spike_counter]
            spike_counter += 1

            if spike_counter >= spike_wave_steps:
                # Spike just finished – enter refractory (Vm already at −80 mV)
                in_spike = False
                in_refractory = True
                refractory_counter = refractory_steps
            continue   # skip passive integration this step
        # -------- end spike block ---------

        # -------- absolute refractory (no new spikes, but passive dynamics on) --
        if in_refractory:
            g_total = g1[i]
            A = -(g_L + g_total)
            B = g_L * V_rest + g_total * presyn_E1
            exp_factor = np.exp(A * dt)
            V[i+1] = exp_factor * V[i] + (exp_factor - 1) * (B / A)

            refractory_counter -= 1
            if refractory_counter <= 0:
                in_refractory = False
            continue
        # -------- end refractory block ----

        # -------- threshold / spike initiation (if enabled) ----------
        if spike_enabled and V[i] > spike_threshold:
            in_spike = True
            spike_counter = 0
            V[i+1] = waveform[spike_counter]
            spike_counter += 1
            continue
        # -------------------------------------------------------------

        # -------- passive membrane update (Euler-exact) --------------
        g_total = g1[i]
        A = -(g_L + g_total)
        B = g_L * V_rest + g_total * presyn_E1
        exp_factor = np.exp(A * dt)
        V[i+1] = exp_factor * V[i] + (exp_factor - 1) * (B / A)
        # -------------------------------------------------------------

    # ---------------------------- plotting ----------------------------
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 4),
                                   sharex=True,
                                   gridspec_kw={'height_ratios': [3, 1]})
    ax1.plot(t, V, color='black', lw=1.5)
    ax1.set_ylabel('Membrane V (mV)')
    ax1.set_title('Postsynaptic Membrane Potential')
    ax1.grid(True)

    ax2.plot(t, g1, lw=1.5)
    ax2.set_xlabel('Time (ms)')
    ax2.set_ylabel('g_syn')
    ax2.set_title('Presynaptic Conductance Pulses')
    ax2.set_ylim(top=0.1)
    ax2.grid(True)

    plt.tight_layout()
    plt.show()

# ----------------------------------------------------------------------
#  Widgets
# ----------------------------------------------------------------------

# --- presynaptic input (allow zero pulses) ---
delay1   = widgets.FloatSlider(value=10.0, min=0,  max=100, step=1,
                               description='Delay (ms):', disabled=True)
npulses1 = widgets.IntSlider( value=3,   min=0,  max=20, step=1,
                              description='N Pulses:')
isi1     = widgets.FloatSlider(value=150, min=2,  max=150, step=2,
                               description='ISI (ms):')

# --- postsynaptic neuron ---
vrest_slider = widgets.FloatSlider(value=-70, min=-80, max=-60, step=1,
                                   description='V_rest (mV):')
spike_thresh_widget = widgets.FloatSlider(value=-55, min=-70, max=-30, step=1,
                                          description='Spike Thres. (mV):',
                                          disabled=False)
E1    = widgets.FloatSlider(value=0.0, min=-100, max=10, step=5,
                            description='E_rev (mV):')
maxg1 = widgets.FloatSlider(value=5,   min=0,  max=10, step=1,
                            description='N receptors:')
spike_enabled_checkbox = widgets.Checkbox(value=True, description='Enable Spiking')

# --- run button ---
run_button = widgets.Button(description='Run Simulation', button_style='success')
plot_out = widgets.Output()

def on_run_clicked(b):
    plot_out.clear_output(wait=True)
    with plot_out:
        simulate(delay1.value, npulses1.value, isi1.value,
                 E1.value, maxg1.value * 0.01,
                 vrest_slider.value,
                 spike_thresh_widget.value,
                 spike_enabled_checkbox.value)

run_button.on_click(on_run_clicked)

# --- minimal layout (quiz widgets omitted for brevity) ---------------
ui_left  = widgets.VBox([widgets.HTML('<h3>Presynaptic Input</h3>'),
                         delay1, npulses1, isi1, run_button])
ui_right = widgets.VBox([widgets.HTML('<h3>Postsynaptic Neuron</h3>'),
                         vrest_slider, spike_thresh_widget,
                         E1, maxg1, spike_enabled_checkbox])
display(widgets.HBox([ui_left, ui_right]), plot_out)

# first draw
with plot_out:
    simulate(delay1.value, npulses1.value, isi1.value,
             E1.value, maxg1.value * 0.01,
             vrest_slider.value,
             spike_thresh_widget.value,
             spike_enabled_checkbox.value)


HBox(children=(VBox(children=(HTML(value='<h3>Presynaptic Input</h3>'), FloatSlider(value=10.0, description='D…

Output()