# A thalamo-cortical model of post-sleep memory improvement

*Bruna, Catalina and Flávio*

*LASCON 2024*

In this Jupyter Notebook, we are going to model a thalamo-cortical model of post-sleep memory improvement. This project reproducts the paper ["Sleep-like slow oscillations improve visual classification through synaptic homeostasis and memory association in a thalamo-cortical model"](https://www.nature.com/articles/s41598-019-45525-0), whose code is not publicly available. The project is also part of the activities developed during LASCON 2024. 

In this notebook, we will go step by step in:

1. Model creation
2. Model testing
3. Implementing MNSIT classification
4. Reproducting the paper's results

Let's start. :)

-------------------------------

## 1. Model creation

Importing libraries

In [1]:
import nest
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline


              -- N E S T --
  Copyright (C) 2004 The NEST Initiative

 Version: 3.4
 Built: Jul 22 2023 00:00:00

 This program is provided AS IS and comes with
 NO WARRANTY. See the file LICENSE for details.

 Problems or suggestions?
   Visit https://www.nest-simulator.org

 Type 'nest.help()' to find out more about NEST.



In [2]:
# Reset Kernel
nest.ResetKernel()

### 1.1 Network creation

Create AEIF (Adaptative Exponential Integrate-and-Fire) alpha neurons population for the populations:

- cx: pyramidal neurons (+) [in the cortex]
- in: inhibitory interneurons (-) [in the cortex]
- tc: thalamic relay neurons (+) [thalamocortical neurons]
- re: reticular neurons (-) [thalamic]
 

In [3]:
# Number of populations
cx_n = 180
in_n = 200
tc_n = 324
re_n = 200

# Change V_peak accordingly to the paper
V_peak = nest.GetDefaults('aeif_cond_alpha')['V_th'] + 5 * nest.GetDefaults('aeif_cond_alpha')['Delta_T']
neuron_params = {"V_peak": V_peak}

nest.SetDefaults('aeif_cond_alpha', neuron_params)

# Creating populations
cx_pop = nest.Create('aeif_cond_alpha', cx_n) # Groups of 20 neurons for each image in the training set. In a first set of runs, the training set was composed of 9 images.
in_pop = nest.Create('aeif_cond_alpha', in_n)
tc_pop = nest.Create('aeif_cond_alpha', tc_n) # The number of thalamic neurons is the same as the dimension of the feature vector produced by the pre-processing of visual input
re_pop = nest.Create('aeif_cond_alpha', re_n)

Connect the populations of neurons.

In [4]:
in_cx = nest.Connect(in_pop, cx_pop, syn_spec={"weight": -4}) # inhibitory interneurons -> pyramidal neurons
cx_in = nest.Connect(cx_pop, in_pop, syn_spec={"weight": 60}) # pyramidal neurons -> inhibitory interneurons
tc_re = nest.Connect(tc_pop, re_pop, syn_spec={"weight": 10}) # thalamic relay -> reticular neurons
re_tc = nest.Connect(re_pop, tc_pop, syn_spec={"weight": -10}) # reticular neurons -> thalamic relay
in_in = nest.Connect(in_pop, in_pop, syn_spec={"weight": -1}) # inhibitory interneurons -> inhibitory interneuros
re_re = nest.Connect(re_pop, re_pop, syn_spec={"weight": -1}) # reticular neurons -> reticular neurons

Connect the populations cx-cx, cx-tc, and tc-cx with the Spike-timing dependent plasticity (STDP) synapses. These synapses comprise the learning mechanism. 

In [5]:
# Synapse definitions
w_max_cxcx = 150                        # Max weight value fo the cx-cx connection     
w_max_cxtc = 130                        # Max weight value fo the cx-tc connection 
w_max_tccx = 5.5                        # Max weight value fo the tc-cx connection 
syn_alpha = 1.0

syn_dict_cxcx = {"synapse_model": "stdp_synapse", 
                "alpha": syn_alpha,
                "weight": 1,
                "Wmax": w_max_cxcx}

syn_dict_cxtc = {"synapse_model": "stdp_synapse", 
                "alpha": syn_alpha,
                "weight": 1,
                "Wmax": w_max_cxtc}

syn_dict_tccx = {"synapse_model": "stdp_synapse", 
                "alpha": syn_alpha,
                "weight": 1.0,
                "Wmax": w_max_tccx}

# Connect populations
cx_cx = nest.Connect(cx_pop, cx_pop, syn_spec=syn_dict_cxcx)
cx_tc = nest.Connect(cx_pop, tc_pop, syn_spec=syn_dict_cxtc)
tc_cx = nest.Connect(tc_pop, cx_pop, syn_spec=syn_dict_tccx)

Curiosity: let's check the number of connections in this network.

In [6]:
print(nest.num_connections)

430640


In [27]:
# Multimeter
multimeter = nest.Create("multimeter")
multimeter.set(record_from=["V_m"])

# Spike recorder
spikerecorder = nest.Create("spike_recorder")

In [28]:
mult_cx = nest.Connect(multimeter, cx_pop)
cx_rec = nest.Connect(cx_pop, spikerecorder)

In [29]:
nest.Simulate(1000.0)


Jan 22 11:53:37 NodeManager::prepare_nodes [Info]: 
    Preparing 1270 nodes for simulation.

Jan 22 11:53:37 SimulationManager::start_updating_ [Info]: 
    Number of local nodes: 1270
    Simulation time (ms): 1000
    Number of OpenMP threads: 1
    Not using MPI

Jan 22 11:58:55 SimulationManager::run [Info]: 
    Simulation finished.


In [30]:
# Get events from multimeter
cx_vm = multimeter.get()
Vms = cx_vm["events"]["V_m"]
ts = cx_vm["events"]["times"]

cx_spike = cx_rec.get("events", "times")

AttributeError: 'NoneType' object has no attribute 'get'

In [31]:
multimeter

NodeCollection(metadata=None, model=multimeter, size=1, first=1269)

In [None]:

plt.plot(cx_spike, cx_vm)
plt.legend(loc=3)
plt.xlabel("Times (ms)")
plt.ylabel("Vm cx (mv)")

## 1.2 Input Poisson spike trains to the populations

Let's input the contextual signal to the cx population (pyramidal neurons). 

The contextual signal is applied alongside the visual stimuli to facilitate the learning of new stimuli by a subset of neurons. It changes their effective firing threshold during the presentation of handwritten images in the training phase, thus mimicking a coincidence of signals. Besides, another signal is inputted into inhibitory neurons to prevent already trained neurons to respond to the new stimuli in the training phase. 

*"Each cell receives a Poisson spike train with average firing rate that is 30 kHz only when the element of the feature vector is 1. The specific number of thalamic neurons used in the model is related to the specific pre-processing algorithm and the number of levels used to code the pre-processing output."*

*"During the retrieval phase only the 30 kHz input to thalamic cell is provided, while the contextual signal is off."*

*"The contextual signal is a Poissonian train of spikes which mimics a contextual signal coming from other brain areas and selectively facilitates neurons to learn new stimuli."*

*"Every time a new training image is presented to the network through the thalamic pathway, the facilitation signal coming from the contextual signal provides a 2 kHz Poisson spike train to a different set of 20 neurons, inducing the group to encode for that specific input stimulus (see the Discussion section for details about this choice). Additionally a 10 kHz Poisson spike train is provided to inhibitory neurons () to prevent already trained neurons to respond to new stimuli in the training phase"*

*"Relying on these observation we introduced in our model external stimuli mimiking contextual information which changes the effective firing threshold of specific subsets of neurons during the presentation of examples in the training phase"*



In [None]:
# Contextual signal (Poisson spile train of 2 kHz) to the cx population
cont_sign = nest.Create("poisson_generator")

# Poisson spike train to inhibitory neurons of 10 kHz
poisson_in = nest.Create("poisson_generator")

# Poisson spike train inputted with visual stimulu only when the element of feature vector is 1 (30 kHz)
poisson_tc = nest.Create("poisson_generator")

# Set frequencies
cont_sign.set(rate=2000.0)
poisson_in.set(rate=10000.0)
poisson_tc.set(rate=30000.0)

# Connect them to the neurons
syn_dict_contsig = {"weight": 15}
syn_dict_p_inh = {"weight": 5}
syn_dict_p_vis = {"weight": 8}

nest.Connect(cont_sign, cx_pop, syn_spec=syn_dict_contsig)      # Turn off during the retrieval phase
nest.Connect(poisson_in, in_pop, syn_spec=syn_dict_p_inh)       # Only connect after the first training set. Turn off during the retrieval phase.
nest.Connect(poisson_tc, tc_pop, syn_spec=syn_dict_p_vis)

## 2. Model testing

Creating devices for testing the model.

By creating a multimeter, we can use it to record the membrane voltage of a neuron over time. 
By creating a spike_recorder, we can record the spiking events produced by a neuron.

In [None]:
# Multimeter
multimeter = nest.Create("multimeter")
multimeter.set(record_from=["V_m"])

# Spike recorder
spikerecorder = nest.Create("spike_recorder")

Connect neurons to devices.

In [None]:
nest.Connect(multimeter, cx_pop)
nest.Connect(cx_pop, spikerecorder)

In [None]:
nest.Simulate(1000.0)


Jan 18 16:33:47 NodeManager::prepare_nodes [Info]: 
    Preparing 6 nodes for simulation.

Jan 18 16:33:47 SimulationManager::start_updating_ [Info]: 
    Number of local nodes: 6
    Simulation time (ms): 1000
    Number of OpenMP threads: 1
    Not using MPI

Jan 18 16:33:47 SimulationManager::run [Info]: 
    Simulation finished.
