<a href="https://colab.research.google.com/github/dorian-goueytes/M2_SCE_NeuroComp_Dec/blob/main/recurrent_network_wang2002.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hands-on modeling: Python implementation of a decision-making neural network

Today we are going to study a concrete implementation of the neural network proposed in Wang, 2002 [1] (see Class 2) and extended in Wong and Wang, 2006 [2].

As a reminder this neural network can integrate the information from a cloud of randomly moving dots and decide whether there is a coherent movement in the cloud towards the right or the left.

The architecture of the network is as follow:

![texte du lien](https://neuronaldynamics-exercises.readthedocs.io/en/latest/_images/DecisionMaking_NetworkStructureAll.png)

With the excitatory population corresponding to the attractors seen in Wang, 2002

![texte du lien](https://neuronaldynamics-exercises.readthedocs.io/en/latest/_images/DecisionMaking_NetworkStructureDetail.png)


This notebook is adapted from the following following resource [3]:

[https://neuronaldynamics-exercises.readthedocs.io/en/latest/exercises/perceptual-decision-making.html](https://neuronaldynamics-exercises.readthedocs.io/en/latest/exercises/perceptual-decision-making.html)



[1] Wang, X.-J. (2002). Probabilistic Decision Making by Slow Reverberation in Cortical Circuits. Neuron, 36(5), 955–968. https://doi.org/10.1016/S0896-6273(02)01092-9

[2] Wong, K.-F., & Wang, X.-J. (2006). A Recurrent Network Mechanism of Time Integration in Perceptual Decisions. The Journal of Neuroscience, 26(4), 1314–1328. https://doi.org/10.1523/JNEUROSCI.3733-05.2006

[3] Neuronal Dynamics: From single neurons to networks and models of cognition and beyond Wulfram Gerstner, Werner M. Kistler, Richard Naud and Liam Paninski



## 1 - Installing brian2 library

Brian 2 is a python library designed to facilitate the implementation of biologically plausible neural network

Documentation about Brian2 can be found [here](https://brian2.readthedocs.io/en/stable/)

In [None]:
#@title Run this cell to install brian2 skip this step if you are working on your own computer
!pip install brian2

## 2 - Python dependencies

We load all required Pythond dependenceis for the exercice

In [None]:
#@title Run this cell to install required dependencies
import brian2 as b2
from brian2 import NeuronGroup, Synapses, PoissonInput, PoissonGroup, network_operation
from brian2.monitors import StateMonitor, SpikeMonitor, PopulationRateMonitor
from random import sample
import numpy.random as rnd
import numpy
import matplotlib.pyplot as plt
from math import floor
import time
#b2.set_device('cpp_standalone')
b2.prefs.codegen.target = 'cython'

## 3 - Implementation of the model

This step is the most important: We define a function "sim_decision_making_network", which implement the model as described in the description at the top of the notebook.

### Question 1

For each population found in the description of the model on top of the notebook, find the corresponding variables in the sim_decision_making_network function. What type of brian2 object is used to create these populations?

### Question 2

 Each NeuronGroup is monitored with brian2 objects called PopulationRateMonitor, a SpikeMonitor, and a StateMonitor. Find the variable names for those monitors. You can have a look at the [Brian2 documentation](https://brian2.readthedocs.io/en/stable/reference/brian2.monitors.html#) to understand the function of theses monitors.

### Question 3

Which process of the neurons is recorded by the StateMonitor?

In [None]:
"""
Implementation of a decision making model of
[1] Wang, Xiao-Jing. "Probabilistic decision making by slow reverberation in cortical circuits."
Neuron 36.5 (2002): 955-968.

Some parts of this implementation are inspired by material from
*Stanford University, BIOE 332: Large-Scale Neural Modeling, Kwabena Boahen & Tatiana Engel, 2013*,
online available.

Note: Most parameters differ from the original publication.
"""

# This file is part of the exercise code repository accompanying
# the book: Neuronal Dynamics (see http://neuronaldynamics.epfl.ch)
# located at http://github.com/EPFL-LCN/neuronaldynamics-exercises.



# Should you reuse and publish the code for your own purposes,
# please cite the book or point to the webpage http://neuronaldynamics.epfl.ch.

# Wulfram Gerstner, Werner M. Kistler, Richard Naud, and Liam Paninski.
# Neuronal Dynamics: From Single Neurons to Networks and Models of Cognition.
# Cambridge University Press, 2014.




b2.defaultclock.dt = 0.10 * b2.ms


def sim_decision_making_network(N_Excit=384, N_Inhib=96, weight_scaling_factor=5.33,
                                t_stimulus_start=100 * b2.ms, t_stimulus_duration=9999 * b2.ms, coherence_level=0.,
                                stimulus_update_interval=30 * b2.ms, mu0_mean_stimulus_Hz=160.,
                                stimulus_std_Hz=20.,
                                N_extern=1000, firing_rate_extern=9.8 * b2.Hz,
                                w_pos=1.90, f_Subpop_size=0.25,  # .15 in publication [1]
                                max_sim_time=1000. * b2.ms, stop_condition_rate=None,
                                monitored_subset_size=512):
    """

    Args:
        N_Excit (int): total number of neurons in the excitatory population
        N_Inhib (int): nr of neurons in the inhibitory populations
        weight_scaling_factor: When increasing the number of neurons by 2, the weights should be scaled down by 1/2
        t_stimulus_start (Quantity): time when the stimulation starts
        t_stimulus_duration (Quantity): duration of the stimulation
        coherence_level (int): coherence of the stimulus.
            Difference in mean between the PoissonGroups "left" stimulus and "right" stimulus
        stimulus_update_interval (Quantity): the mean of the stimulating PoissonGroups is
            re-sampled at this interval
        mu0_mean_stimulus_Hz (float): maximum mean firing rate of the stimulus if c=+1 or c=-1. Each neuron
            in the populations "Left" and "Right" receives an independent poisson input.
        stimulus_std_Hz (float): std deviation of the stimulating PoissonGroups.
        N_extern (int): nr of neurons in the stimulus independent poisson background population
        firing_rate_extern (int): firing rate of the stimulus independent poisson background population
        w_pos (float): Scaling (strengthening) of the recurrent weights within the
            subpopulations "Left" and "Right"
        f_Subpop_size (float): fraction of the neurons in the subpopulations "Left" and "Right".
            #left = #right = int(f_Subpop_size*N_Excit).
        max_sim_time (Quantity): simulated time.
        stop_condition_rate (Quantity): An optional stopping criteria: If not None, the simulation stops if the
            firing rate of either subpopulation "Left" or "Right" is above stop_condition_rate.
        monitored_subset_size (int): max nr of neurons for which a state monitor is registered.

    Returns:

        A dictionary with the following keys (strings):
        "rate_monitor_A", "spike_monitor_A", "voltage_monitor_A", "idx_monitored_neurons_A", "rate_monitor_B",
         "spike_monitor_B", "voltage_monitor_B", "idx_monitored_neurons_B", "rate_monitor_Z", "spike_monitor_Z",
         "voltage_monitor_Z", "idx_monitored_neurons_Z", "rate_monitor_inhib", "spike_monitor_inhib",
         "voltage_monitor_inhib", "idx_monitored_neurons_inhib"

    """

    print("simulating {} neurons. Start: {}".format(N_Excit + N_Inhib, time.ctime()))
    t_stimulus_end = t_stimulus_start + t_stimulus_duration

    N_Group_A = int(N_Excit * f_Subpop_size)  # size of the excitatory subpopulation sensitive to stimulus A
    N_Group_B = N_Group_A  # size of the excitatory subpopulation sensitive to stimulus B
    N_Group_Z = N_Excit - N_Group_A - N_Group_B  # (1-2f)Ne excitatory neurons do not respond to either stimulus.

    Cm_excit = 0.5 * b2.nF  # membrane capacitance of excitatory neurons
    G_leak_excit = 25.0 * b2.nS  # leak conductance
    E_leak_excit = -70.0 * b2.mV  # reversal potential
    v_spike_thr_excit = -50.0 * b2.mV  # spike condition
    v_reset_excit = -60.0 * b2.mV  # reset voltage after spike
    t_abs_refract_excit = 2. * b2.ms  # absolute refractory period

    # specify the inhibitory interneurons:
    # N_Inhib = 200
    Cm_inhib = 0.2 * b2.nF
    G_leak_inhib = 20.0 * b2.nS
    E_leak_inhib = -70.0 * b2.mV
    v_spike_thr_inhib = -50.0 * b2.mV
    v_reset_inhib = -60.0 * b2.mV
    t_abs_refract_inhib = 1.0 * b2.ms

    # specify the AMPA synapses
    E_AMPA = 0.0 * b2.mV
    tau_AMPA = 2.5 * b2.ms

    # specify the GABA synapses
    E_GABA = -70.0 * b2.mV
    tau_GABA = 5.0 * b2.ms

    # specify the NMDA synapses
    E_NMDA = 0.0 * b2.mV
    tau_NMDA_s = 100.0 * b2.ms
    tau_NMDA_x = 2. * b2.ms
    alpha_NMDA = 0.5 * b2.kHz

    # projections from the external population
    g_AMPA_extern2inhib = 1.62 * b2.nS
    g_AMPA_extern2excit = 2.1 * b2.nS

    # projectsions from the inhibitory populations
    g_GABA_inhib2inhib = weight_scaling_factor * 1.25 * b2.nS
    g_GABA_inhib2excit = weight_scaling_factor * 1.60 * b2.nS

    # projections from the excitatory population
    g_AMPA_excit2excit = weight_scaling_factor * 0.012 * b2.nS
    g_AMPA_excit2inhib = weight_scaling_factor * 0.015 * b2.nS
    g_NMDA_excit2excit = weight_scaling_factor * 0.040 * b2.nS
    g_NMDA_excit2inhib = weight_scaling_factor * 0.045 * b2.nS  # stronger projection to inhib.

    # weights and "adjusted" weights.
    w_neg = 1. - f_Subpop_size * (w_pos - 1.) / (1. - f_Subpop_size)
    # We use the same postsyn AMPA and NMDA conductances. Adjust the weights coming from different sources:
    w_ext2inhib = g_AMPA_extern2inhib / g_AMPA_excit2inhib
    w_ext2excit = g_AMPA_extern2excit / g_AMPA_excit2excit
    # other weights are 1
    # print("w_neg={}, w_ext2inhib={}, w_ext2excit={}".format(w_neg, w_ext2inhib, w_ext2excit))

    # Define the inhibitory population
    # dynamics:
    inhib_lif_dynamics = """
        s_NMDA_total : 1  # the post synaptic sum of s. compare with s_NMDA_presyn
        dv/dt = (
        - G_leak_inhib * (v-E_leak_inhib)
        - g_AMPA_excit2inhib * s_AMPA * (v-E_AMPA)
        - g_GABA_inhib2inhib * s_GABA * (v-E_GABA)
        - g_NMDA_excit2inhib * s_NMDA_total * (v-E_NMDA)/(1.0+1.0*exp(-0.062*v/volt)/3.57)
        )/Cm_inhib : volt (unless refractory)
        ds_AMPA/dt = -s_AMPA/tau_AMPA : 1
        ds_GABA/dt = -s_GABA/tau_GABA : 1
    """

    inhib_pop = NeuronGroup(
        N_Inhib, model=inhib_lif_dynamics,
        threshold="v>v_spike_thr_inhib", reset="v=v_reset_inhib", refractory=t_abs_refract_inhib,
        method="rk2")
    # initialize with random voltages:
    inhib_pop.v = rnd.uniform(v_spike_thr_inhib / b2.mV - 4., high=v_spike_thr_inhib / b2.mV - 1., size=N_Inhib) * b2.mV

    # Specify the excitatory population:
    # dynamics:
    excit_lif_dynamics = """
        s_NMDA_total : 1  # the post synaptic sum of s. compare with s_NMDA_presyn
        dv/dt = (
        - G_leak_excit * (v-E_leak_excit)
        - g_AMPA_excit2excit * s_AMPA * (v-E_AMPA)
        - g_GABA_inhib2excit * s_GABA * (v-E_GABA)
        - g_NMDA_excit2excit * s_NMDA_total * (v-E_NMDA)/(1.0+1.0*exp(-0.062*v/volt)/3.57)
        )/Cm_excit : volt (unless refractory)
        ds_AMPA/dt = -s_AMPA/tau_AMPA : 1
        ds_GABA/dt = -s_GABA/tau_GABA : 1
        ds_NMDA/dt = -s_NMDA/tau_NMDA_s + alpha_NMDA * x * (1-s_NMDA) : 1
        dx/dt = -x/tau_NMDA_x : 1
    """

    # define the three excitatory subpopulations.
    # A: subpop receiving stimulus A
    excit_pop_A = NeuronGroup(N_Group_A, model=excit_lif_dynamics,
                              threshold="v>v_spike_thr_excit", reset="v=v_reset_excit",
                              refractory=t_abs_refract_excit, method="rk2")
    excit_pop_A.v = rnd.uniform(E_leak_excit / b2.mV, high=E_leak_excit / b2.mV + 5., size=excit_pop_A.N) * b2.mV

    # B: subpop receiving stimulus B
    excit_pop_B = NeuronGroup(N_Group_B, model=excit_lif_dynamics, threshold="v>v_spike_thr_excit",
                              reset="v=v_reset_excit", refractory=t_abs_refract_excit, method="rk2")
    excit_pop_B.v = rnd.uniform(E_leak_excit / b2.mV, high=E_leak_excit / b2.mV + 5., size=excit_pop_B.N) * b2.mV
    # Z: non-sensitive
    excit_pop_Z = NeuronGroup(N_Group_Z, model=excit_lif_dynamics,
                              threshold="v>v_spike_thr_excit", reset="v=v_reset_excit",
                              refractory=t_abs_refract_excit, method="rk2")
    excit_pop_Z.v = rnd.uniform(v_reset_excit / b2.mV, high=v_spike_thr_excit / b2.mV - 1., size=excit_pop_Z.N) * b2.mV

    # now define the connections:
    # projections FROM EXTERNAL POISSON GROUP: ####################################################
    poisson2Inhib = PoissonInput(target=inhib_pop, target_var="s_AMPA",
                                 N=N_extern, rate=firing_rate_extern, weight=w_ext2inhib)
    poisson2A = PoissonInput(target=excit_pop_A, target_var="s_AMPA",
                             N=N_extern, rate=firing_rate_extern, weight=w_ext2excit)

    poisson2B = PoissonInput(target=excit_pop_B, target_var="s_AMPA",
                             N=N_extern, rate=firing_rate_extern, weight=w_ext2excit)
    poisson2Z = PoissonInput(target=excit_pop_Z, target_var="s_AMPA",
                             N=N_extern, rate=firing_rate_extern, weight=w_ext2excit)

    ###############################################################################################

    # GABA projections FROM INHIBITORY population: ################################################
    syn_inhib2inhib = Synapses(inhib_pop, target=inhib_pop, on_pre="s_GABA += 1.0", delay=0.5 * b2.ms)
    syn_inhib2inhib.connect(p=1.)
    syn_inhib2A = Synapses(inhib_pop, target=excit_pop_A, on_pre="s_GABA += 1.0", delay=0.5 * b2.ms)
    syn_inhib2A.connect(p=1.)
    syn_inhib2B = Synapses(inhib_pop, target=excit_pop_B, on_pre="s_GABA += 1.0", delay=0.5 * b2.ms)
    syn_inhib2B.connect(p=1.)
    syn_inhib2Z = Synapses(inhib_pop, target=excit_pop_Z, on_pre="s_GABA += 1.0", delay=0.5 * b2.ms)
    syn_inhib2Z.connect(p=1.)
    ###############################################################################################
    # AMPA projections FROM EXCITATORY A: #########################################################
    syn_AMPA_A2A = Synapses(excit_pop_A, target=excit_pop_A, on_pre="s_AMPA += w_pos", delay=0.5 * b2.ms)
    syn_AMPA_A2A.connect(p=1.)
    syn_AMPA_A2B = Synapses(excit_pop_A, target=excit_pop_B, on_pre="s_AMPA += w_neg", delay=0.5 * b2.ms)
    syn_AMPA_A2B.connect(p=1.)
    syn_AMPA_A2Z = Synapses(excit_pop_A, target=excit_pop_Z, on_pre="s_AMPA += 1.0", delay=0.5 * b2.ms)
    syn_AMPA_A2Z.connect(p=1.)
    syn_AMPA_A2inhib = Synapses(excit_pop_A, target=inhib_pop, on_pre="s_AMPA += 1.0", delay=0.5 * b2.ms)
    syn_AMPA_A2inhib.connect(p=1.)
    ###############################################################################################
    # AMPA projections FROM EXCITATORY B: #########################################################
    syn_AMPA_B2A = Synapses(excit_pop_B, target=excit_pop_A, on_pre="s_AMPA += w_neg", delay=0.5 * b2.ms)
    syn_AMPA_B2A.connect(p=1.)
    syn_AMPA_B2B = Synapses(excit_pop_B, target=excit_pop_B, on_pre="s_AMPA += w_pos", delay=0.5 * b2.ms)
    syn_AMPA_B2B.connect(p=1.)
    syn_AMPA_B2Z = Synapses(excit_pop_B, target=excit_pop_Z, on_pre="s_AMPA += 1.0", delay=0.5 * b2.ms)
    syn_AMPA_B2Z.connect(p=1.)
    syn_AMPA_B2inhib = Synapses(excit_pop_B, target=inhib_pop, on_pre="s_AMPA += 1.0", delay=0.5 * b2.ms)
    syn_AMPA_B2inhib.connect(p=1.)
    ###############################################################################################
    # AMPA projections FROM EXCITATORY Z: #########################################################
    syn_AMPA_Z2A = Synapses(excit_pop_Z, target=excit_pop_A, on_pre="s_AMPA += 1.0", delay=0.5 * b2.ms)
    syn_AMPA_Z2A.connect(p=1.)
    syn_AMPA_Z2B = Synapses(excit_pop_Z, target=excit_pop_B, on_pre="s_AMPA += 1.0", delay=0.5 * b2.ms)
    syn_AMPA_Z2B.connect(p=1.)
    syn_AMPA_Z2Z = Synapses(excit_pop_Z, target=excit_pop_Z, on_pre="s_AMPA += 1.0", delay=0.5 * b2.ms)
    syn_AMPA_Z2Z.connect(p=1.)
    syn_AMPA_Z2inhib = Synapses(excit_pop_Z, target=inhib_pop, on_pre="s_AMPA += 1.0", delay=0.5 * b2.ms)
    syn_AMPA_Z2inhib.connect(p=1.)
    ###############################################################################################
    # NMDA projections FROM EXCITATORY to INHIB, A,B,Z
    @network_operation()
    def update_nmda_sum():
        sum_sNMDA_A = sum(excit_pop_A.s_NMDA)
        sum_sNMDA_B = sum(excit_pop_B.s_NMDA)
        sum_sNMDA_Z = sum(excit_pop_Z.s_NMDA)
        # note the _ at the end of s_NMDA_total_ disables unit checking
        inhib_pop.s_NMDA_total_ = (1.0 * sum_sNMDA_A + 1.0 * sum_sNMDA_B + 1.0 * sum_sNMDA_Z)
        excit_pop_A.s_NMDA_total_ = (w_pos * sum_sNMDA_A + w_neg * sum_sNMDA_B + w_neg * sum_sNMDA_Z)
        excit_pop_B.s_NMDA_total_ = (w_neg * sum_sNMDA_A + w_pos * sum_sNMDA_B + w_neg * sum_sNMDA_Z)
        excit_pop_Z.s_NMDA_total_ = (1.0 * sum_sNMDA_A + 1.0 * sum_sNMDA_B + 1.0 * sum_sNMDA_Z)

    # set a self-recurrent synapse to introduce a delay when updating the intermediate
    # gating variable x
    syn_x_A2A = Synapses(excit_pop_A, excit_pop_A, on_pre="x += 1.", delay=0.5 * b2.ms)
    syn_x_A2A.connect(j="i")
    syn_x_B2B = Synapses(excit_pop_B, excit_pop_B, on_pre="x += 1.", delay=0.5 * b2.ms)
    syn_x_B2B.connect(j="i")
    syn_x_Z2Z = Synapses(excit_pop_Z, excit_pop_Z, on_pre="x += 1.", delay=0.5 * b2.ms)
    syn_x_Z2Z.connect(j="i")

    ###############################################################################################
    # Define the stimulus: two PoissonInput with time time-dependent mean.
    poissonStimulus2A = PoissonGroup(N_Group_A, 0. * b2.Hz)
    syn_Stim2A = Synapses(poissonStimulus2A, excit_pop_A, on_pre="s_AMPA+=w_ext2excit")
    syn_Stim2A.connect(j="i")
    poissonStimulus2B = PoissonGroup(N_Group_B, 0. * b2.Hz)
    syn_Stim2B = Synapses(poissonStimulus2B, excit_pop_B, on_pre="s_AMPA+=w_ext2excit")
    syn_Stim2B.connect(j="i")

    @network_operation(dt=stimulus_update_interval)
    def update_poisson_stimulus(t):
        if t >= t_stimulus_start and t < t_stimulus_end:
            offset_A = mu0_mean_stimulus_Hz * (0.5 + 0.5 * coherence_level)
            offset_B = mu0_mean_stimulus_Hz * (0.5 - 0.5 * coherence_level)

            rate_A = numpy.random.normal(offset_A, stimulus_std_Hz)
            rate_A = (max(0, rate_A)) * b2.Hz  # avoid negative rate
            rate_B = numpy.random.normal(offset_B, stimulus_std_Hz)
            rate_B = (max(0, rate_B)) * b2.Hz

            poissonStimulus2A.rates = rate_A
            poissonStimulus2B.rates = rate_B
            # print("stim on. rate_A= {}, rate_B = {}".format(rate_A, rate_B))
        else:
            # print("stim off")
            poissonStimulus2A.rates = 0.
            poissonStimulus2B.rates = 0.

    ###############################################################################################

    def get_monitors(pop, monitored_subset_size):
        """
        Internal helper.
        Args:
            pop:
            monitored_subset_size:

        Returns:

        """
        monitored_subset_size = min(monitored_subset_size, pop.N)
        idx_monitored_neurons = sample(range(pop.N), monitored_subset_size)
        rate_monitor = PopulationRateMonitor(pop)
        # record parameter: record=idx_monitored_neurons is not supported???
        spike_monitor = SpikeMonitor(pop, record=idx_monitored_neurons)
        voltage_monitor = StateMonitor(pop, "v", record=idx_monitored_neurons)
        return rate_monitor, spike_monitor, voltage_monitor, idx_monitored_neurons

    # collect data of a subset of neurons:
    rate_monitor_inhib, spike_monitor_inhib, voltage_monitor_inhib, idx_monitored_neurons_inhib = \
        get_monitors(inhib_pop, monitored_subset_size)

    rate_monitor_A, spike_monitor_A, voltage_monitor_A, idx_monitored_neurons_A = \
        get_monitors(excit_pop_A, monitored_subset_size)

    rate_monitor_B, spike_monitor_B, voltage_monitor_B, idx_monitored_neurons_B = \
        get_monitors(excit_pop_B, monitored_subset_size)

    rate_monitor_Z, spike_monitor_Z, voltage_monitor_Z, idx_monitored_neurons_Z = \
        get_monitors(excit_pop_Z, monitored_subset_size)

    if stop_condition_rate is None:
        b2.run(max_sim_time)
    else:
        sim_sum = 0. * b2.ms
        sim_batch = 100. * b2.ms
        samples_in_batch = int(floor(sim_batch / b2.defaultclock.dt))
        avg_rate_in_batch = 0
        while (sim_sum < max_sim_time) and (avg_rate_in_batch < stop_condition_rate):
            b2.run(sim_batch)
            avg_A = numpy.mean(rate_monitor_A.rate[-samples_in_batch:])
            avg_B = numpy.mean(rate_monitor_B.rate[-samples_in_batch:])
            avg_rate_in_batch = max(avg_A, avg_B)
            sim_sum += sim_batch

    print("sim end: {}".format(time.ctime()))
    ret_vals = dict()

    ret_vals["rate_monitor_A"] = rate_monitor_A
    ret_vals["spike_monitor_A"] = spike_monitor_A
    ret_vals["voltage_monitor_A"] = voltage_monitor_A
    ret_vals["idx_monitored_neurons_A"] = idx_monitored_neurons_A

    ret_vals["rate_monitor_B"] = rate_monitor_B
    ret_vals["spike_monitor_B"] = spike_monitor_B
    ret_vals["voltage_monitor_B"] = voltage_monitor_B
    ret_vals["idx_monitored_neurons_B"] = idx_monitored_neurons_B

    ret_vals["rate_monitor_Z"] = rate_monitor_Z
    ret_vals["spike_monitor_Z"] = spike_monitor_Z
    ret_vals["voltage_monitor_Z"] = voltage_monitor_Z
    ret_vals["idx_monitored_neurons_Z"] = idx_monitored_neurons_Z

    ret_vals["rate_monitor_inhib"] = rate_monitor_inhib
    ret_vals["spike_monitor_inhib"] = spike_monitor_inhib
    ret_vals["voltage_monitor_inhib"] = voltage_monitor_inhib
    ret_vals["idx_monitored_neurons_inhib"] = idx_monitored_neurons_inhib

    return ret_vals




## 4 - Plotting the outcome of the model

Here we create a function to generate a visualization of the model behaviour after simulation. You should execute this cell before moving on to the exercice

In [None]:
def plot_network_activity(rate_monitor, spike_monitor, voltage_monitor=None, spike_train_idx_list=None,
                          t_min=None, t_max=None, N_highlighted_spiketrains=3, avg_window_width=1.0 * b2.ms,
                          sup_title=None, figure_size=(10, 4)):
    """
    Visualizes the results of a network simulation: spike-train, population activity and voltage-traces.

    Args:
        rate_monitor (PopulationRateMonitor): rate of the population
        spike_monitor (SpikeMonitor): spike trains of individual neurons
        voltage_monitor (StateMonitor): optional. voltage traces of some (same as in spike_train_idx_list) neurons
        spike_train_idx_list (list): optional. A list of neuron indices whose spike-train is plotted.
            If no list is provided, all (up to 500) spike-trains in the spike_monitor are plotted. If None, the
            the list in voltage_monitor.record is used.
        t_min (Quantity): optional. lower bound of the plotted time interval.
            if t_min is None, it is set to the larger of [0ms, (t_max - 100ms)]
        t_max (Quantity): optional. upper bound of the plotted time interval.
            if t_max is None, it is set to the timestamp of the last spike in
        N_highlighted_spiketrains (int): optional. Number of spike trains visually highlighted, defaults to 3
            If N_highlighted_spiketrains==0 and voltage_monitor is not None, then all voltage traces of
            the voltage_monitor are plotted. Otherwise N_highlighted_spiketrains voltage traces are plotted.
        avg_window_width (Quantity): optional. Before plotting the population rate (PopulationRateMonitor), the rate
            is smoothed using a window of width = avg_window_width. Defaults is 1.0ms
        sup_title (String): figure suptitle. Default is None.
        figure_size (tuple): (width,height) tuple passed to pyplot's figsize parameter.

    Returns:
        Figure: The whole figure
        Axes: Top panel, Raster plot
        Axes: Middle panel, population activity
        Axes: Bottom panel, voltage traces. None if no voltage monitor is provided.
    """

    assert isinstance(rate_monitor, b2.PopulationRateMonitor), \
        "rate_monitor  is not of type PopulationRateMonitor"
    assert isinstance(spike_monitor, b2.SpikeMonitor), \
        "spike_monitor is not of type SpikeMonitor"
    assert (voltage_monitor is None) or (isinstance(voltage_monitor, b2.StateMonitor)), \
        "voltage_monitor is not of type StateMonitor"
    assert (spike_train_idx_list is None) or (isinstance(spike_train_idx_list, list)), \
        "spike_train_idx_list is not of type list"

    all_spike_trains = spike_monitor.spike_trains()
    if spike_train_idx_list is None:
        if voltage_monitor is not None:
            # if no index list is provided use the one from the voltage monitor
            spike_train_idx_list = numpy.sort(voltage_monitor.record)
        else:
            # no index list AND no voltage monitor: plot all spike trains
            spike_train_idx_list = numpy.sort(all_spike_trains.keys())
        if len(spike_train_idx_list) > 5000:
            # avoid slow plotting of a large set
            print("Warning: raster plot with more than 5000 neurons truncated!")
            spike_train_idx_list = spike_train_idx_list[:5000]

    # get a reasonable default interval
    if t_max is None:
        t_max = max(rate_monitor.t / b2.ms)
    else:
        t_max = t_max / b2.ms
    if t_min is None:
        t_min = max(0., t_max - 100.)  # if none, plot at most the last 100ms
    else:
        t_min = t_min / b2.ms

    fig = None
    ax_raster = None
    ax_rate = None
    ax_voltage = None
    if voltage_monitor is None:
        fig, (ax_raster, ax_rate) = plt.subplots(2, 1, sharex=True, figsize=figure_size)
    else:
        fig, (ax_raster, ax_rate, ax_voltage) = plt.subplots(3, 1, sharex=True, figsize=figure_size)

    # nested helpers to plot the parts, note that they use parameters defined outside.
    def get_spike_train_ts_indices(spike_train):
        """
        Helper. Extracts the spikes within the time window from the spike train
        """
        ts = spike_train/b2.ms
        # spike_within_time_window = (ts >= t_min) & (ts <= t_max)
        # idx_spikes = numpy.where(spike_within_time_window)
        idx_spikes = (ts >= t_min) & (ts <= t_max)
        ts_spikes = ts[idx_spikes]
        return idx_spikes, ts_spikes

    def plot_raster():
        """
        Helper. Plots the spike trains of the spikes in spike_train_idx_list
        """
        neuron_counter = 0
        for neuron_index in spike_train_idx_list:
            idx_spikes, ts_spikes = get_spike_train_ts_indices(all_spike_trains[neuron_index])
            ax_raster.scatter(ts_spikes, neuron_counter * numpy.ones(ts_spikes.shape),
                              marker=".", c="k", s=15, lw=0)
            neuron_counter += 1
        ax_raster.set_ylim([0, neuron_counter])

    def highlight_raster(neuron_idxs):
        """
        Helper. Highlights three spike trains
        """
        for i in range(len(neuron_idxs)):
            color = "r" if i == 0 else "k"
            raster_plot_index = neuron_idxs[i]
            population_index = spike_train_idx_list[raster_plot_index]
            idx_spikes, ts_spikes = get_spike_train_ts_indices(all_spike_trains[population_index])
            ax_raster.axhline(y=raster_plot_index, linewidth=.5, linestyle="-", color=[.9, .9, .9])
            ax_raster.scatter(
                ts_spikes, raster_plot_index * numpy.ones(ts_spikes.shape),
                marker=".", c=color, s=144, lw=0)
        ax_raster.set_ylabel("neuron #")
        ax_raster.set_title("Raster Plot", fontsize=10)

    #def plot_population_activity(window_width=0.5*b2.ms):
    def plot_population_activity(window_width=100*b2.ms):
        """
        Helper. Plots the population rate and a mean
        """
        ts = rate_monitor.t / b2.ms
        idx_rate = (ts >= t_min) & (ts <= t_max)
        # ax_rate.plot(ts[idx_rate],rate_monitor.rate[idx_rate]/b2.Hz, ".k", markersize=2)
        smoothed_rates = rate_monitor.smooth_rate(window="flat", width=window_width)/b2.Hz
        ax_rate.plot(ts[idx_rate], smoothed_rates[idx_rate])
        ax_rate.set_ylabel("Firing Rate (Hz)")
        ax_rate.set_title("Population Activity", fontsize=10)

    def plot_voltage_traces(voltage_traces_i):
        """
        Helper. Plots three voltage traces
        """
        ts = voltage_monitor.t/b2.ms
        idx_voltage = (ts >= t_min) & (ts <= t_max)
        for i in range(len(voltage_traces_i)):
            color = "r" if i == 0 else ".7"
            raster_plot_index = voltage_traces_i[i]
            population_index = spike_train_idx_list[raster_plot_index]
            ax_voltage.plot(
                ts[idx_voltage], voltage_monitor[population_index].v[idx_voltage]/b2.mV,
                c=color, lw=1.)
            ax_voltage.set_ylabel("Memb. Pot. (mV)")
            ax_voltage.set_title("Voltage Traces", fontsize=10)

    plot_raster()
    plot_population_activity(avg_window_width)
    nr_neurons = len(spike_train_idx_list)
    highlighted_neurons_i = []  # default to an empty list.
    if N_highlighted_spiketrains > 0:
        fract = numpy.linspace(0, 1, N_highlighted_spiketrains + 2)[1:-1]
        highlighted_neurons_i = [int(nr_neurons * v) for v in fract]
        highlight_raster(highlighted_neurons_i)

    if voltage_monitor is not None:
        if N_highlighted_spiketrains == 0:
            traces_i = range(nr_neurons)
        else:
            traces_i = highlighted_neurons_i
        plot_voltage_traces(traces_i)

    plt.xlabel("t [ms]")

    if sup_title is not None:
        plt.suptitle(sup_title)

    return fig, ax_raster, ax_rate, ax_voltage

### Question 1

Run the simulation for 800 ms and use the plot_network_activity function to visualize the activity of left and right motion sensitive neurons

In [None]:
#@title Exercice (write and test your own code in this cell)



In [None]:
#@title Solution (you can see the solution to the exercice here)

results = sim_decision_making_network(N_Excit = int(384), N_Inhib= int(96),t_stimulus_start= 50. * b2.ms,
                                                      coherence_level=0.6, max_sim_time=800. * b2.ms)

plot_network_activity(results["rate_monitor_A"], results["spike_monitor_A"],
                                 results["voltage_monitor_A"], t_min=0. * b2.ms, avg_window_width=5. * b2.ms,
                                 sup_title="Left")
plt.tight_layout()
plt.show()
plot_network_activity(results["rate_monitor_B"], results["spike_monitor_B"],
                                 results["voltage_monitor_B"], t_min=0. * b2.ms, avg_window_width=5. * b2.ms,
                                 sup_title="Right")
plt.tight_layout()
plt.show()

### Question 2

Without running the simulation again, re-plot the results of the previous simulation while manipulating the lenght of the avg_window_width argument. What is the effect of having very short and a very long averaging window?

What do you think is an optimal avg_window_width value to make sense of the model outcome?

In [None]:
#@title Exercice (write and test your own code in this cell)



In [None]:
#@title Solution (you can see the solution to the exercice here)

plot_network_activity(results["rate_monitor_A"], results["spike_monitor_A"],
                                 results["voltage_monitor_A"], t_min=0. * b2.ms, avg_window_width=50. * b2.ms,
                                 sup_title="Left")
plt.tight_layout()
plt.show()
plot_network_activity(results["rate_monitor_B"], results["spike_monitor_B"],
                                 results["voltage_monitor_B"], t_min=0. * b2.ms, avg_window_width=50. * b2.ms,
                                 sup_title="Right")
plt.tight_layout()
plt.show()

## 5 - Manipulating the model's inputs

Now that we understand the structure and the model and its outputs, let's try to manipulate its inputs and understand the impact of our manipulations on the outcome

### Question 1

From the sim_decision_making_network function, can you identify which brian2 objects correspond to the stimuli?

### Question 2

Run the model with differents values of coherence, for instance c = 0.6 and c = 0.4 and plot the results. What do you observe?

In [None]:
#@title Exercice (write and test your own code in this cell)



In [None]:
#@title Solution (you can see the solution to the exercice here)
c = [-0.6, 0.4]
for coh in c:
  results = sim_decision_making_network(N_Excit = int(384), N_Inhib= int(96),t_stimulus_start= 50. * b2.ms,
                                                  coherence_level=coh, max_sim_time=800. * b2.ms)#, weight_scaling_factor = 10.66)

  plot_network_activity(results["rate_monitor_A"], results["spike_monitor_A"],
                                results["voltage_monitor_A"], t_min=0. * b2.ms, avg_window_width=40. * b2.ms,
                                sup_title="Left")
  plt.tight_layout()
  plt.show()
  plot_network_activity(results["rate_monitor_B"], results["spike_monitor_B"],
                                results["voltage_monitor_B"], t_min=0. * b2.ms, avg_window_width=40. * b2.ms,
                                sup_title="Right")
  plt.tight_layout()
  plt.show()

### Question 3

Run a few simulations with c=-0.1 and plot the network activity.

Does the network always correctly reflects the input's coherence ?

In [None]:
#@title Exercice (write and test your own code in this cell)


In [None]:
#@title Solution (you can see the solution to the exercice here)

n_runs = 3 # number of simulations to run
coherence = [-0.1,0.4] # different coherence values to try
for coh in coherence:
  print("Running "+str(n_runs)+" simulations with coherence "+str(coh))
  for run in range(0, n_runs):
      results = sim_decision_making_network(N_Excit = int(384), N_Inhib= int(96),t_stimulus_start= 50. * b2.ms,
                                                      coherence_level=coh, max_sim_time=800. * b2.ms)

      plot_network_activity(results["rate_monitor_A"], results["spike_monitor_A"],
                                    results["voltage_monitor_A"], t_min=0. * b2.ms, avg_window_width=50. * b2.ms,
                                    sup_title="Left-sensitive pop. activity, coherence =  "+str(coh)+" Run N°"+str(run))
      plt.tight_layout()
      plt.show()
      plot_network_activity(results["rate_monitor_B"], results["spike_monitor_B"],
                                    results["voltage_monitor_B"], t_min=0. * b2.ms, avg_window_width=50. * b2.ms,
                                    sup_title="Right-sensitive pop. activity, coherence =  "+str(coh)+" Run N°"+str(run))
      plt.tight_layout()
      plt.show()

## 6 - Decision criterion and reaction times

So far our model qualitatively indicates if one of its two attractors "wins" the race. However, it does not gives us a discrete outcome (right or left) or a reaction time.

### Question 1

Using your insights from the previous questions, implement a function "get_decision_time" that takes two RateMonitors , a avg_window_width and a rate_threshold.
The function should return a tuple (decision_time_left, decision_time_right).

*  if the decision was right it should return (0ms, 546ms), indicating that the right accumulator crossed the threshold after 546ms.
*   if the decision was left it should return (443ms, 0ms), indicating that the left accumulator crossed the threshold after 443ms.
*   Finally if no decision is reached it should return (0ms,0ms)
*   A results such as (443ms, 546ms) is an error, as the system cannot select right or left for a single trial




In [None]:
#@title Exercice (write and test your own code in this cell)

## Hint: The following lines of code could be of use
#smoothed_rates_A = rate_monitor_A.smooth_rate(window="flat", width=avg_window_width) / b2.Hz
#idx_A = numpy.argmax(smoothed_rates_A > rate_threshold/b2.Hz)
#t_A = idx_A * b2.defaultclock.dt

In [None]:
#@title Solution (you can see the solution to the exercice here)

def get_decision_time(rate_monitorA, rate_monitorB, avg_window_width = 50, rate_threshold = 35):
  smoothed_rates_A = rate_monitorA.smooth_rate(window="flat", width=avg_window_width) / b2.Hz # We take the population A firing rate and smooth it to decrease the noise
  smoothed_rates_B = rate_monitorB.smooth_rate(window="flat", width=avg_window_width) / b2.Hz # We take the population B firing rate and smooth it to decrease the noise
  idx_A = numpy.argmax(smoothed_rates_A > rate_threshold/b2.Hz) # We check if any firing rate values of population A is superior to our threshold
  idx_B = numpy.argmax(smoothed_rates_B > rate_threshold/b2.Hz) # We check if any firing rate values of population B is superior to our threshold
  t_A = idx_A * b2.defaultclock.dt # Then we find the timing corresponding to the threshold crossing
  t_B = idx_B * b2.defaultclock.dt # Then we find the timing corresponding to the threshold crossing

  return {"Left":t_A,"Right":t_B} #We return a dictionary indicating for each option possible the reaction time

coh = 0.7

# we run a new simulation with coherence "coh"
results = sim_decision_making_network(N_Excit = int(384), N_Inhib= int(96),t_stimulus_start= 50. * b2.ms,
                                                      coherence_level=coh, max_sim_time=800. * b2.ms)

# Plot the population activity
plot_network_activity(results["rate_monitor_A"], results["spike_monitor_A"],
                                    results["voltage_monitor_A"], t_min=0. * b2.ms, avg_window_width=50. * b2.ms,
                                    sup_title="Left-sensitive pop. activity, coherence =  "+str(coh))
plt.tight_layout()
plt.show()
plot_network_activity(results["rate_monitor_B"], results["spike_monitor_B"],
                              results["voltage_monitor_B"], t_min=0. * b2.ms, avg_window_width=50. * b2.ms,
                              sup_title="Right-sensitive pop. activity, coherence =  "+str(coh))
plt.tight_layout()
plt.show()

# And use our function with a firing rate threhold of 25 to generate a response and a reaction time
threshold = 25
reaction_times = get_decision_time(results["rate_monitor_A"], results["rate_monitor_B"], avg_window_width=50* b2.ms, rate_threshold = threshold* b2.Hz)
print("Reaction times:", reaction_times)


### Question 2

Without re-running the simulation, replot the model results and add to the graph a visualization of the thresholds and the reaction time

In [None]:
#@title Exercice (write and test your own code in this cell)



In [None]:
#@title Solution (you can see the solution to the exercice here)
threshold = 25
reaction_times = get_decision_time(results["rate_monitor_A"], results["rate_monitor_B"], avg_window_width=50* b2.ms, rate_threshold = threshold* b2.Hz)
print("Reaction times:", reaction_times)

fig, ax_raster, ax_rate, ax_voltage= plot_network_activity(results["rate_monitor_A"], results["spike_monitor_A"],
                                    results["voltage_monitor_A"], t_min=0. * b2.ms, avg_window_width=50. * b2.ms,
                                    sup_title="Left-sensitive pop. activity, coherence =  "+str(coh))

ax_rate.axhline(threshold, color = 'k', linestyle = '--', label = 'Decision Threshold')
if reaction_times['Left'] !=0:
  ax_rate.axvline(reaction_times['Left']/b2.second*1000, color = 'r', linestyle = '-', label = 'Reaction Time')
ax_rate.legend()
plt.tight_layout()
plt.show()

fig, ax_raster, ax_rate, ax_voltage = plot_network_activity(results["rate_monitor_B"], results["spike_monitor_B"],
                              results["voltage_monitor_B"], t_min=0. * b2.ms, avg_window_width=50. * b2.ms,
                              sup_title="Right-sensitive pop. activity, coherence =  "+str(coh))
ax_rate.axhline(threshold, color = 'k', linestyle = '--', label = 'Decision Threshold')
if reaction_times['Right'] !=0:
  ax_rate.axvline(reaction_times['Right']/b2.second*1000, color = 'r', linestyle = '-', label = 'Reaction Time')
ax_rate.legend()
plt.tight_layout()
plt.show()

reaction_times = get_decision_time(results["rate_monitor_A"], results["rate_monitor_B"], avg_window_width=50* b2.ms, rate_threshold = 35* b2.Hz)
print("Reaction times:", reaction_times)


### Question 3

For a given coherence value run the simulation multiple times (50 or 100 times) and build a distribution of reaction times

**Warning, running the simulation repeatedly may take a considerable amount of time**

In [None]:
#@title Exercice (write and test your own code in this cell)



In [None]:
#@title Solution (you can see the solution to the exercice here)
coherence = 0.7
threshold = 25
N_run = 50

RT_left = []
RT_right = []
for run in range(0, N_run):
  # we run a new simulation with coherence "coh"
  results = sim_decision_making_network(N_Excit = int(384), N_Inhib= int(96),t_stimulus_start= 50. * b2.ms,
                                                        coherence_level=coh, max_sim_time=800. * b2.ms)
  reaction_times = get_decision_time(results["rate_monitor_A"], results["rate_monitor_B"], avg_window_width=50* b2.ms, rate_threshold = threshold* b2.Hz)
  RT_left.append(reaction_times['Left'])
  RT_right.append(reaction_times['Right'])

fig, ax = plt.subplots(1,1)
ax.hist(RT_left, bins = 10, histtype = 'stepfilled', color = 'b', label = 'Left')
ax.hist(RT_right, bins = 10, histtype = 'stepfilled', color = 'b', label = 'Right')

ax.set_xlabel('Reaction Time (s)')
ax.set_ylabel('Count')
ax.set_title('Distribution of left and right reaction time, coherence = '+str(coherence))
ax.legend()
plt.tight_layout()
plt.show()