*Copyright (C) 2021 Intel Corporation*<br>
*SPDX-License-Identifier: BSD-3-Clause*<br>
*See: https://spdx.org/licenses/*

---

# Three Factor Learning with Lava 

_**Motivation**: In this tutorial, we will demonstrate the simple mechanics of a three factor learning rule using a software model of Loihi's learning engine, exposed in Lava. This involves the definition of a reward-modulated synaptic plasticity rule with eligibility traces and reward signals_

#### This tutorial assumes that you:
- have the [Lava framework installed](../../in_depth/tutorial01_installing_lava.ipynb "Tutorial on Installing Lava")
- are familiar with the [Process concept in Lava](../../in_depth/tutorial02_processes.ipynb "Tutorial on Processes")
- are familiar with the [ProcessModel concept in Lava](../../in_depth/tutorial02_process_models.ipynb "Tutorial on ProcessModels")
- are familiar with how to [connect Lava Processes](../../in_depth/tutorial05_connect_processes.ipynb "Tutorial on connecting Processes")
- are familiar with how to [Implement a custom learning rule](../../in_depth/tutorial08_stdp.ipynb "Tutorial on STDP")

This tutorial gives a bird's-eye view of how to create a three-factor learning rule using the Lava process libraries. For this purpose, we will create a network of LIF and Dense processes with one plastic connection and generate frozen patterns of activity. We can easily choose between a floating point simulation of the learning engine and a fixed point simulation, which approximates the behavior on the Loihi neuromorphic hardware. We also will create monitors to observe the behavior of the weights and activity traces of the neurons and learning rules.

### Reward Modulated Spike-timing Dependent Plasticity (R-STDP) Learning rule

Reward-modulated STDP is a learning rule that can explain how behaviourly relevant adaptive changes in complex network of spiking neurons could be achieved in a self-organizing manner through local synaptic plasticity. The main idea is to modulate the outcome of a pairwise two-factor learning rule like STDP by a reward term. The implementation of the R-STDP described below is adapted from [Neuro-modulated Spike-Timing-Dependent Plasticity](https://www.frontiersin.org/articles/10.3389/fncir.2015.00085/full "Neuromodulated Spike-Timing-Dependent Plasticity, and Theory of Three-Factor Learning Rules"). 

The magnitude of weight change is implemented as a function of the synaptic eligibility trace $e$ and a reward term $R$. A synaptic eligibility trace, is used to store a temporary memory of the STDP outcome so that it is still available by the time a delayed reward signal is received. If you define the learning window of a traditional hebbian STDP as $STDP(pre, post)$, where "pre" represents the pre-synaptic activity of a synapse and "post" represents the state of the post-synaptic neuron, the synaptic eligibility trace dynamics can be represented in the form:

$$\dot{E} = - \frac{E}{\tau_e} + STDP(pre, post)$$

where $\tau_e$ is the time constant of the eligibility trace. In R-STDP, the synaptic weight, $W$, is modulated based on:

$$\dot{W} = R \cdot E.

### Instantiating an R-STDP learning rule with learning-related parameters. 

We define the eligibility trace $\dot{E}$ in the sum-of-products form using the tag variable $t$ and describe its dynamics $dt$ as:

$$dt = ( A_{+} \cdot x_0 \cdot y_1 + A_{-} \cdot y_0 \cdot x_1 ) - t \cdot tag\_tau$$

Here, $dt$ represents a simple pairwise-STDP learning rule with $A_{+} < 0$ and $A_{-} > 0$. The Reward-modulated STDP is defined by the synaptic weight change variable $dw$ in the form:

$$dw = u_0 \cdot t \cdot y_2  

In [None]:
#INITIALIZING LEARNING-RELATED PARAMETERS
from lava.magma.core.learning.learning_rule import LoihiLearningRule

# Learning rule coefficient
A_plus = -2
A_minus = 2

learning_rate = 1

# Trace decay constants
x1_tau = 10
y1_tau = 10

# Eligibility trace decay constant
tag_tau = 10 # Verify

# High Reward decay constant for negligible decay
y2_tau = 2 ** 32-1

# Impulses
x1_impulse = 16
y1_impulse = 16

# Zero impulse value for reward. 
y2_impulse = 0

# Epoch length
t_epoch = 2

#string learning rule for dt : ELIGIBILITY TRACE represented as tag_1
dt = f"{learning_rate} * {A_plus} * x0 * y1 +" \
     f"{learning_rate} * {A_minus} * y0 * x1 - t * {tag_tau}"

# String learning rule for dw
# The weights are updated at every-timestep and the magnitude is a product of y2 (R) and de (tag_1)
dw = " u0 * t * y2 "


# Create custom LearningRule
R_STDP = LoihiLearningRule(dw=dw,
                         x1_impulse=x1_impulse,
                         x1_tau=x1_tau,
                         y1_impulse=y1_impulse,
                         y1_tau=y1_tau,
                         y2_impulse=y2_impulse,
                         y2_tau=y2_tau,
                         t_epoch=t_epoch)

### Network Parameters and Spike Inputs

We now define the parameters of the network and generate frozen and random input spikes that act as inputs to the pre and post-synaptic inputs. We also generate graded reward spikes that are used to set the third-factor in the post-synaptic neuron. 

In [None]:
import numpy as np

# Set this tag to "fixed_pt" or "floating_pt" to choose the corresponding models.
SELECT_TAG = "floating_pt"

# LIF parameters : Only supports floating_pt for now. 
if SELECT_TAG == "floating_pt":
    du = 1
    dv = 1

vth = 240

# Number of pre-synaptic neurons per layer
num_neurons_pre = 1
shape_lif_pre = (num_neurons_pre, )
shape_conn_pre = (num_neurons_pre, num_neurons_pre)

# Number of post-synaptic neurons per layer
num_neurons_post = 2
shape_lif_post = (num_neurons_post, )
shape_conn_post = (num_neurons_post, num_neurons_post)

# Connection parameters

# SpikePattern -> LIF connection weight : PRE-synaptic
wgt_inp_pre = np.eye(num_neurons_pre) * 250

# SpikePattern -> LIF connection weight : POST-synaptic
wgt_inp_post = np.eye(num_neurons_post) * 250

# LIF -> LIF connection initial weight (learning-enabled)
wgt_plast_conn = np.full(shape_conn_post, 50)
    
# Number of simulation time steps
num_steps = 200
time = list(range(1, num_steps + 1))

# Spike times
spike_prob = 0.03

# Create spike rasters
np.random.seed(123)
spike_raster_pre = np.zeros((num_neurons, num_steps))
np.place(spike_raster_pre, np.random.rand(num_neurons, num_steps) < spike_prob, 1)

spike_raster_post = np.zeros((num_neurons, num_steps))
np.place(spike_raster_post, np.random.rand(num_neurons, num_steps) < spike_prob, 1)

### Create Network
The following diagram depics the Lava Process architecture used in this tutorial. It consists of:
- 2 Constant pattern generators for injecting spike trains to Leaky-Intergrate and Fire (LIF) neurons.
- 2 _LIF_ Processes representing pre- and post-synaptic Leaky Integrate-and-Fire neurons.
- 1 _LearningDense_ Process representing learning-enabled connection between LIF neurons.

>**Note:** 
All neuronal population (spike generator, LIF) are composed of only 1 neuron in this tutorial.

![R_STDP_architecture](r_stdp_tutorial_architecture_2.svg)

In [None]:
from lava.proc.lif.process import LIF
from lava.proc.lif.process import LearningLIF
from lava.proc.io.source import RingBuffer
from lava.proc.dense.process import Dense

In [None]:
# Create input devices
pattern_pre = RingBuffer(data=spike_raster_pre.astype(int))
pattern_post = RingBuffer(data=spike_raster_post.astype(int))

# Create input connectivity
conn_inp_pre = Dense(weights=wgt_inp)
conn_inp_post = Dense(weights=wgt_inp)

# Create pre-synaptic neurons
lif_pre = LIF(u=0,
              v=0,
              du=du,
              dv=du,
              bias_mant=0,
              bias_exp=0,
              vth=vth,
              shape=shape_lif,
              name='lif_pre')

# Create plastic connection
plast_conn = Dense(weights=wgt_plast_conn,
                   learning_rule=R_STDP,
                   name='plastic_dense')

# Create post-synaptic neuron
lif_post = LearningLIF(u=0,
               v=0,
               du=du,
               dv=du,
               bias_mant=0,
               bias_exp=0,
               vth=vth,
               shape=shape_lif,
               name='lif_post')

# Connect network
pattern_pre.s_out.connect(conn_inp_pre.s_in)
conn_inp_pre.a_out.connect(lif_pre.a_in)

pattern_post.s_out.connect(conn_inp_post.s_in)
conn_inp_post.a_out.connect(lif_post.a_in)

lif_pre.s_out.connect(plast_conn.s_in)
plast_conn.a_out.connect(lif_post.a_in)

# VERIFY
# Connect back-propagating actionpotential (BAP)
lif_post.s_out_bap.connect(plast_conn.s_in_bap)

# Connect reward trace callback (y2)
lif_post.s_out_y2.connect(plast_conn.s_in_y2)

### Create monitors to observe traces

In [None]:
from lava.proc.monitor.process import Monitor

# Create monitors
mon_pre_trace = Monitor()
mon_post_trace = Monitor()
mon_reward_trace = Monitor()
mon_pre_spikes = Monitor()
mon_post_spikes = Monitor()
mon_weight = Monitor()

# Connect monitors
mon_pre_trace.probe(plast_conn.x1, num_steps)
mon_post_trace.probe(plast_conn.y1, num_steps)
mon_reward_trace.probe(plast_conn.y2, num_steps)
mon_pre_spikes.probe(lif_pre.s_out, num_steps)
mon_post_spikes.probe(lif_post.s_out, num_steps)
mon_weight.probe(plast_conn.weights, num_steps)

### Running 

In [None]:
from lava.magma.core.run_conditions import RunSteps
from lava.magma.core.run_configs import Loihi1SimCfg

In [None]:
# Running
pattern_pre.run(condition=RunSteps(num_steps=num_steps), run_cfg=Loihi1SimCfg(select_tag=SELECT_TAG))

In [None]:
# Get data from monitors
pre_trace = mon_pre_trace.get_data()['plastic_dense']['x1']
post_trace = mon_post_trace.get_data()['plastic_dense']['y1']
reward_trace = mon_reward_trace.get_data()['plastic_dense']['y2']
pre_spikes = mon_pre_spikes.get_data()['lif_pre']['s_out']
post_spikes = mon_post_spikes.get_data()['lif_post']['s_out']
weights = mon_weight.get_data()['plastic_dense']['weights'][:, :, 0]

In [None]:
# Stopping
pattern_pre.stop()