# NESTML synapse model

---

## STDP synapse

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
import nest
import numpy as np
from pynestml.codegeneration.nest_code_generator_utils import NESTCodeGeneratorUtils

In [None]:
stdp_model = """
model stdp_synapse:

    parameters:
        d ms = 1 ms
        lambda real = .01
        tau_tr_pre ms = 20 ms
        tau_tr_post ms = 20 ms
        alpha real = 1

    state:
        w real = 1.

    equations:
        kernel pre_trace_kernel = exp(-t / tau_tr_pre)
        inline pre_trace real = convolve(pre_trace_kernel, pre_spikes)

        # all-to-all trace of postsynaptic neuron
        kernel post_trace_kernel = exp(-t / tau_tr_post)
        inline post_trace real = convolve(post_trace_kernel, post_spikes)

    input:
        pre_spikes <- spike
        post_spikes <- spike

    output:
        spike(weight real, delay ms)

    onReceive(post_spikes):
        # potentiate synapse
        w += lambda * pre_trace

    onReceive(pre_spikes):
        # depress synapse
        w -= alpha * lambda * post_trace

        # deliver spike to postsynaptic partner
        emit_spike(w, d)

    update:
        integrate_odes()
"""


neuron_model = """
model izhikevich_neuron:
    parameters:
        a real = 0.02             # Describes time scale of recovery variable
        b real = 0.2              # Sensitivity of recovery variable
        c mV = -50 mV             # After-spike reset value of V_m
        d real = 2.0              # After-spike reset value of U_m
        V_m_init mV = -65 mV      # Initial membrane potential
        V_min mV = -inf * mV      # Absolute lower value for the membrane potential.
        V_th mV = 30 mV           # Threshold potential

        # Constant external input current
        I_e pA = 0 pA

    state:
        V_m mV = V_m_init         # Membrane potential
        U_m real = b * V_m_init   # Membrane potential recovery variable

    equations:
        V_m' = (0.04 * V_m * V_m / mV + 5.0 * V_m + (140 - U_m) * mV + ((I_e + I_stim) * GOhm)) / ms
        U_m' = a * (b * V_m - U_m * mV) / (mV * ms)

    input:
        spikes <- spike
        I_stim pA <- continuous

    output:
        spike

    update:
        integrate_odes()

        # Add synaptic current
        V_m += spikes * mV * s

        # lower bound of membrane potential
        V_m = max(V_min, V_m)

    onCondition(V_m >= V_th):
        # threshold crossing
        V_m = c
        U_m += d
        emit_spike()
"""

In [None]:
module_name, neuron_model_name, synapse_model_name = \
    NESTCodeGeneratorUtils.generate_code_for(neuron_model,
                                             stdp_model,
                                             post_ports=["post_spikes"],
                                             logging_level="WARNING",
                                             codegen_opts={"delay_variable": {"stdp_synapse": "d"},
                                                           "weight_variable": {"stdp_synapse": "w"}})

In [None]:
nest.Install(module_name)

In [None]:
sim_time = 1000
spike_pre = 300
spike_post = range(1, sim_time-100, 5)
delay = 15
dt = []
dw = []

for spike in spike_post:
    nest.ResetKernel()
    nest.Install(module_name)

    sg_pre = nest.Create("spike_generator", params={"spike_times": [spike_pre, sim_time - 2]})
    sg_post = nest.Create("spike_generator", params={"spike_times": [spike]})

    pre_neuron = nest.Create("parrot_neuron")
    post_neuron = nest.Create(neuron_model_name)

    sr_pre = nest.Create("spike_recorder")
    sr_post = nest.Create("spike_recorder")

    nest.Connect(sg_pre, pre_neuron)
    nest.Connect(sg_post, post_neuron, syn_spec={"weight": 9999.0})
    nest.Connect(pre_neuron, sr_pre)
    nest.Connect(post_neuron, sr_post)
    nest.Connect(pre_neuron, post_neuron, syn_spec={"synapse_model": synapse_model_name, "delay": delay})


    syn = nest.GetConnections(source=pre_neuron, synapse_model=synapse_model_name)
    initial_weight = nest.GetStatus(syn)[0]["w"]
    l = nest.GetDefaults(synapse_model_name)["lambda"]

    nest.Simulate(sim_time)

    updated_weight = nest.GetStatus(syn)[0]["w"]

    t_pre_spike = sr_pre.get()["events"]["times"][0]
    t_post_spike = sr_post.get()["events"]["times"][0]

    dt.append(t_post_spike - t_pre_spike)
    dw.append((updated_weight - initial_weight)/l)

In [None]:
plt.figure()
plt.scatter(dt, dw, s=30)
plt.grid(which="major", axis="both")
plt.grid(which="minor", axis="x")
plt.plot((np.min(dt), np.max(dt)), (0, 0), color='k', linestyle="--", linewidth=2, alpha=0.5)
plt.plot((-delay, -delay), (np.min(dw), np.max(dw)), color='k', linestyle="--", linewidth=2, alpha=0.5)
plt.xlim((-200, 200))
plt.ylim((np.min(dw), np.max(dw)))
plt.xlabel(r"$\Delta t$")
plt.ylabel(r"$\Delta w$")