# Polebalancing using NESTML

In this tutorial, we are going to build an agent that can successfully solve the classic pole balancing problem using reinforcement learning. We will start with a standard temporal difference learning approach and after that, use NESTML to set up a spiking neural network to perform this task.

# Cart Pole Environment

For the cart pole environment, we mostly need three things:  
    - A renderer to display the simulation  
    - The physics system and  
    - An input to be able to nudge the pole in both directions  

For that, we will need the following packages:

In [1]:
import pygame as pg
from typing import Tuple
import numpy as np

pygame 2.6.1 (SDL 2.28.4, Python 3.12.8)
Hello from the pygame community. https://www.pygame.org/contribute.html


Let's start with the renderer...

In [2]:
#Renders the scene. IMPORTANT: Because ipycanvas uses the html canvas coordinates, the y-axis is inverted.
class Renderer():
    def __init__(self, width: int, height: int, origin_x: int = 0, origin_y: int = 0, SCALE: int = 1) -> None:
        self.width = width
        self.height = height
        self.origin = (origin_x, origin_y)
        self.SCALE = SCALE #1m = SCALE pixels

        pg.display.init()
        pg.display.set_caption("Pole Balancing Simulator")
        pg.font.init()
        self.screen = pg.display.set_mode((width, height))
    
    #Translates global coordinates into screen coordinates
    def translate(self, x: int, y: int) -> Tuple[int, int]:
        return (x+self.origin[0], -y+self.origin[1])
    
    #Draws ground. offset is there to shift the ground below the car
    def draw_ground(self, offset: int, color) -> None:
        ground = pg.Rect(self.translate(-self.width//2, -offset * self.SCALE), (self.width, self.height-self.origin[1]-offset * self.SCALE))
        pg.draw.rect(self.screen, color, ground)

    #Draws car. pos_y is omitted because the car's center should be at y = 0
    def draw_car(self, pos_x: float, car_color = "blue", wheel_color = "black") -> None:
        pos_x *= self.SCALE
        #values, hard-coded for now, in meters
        width = 0.5 * self.SCALE
        height = 0.25 * self.SCALE
        wheel_radius = 0.1 * self.SCALE

        car_body = pg.Rect(self.translate(pos_x - width/2, height/2), (width, height))
        pg.draw.rect(self.screen, car_color, car_body)
        pg.draw.circle(self.screen, wheel_color, 
                           self.translate(pos_x - width/2 + wheel_radius, -height/2), wheel_radius)
        pg.draw.circle(self.screen, wheel_color, 
                           self.translate(pos_x + width/2 - wheel_radius, -height/2), wheel_radius)

    #Draws the pole
    def draw_pole(self, pos_x: float, theta: float, length: float, width: float = 0.1, color = "red") -> None:
        pos_x *= self.SCALE
        width = int(width * self.SCALE)
        pole_end_x = length * np.sin(theta) * self.SCALE + pos_x
        pole_end_y = length * np.cos(theta) * self.SCALE
        pg.draw.line(self.screen, color, self.translate(pos_x, 0), self.translate(pole_end_x, pole_end_y), width)

    #Clears the entire canvas
    def draw_clear(self) -> None:
        self.screen.fill("white")

    #Draws physical values
    def draw_stats(self, theta: float, w: float, v: float, x: float, 
                    episode: int, 
                    spikes_left : int, spikes_right : int, 
                    dopamine_left : float, dopamine_right : float, 
                    action: int) -> None:
        font = pg.font.Font(None, 24)
        #Physics stats, drawn left
        text = "angle: " + str(theta)[:4] + \
            "\nangular velocity: " + str(w)[:4] + \
            "\nposition: " + str(x)[:4] + \
            "\nvelocity" + str(v)[:4] + \
            " \nepisode: " + str(episode)
        lines = text.split('\n')
        y_pos = 10
        for line in lines:
            text_surface = font.render(line, True, (0,0,0))
            self.screen.blit(text_surface, (10, y_pos))
            y_pos += 30

        #Network stats, drawn right
        text = "Spikes left: " + str(spikes_left) + \
            "\nDopamine left: " + str(dopamine_left)[:4] + \
            "\nSpikes right: " + str(spikes_right)[:4] + \
            "\nDopamine right: " + str(dopamine_right) + \
            "\nTaken action: " + ("Left" if action==0 else "Right" if action==1 else "Failure")
        lines = text.split('\n')
        y_pos = 10
        for line in lines:
            text_surface = font.render(line, True, (0,0,0))
            self.screen.blit(text_surface, (self.width - 200, y_pos))
            y_pos += 30

    def get_relative_mouse_x(self, mouse_x:float) -> float:
        return (mouse_x-self.origin[0])/self.SCALE
    
    def display(self) -> None:
        pg.display.flip()

## Physics Updates

For the physics, we use the corrected version of of the original problem derived from V. Florian (CITATION NEEDED), but omit the friction forces.
The situation is sketched here:  

![alt text](cartpole_illustration.png "Cartpole")

We apply Newton's second law of motion to the cart:  
$$
\begin{aligned}
    \mathbf{F} + \mathbf{G}_c - \mathbf{N} = m_c \cdot \mathbf{a}_c
\end{aligned}
$$
Where:  

$\mathbf{F} = F \cdot \mathbf{u_x}$ is the control force acting on the cart,  
$\mathbf{G}_c = m_c \cdot g \cdot \mathbf{u}_y$ is the gravitational component acting on the cart,  
$\mathbf{N} = N_x \cdot \mathbf{u}_x - N_y \cdot \mathbf{u}_y$ is the negative reaction force that the pole is applying on the cart,  
$\mathbf{a}_c = \ddot{x} \cdot \mathbf{u}_x$ is the accelaration of the cart,  
$m_c$ is the cart's mass and  
$\mathbf{u}_x$, $\mathbf{u}_y$, $\mathbf{u}_z$ are the unit vectors of the frame of reference given in the illustration.

We can decompose this equation now into the $x$ and $y$ component:
$$
\begin{aligned}
    F - N_x = m_c \cdot \ddot{x}
\end{aligned}
$$
$$
\begin{aligned}
    m_c \cdot g + N_y = 0
\end{aligned}
$$

Newton's second law of motion applied to the pole gives us:
$$
\begin{aligned}
    \mathbf{N} + \mathbf{G}_p = m_p \cdot \mathbf{a}_p
\end{aligned}
$$

Where $\mathbf{G}_p = m_p \cdot g \cdot \mathbf{u}_y$.

The accelaration $\mathbf{a}_p$ of the pole's center of mass consists of three components, where $\mathbf{r}_p = l \cdot (\sin{\theta}\cdot \mathbf{u}_x-\cos{\theta}\cdot \mathbf{u}_y)$ denotes the vector pointing to the pole's center of mass relative to it's rotation center:  
1. The accelaration of the cart it is attached to $\mathbf{a}_c$,
2. The pole's angular accelaration $\mathbf{\epsilon} = \ddot{\theta} \cdot \mathbf{u}_z$, which is translated into accelaration by $\mathbf{\epsilon} \times \mathbf{r}_p$.
3. The pole's angular velocity $\mathbf{\omega} = \dot{\theta} \cdot \mathbf{u}_z$, for which the accelaration can be derived by  $\mathbf{\omega} \times (\mathbf{\omega} \times \mathbf{r}_p)$.

Thus we obtain:
$$
\begin{aligned}
    \mathbf{a}_p  = \mathbf{a}_c + \mathbf{\epsilon} \times \mathbf{r}_p + \mathbf{\omega} \times (\mathbf{\omega} \times \mathbf{r}_p)
\end{aligned}
$$
Substituting $\mathbf{r}_p = l \cdot (\sin{\theta}\cdot \mathbf{u}_x-\cos{\theta}\cdot \mathbf{u}_y)$ and $\mathbf{a}_p = \ddot{x} \cdot \mathbf{u}_x$ as well as $\mathbf{u}_z \times \mathbf{u}_x = \mathbf{u}_y$ and $\mathbf{u}_z \times \mathbf{u}_y = -\mathbf{u}_x$:
\begin{aligned}
    \mathbf{a}_p  = \ddot{x} \cdot \mathbf{u}_x + l \cdot \ddot{\theta} \cdot (\sin{\theta}\cdot \mathbf{u}_y + \cos{\theta}\cdot \mathbf{u}_x) - l \cdot \dot{\theta}^2 \cdot (\sin{\theta}\cdot \mathbf{u}_x - \cos{\theta}\cdot \mathbf{u}_y)
\end{aligned}

Inserting this quation into our equation for the forces of the pole and decomposing on the $x$ and $y$ axis we obtain:
$$
\begin{aligned}
    N_x = m_p \cdot (\ddot{x} + l \cdot \ddot{\theta} \cdot \cos{\theta} - l \cdot \dot{\theta}^2 \cdot \sin{\theta})
\end{aligned}
$$
$$
\begin{aligned}
    m_p \cdot g - N_y = m_p \cdot (l \cdot \ddot{\theta} \cdot \sin{\theta} + l \cdot \dot{\theta}^2 \cdot \cos{\theta})
\end{aligned}
$$

# TODO: FINISH EQUATION DERIVATION (SOLVE EQUATION REFERENCING?)

In [3]:
class Physics():
    
    def __init__(self, x, theta, v = 0, a = 0, w = 0, dw = 0, g = 9.81, m_c = 1, m_p = 0.1, l = 0.5, dt = 0.02) -> None:
        self.__dict__.update(vars())

    def dw_step(self, cart_force, nudge_force) -> float:
        numerator = self.g * np.sin(self.theta) + np.cos(self.theta) * (-cart_force - self.m_p * self.l * self.w**2 * np.sin(self.theta))/(self.m_c+self.m_p) + nudge_force * np.cos(self.theta)/(self.m_p*self.l)
        denominator = self.l * (4/3 - (self.m_p*np.cos(self.theta)**2)/(self.m_c+self.m_p))

        self.dw = numerator/denominator
        self.w += self.dt * self.dw
        self.theta += self.dt * self.w

        return self.theta
    
    def a_step(self, force) -> float:
        numerator = force + self.m_p * self.l * (self.w**2 * np.sin(self.theta) - self.dw * np.cos(self.theta))
        denominator = self.m_c + self.m_p

        self.a = numerator/denominator
        self.v += self.dt * self.a
        self.x += self.dt * self.v

        return self.x

    def update(self, force, mouse_x) -> Tuple[float, float]:
        nudge_force = 0
        if mouse_x is not None:
            nudge_force = -1 if mouse_x > self.x else 1
        return (self.dw_step(force, nudge_force), self.a_step(force))
    
    #get state of the system that agent can see
    def get_state(self) -> Tuple[float,float,float,float]:
        return (self.x, self.theta, self.v, self.w)
    
    def reset(self) -> None:
        self.x = 0
        self.theta = (np.random.rand() - 1) / 10
        self.v = 0
        self.a = 0
        self.w = 0
        self.dw = 0


# The Agent (BOXES)

In [4]:

class Agent:
    def __init__(self, initial_state: Tuple[float,float,float,float]) -> None:

        #thresholds for discretizing the state space
        
        # ORIGINAL BOXES THAT WE USED SUCCESSFULLY ON THE NON SPIKING AGENT
#         self.x_thresholds = np.array([-2.4, -0.8, 0.8, 2.4])
#         self.theta_thresholds = np.array([-12, -6, -1, 0, 1, 6, 12])
#         self.theta_thresholds = self.theta_thresholds /180 * np.pi
#         self.v_thresholds = np.array([float("-inf"), -0.5, 0.5, float("+inf")]) #open intervals ignored here
#         self.w_thresholds = np.array([float("-inf"), -50, 50, float("+inf")]) #open intervals ignored here
#         self.w_thresholds = self.w_thresholds / 180 * np.pi

        # BOXES FROM LIU&PAN CODE
        self.x_thresholds = np.array([-2.4, 2.4])
        self.v_thresholds = np.array([float("-inf"), float("+inf")])
        
        self.theta_thresholds = np.array([-12, -5.738738738738739, -2.8758758758758756, 0., 2.8758758758758756, 5.738738738738739, 12])
        self.theta_thresholds = self.theta_thresholds / 180 * np.pi
        
        self.w_thresholds = np.array([float("-inf"), -103., -91.7, -80.2, -68.8, -57.3, -45.9, -34.3, -22.9, -11.5, 0.,
                                                      11.5, 22.9, 34.3, 45.9, 57.3, 68.8, 80.2, 91.7, 103., float("+inf")]) #open intervals ignored here
        self.w_thresholds = self.w_thresholds / 180 * np.pi
        
        self.dimensions = (len(self.x_thresholds) - 1, len(self.theta_thresholds) - 1, len(self.v_thresholds) - 1, len(self.w_thresholds) - 1)

        print("Dimension of input space: " + str(self.dimensions))
        
        self.boxes = np.random.rand(self.dimensions[0], 
                                    self.dimensions[1], 
                                    self.dimensions[2], 
                                    self.dimensions[3], 
                                    2) #one q-value for left and right respectively
        box = self.get_box(initial_state)
        self.current_box = self.boxes[box[0], box[1], box[2], box[3], :]

        self.episode = 1
    
    def discretize(self, value, thresholds):
        for i, limit in enumerate(thresholds):
            if value < limit:
                return i - 1
        return -1

    def get_box(self, state: Tuple[float,float,float,float]) -> Tuple[int,int,int,int]:
        return (self.discretize(state[0], self.x_thresholds),
                 self.discretize(state[1], self.theta_thresholds),
                 self.discretize(state[2], self.v_thresholds), 
                 self.discretize(state[3], self.w_thresholds))
    
    def get_episode(self) -> int:
        return self.episode
    
    
    def failure_reset(self, state: Tuple[float,float,float,float]):
        box = self.get_box(state)
        self.current_box = self.boxes[box[0], box[1], box[2], box[3], :]
        self.episode += 1


class NonSpikingAgent(Agent):
    def __init__(self, initial_state: Tuple[float,float,float,float], learning_rate, learning_decay, epsilon, epsilon_decay, discount_factor) -> None:
        super().__init__(initial_state)

        #learning paramters
        self.learning_rate = learning_rate
        self. learning_decay = learning_decay
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.discount_factor = discount_factor

    #returns 0 if the action is "left", else "1"
    def choose_action(self) -> int:
        self.action = np.random.choice([np.argmax(self.current_box), np.argmin(self.current_box)], p=[1-self.epsilon, self.epsilon])
        return self.action
    
    #returns 0 if no failure occured, else 1
    #reward is -1 on failure and 0 else
    def update(self, next_state: Tuple[float,float,float,float]) -> int:
        box = self.get_box(next_state)
        if -1 in box:
            self.current_box[self.action] += self.learning_rate * -1
            return 1
        
        next_box = self.boxes[box[0], box[1], box[2], box[3], :]
        next_q = np.max(next_box)
        self.current_box[self.action] += self.learning_rate * (self.discount_factor * (next_q - self.current_box[self.action]))

        self.current_box = next_box
        self.epsilon *= self.epsilon_decay
        self.learning_rate *= self.learning_decay

        return 0
    

# Plot Renderer

In [5]:
import matplotlib.pyplot as plt
%matplotlib qt
class Non_Spiking_PlotRenderer():
    def __init__(self, init_x = [0], init_y = [0]) -> None:
        plt.ion()
        #Construct lifetime plot
        self.lifetime_fig, self.lifetime_ax = plt.subplots()
        self.x = init_x
        self.y = init_y
        self.max_lifetime = 0
        self.line, = self.lifetime_ax.plot(self.x, self.y)
        self.lifetime_ax.set_xlabel("Episode")
        self.lifetime_ax.set_ylabel("Simulation Steps")
        self.lifetime_ax.set_title("Lifetime Plot")

        #Construct Heatmap for two parameters
        self.q_value_fig, self.q_value_ax = plt.subplots()
        self.q_value_ax.set_title("Q-Values for a state of (param1/param2)")
        self.cmap = plt.cm.coolwarm
        
    def update(self, x, y, boxes) -> None:
        print(x)
        self.x.append(x)
        self.y.append(y)
        self.max_lifetime = max(self.max_lifetime, y)
        self.line.set_data(self.x, self.y)
        self.lifetime_ax.set_xlim(self.x[0], self.x[-1])
        self.lifetime_ax.set_ylim(0, self.max_lifetime)

        if(x % 10 == 0):
            q_values = boxes[:,:,:,:,0] - boxes[:,:,:,:,1]
            self.q_value_ax.imshow(np.mean(q_values, axis = (1,3)), cmap=self.cmap, interpolation='none')

        plt.draw()
        plt.pause(0.0001)



# Executing Non-Spiking-Agent

In [None]:


# import sys

# r = Renderer(1200, 800, 600, 500, 400)
# clock = pg.time.Clock()
# running = True

# p = Physics(0, (np.random.rand() - 1) / 10)

# a = NonSpikingAgent(p.get_state(), 0.5, 0.9999999999999, 1, 0.995, 0.99)

# plot = Non_Spiking_PlotRenderer()

# steps_per_episode = 0
# max_steps = 0

# window_size = 30
# window = np.zeros(30)
# avg_lifetime = 20000

# toggle_sim = False

# while running:
#     steps_per_episode += 1

#     force = 0
#     mouse_x = None

#     # poll for events
#     for event in pg.event.get():
#         if event.type == pg.QUIT:
#             running = False
#             pg.quit()
#             sys.exit()
#             quit()
#         elif event.type == pg.MOUSEBUTTONDOWN:
#             mouse_x = r.get_relative_mouse_x(pg.mouse.get_pos()[0])
#         elif event.type == pg.KEYDOWN:
#             toggle_sim ^= pg.key.get_pressed()[pg.K_SPACE]

#     # agent chooses action, simulation is updated and reward is calculated
#     force = 10 if a.choose_action() else -10
#     theta, x = p.update(force, mouse_x)
#     failure = a.update(p.get_state())

#     if failure:
#         p.reset()
#         a.failure_reset(p.get_state())
#         plot.update(a.get_episode(), steps_per_episode, a.boxes)
#         window = np.roll(window, 1)
#         window[0] = steps_per_episode
#         steps_per_episode = 0
    
    
#     if np.mean(window) >= avg_lifetime or toggle_sim:
#         r.draw_clear()
#         r.draw_ground(0.2, "grey")
#         r.draw_car(x)
#         r.draw_pole(x, theta, 2*p.l, 0.02)
#         #r.draw_stats(theta*180/np.pi, p.w*180/np.pi, x, p.a, a.get_episode())
#         r.display()

#         clock.tick(50)  # limits FPS to 50


Dimension of input space: (1, 6, 1, 20)
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
26

SystemExit: 

# TODO: clean up code, derive equations and explain renderer briefly

# Spiking version

## Idea

The core principle of our SNN is to simulate the physics and neuron model in sequence, where the state at the end of a physics step is the input for the SNN and the resulting action at the end of a period of SNN simulation is the input to the next physics simulation. Both cycles are set to 40ms to provide the effect that they run simultaneously.
The model's structure consists of two layers of neurons. For each discrete state of the system, the input layer contains a single neuron corresponding to it. Neuromodulated synapses connect these to the output layer, which itself consists of two neuron groups interpreted as actions "move left" and "move right" respectively.

One simulation step of the SNN works as follows:
1. Get the current state of the cart pole and find the designated neuron that only fires when that state is reached.
2. Set a continuous firing rate for the simulation period on that neuron.
3. Determine which of the neuron groups in the output layer has fired more spikes at the end of the step.

# SNN Visualization

In [None]:
class Spiking_PlotRenderer():
    def __init__(self) -> None:
        plt.ion()
        self.fig, self.ax = plt.subplots(nrows=2, figsize=(12, 4))
        
        self.spike_line, = self.ax[0].plot([], [], '.k', markersize=5)
        self.vm_right_line, = self.ax[1].plot([], [], 'r')
        self.vm_left_line, = self.ax[1].plot([], [], 'b')

        self.ax[0].set_xlabel("Time [ms]")
        self.ax[0].set_ylabel("Input Neuron")
        self.ax[0].set_ylim(0, 160)
        
        self.ax[1].set_ylabel("V_m [mV]")
        self.ax[1].set_xlabel("Time [ms]")
        self.ax[1].set_ylim(-75, -50)

        self.fig.show()
    
    def update(self, data) -> None:
        if data is None:
            return

        input_spikes_times = data["input_spikes"]["times"]
        input_spikes_senders = data["input_spikes"]["senders"]
        right_vm_times = data["multimeter_right_events"]["times"]
        right_vm = data["multimeter_right_events"]["V_m"]
        left_vm_times = data["multimeter_left_events"]["times"]
        left_vm = data["multimeter_left_events"]["V_m"]

        # Update spike plot
        self.spike_line.set_data(input_spikes_times, input_spikes_senders)
        self.ax[0].set_xlim(np.min(input_spikes_times), np.max(input_spikes_times))

        # Update membrane potential plot
        self.vm_right_line.set_data(right_vm_times, right_vm)
        self.vm_left_line.set_data(left_vm_times, left_vm)
        self.ax[1].set_xlim(np.min(right_vm_times), np.max(right_vm_times))

        self.fig.canvas.draw()
        self.fig.canvas.flush_events()

        self.fig.savefig("/tmp/cartpole.png", dpi=300)

        plt.pause(0.0001)


## Neuron Models

### Ignore and Fire Neuron

In [8]:
# ... generate NESTML model code...

from pynestml.codegeneration.nest_code_generator_utils import NESTCodeGeneratorUtils

# generate and build code
input_layer_module_name, input_layer_neuron_model_name = \
   NESTCodeGeneratorUtils.generate_code_for("../../../models/neurons/ignore_and_fire_neuron.nestml")

# ignore_and_fire
output_layer_module_name, output_layer_neuron_model_name, output_layer_synapse_model_name = \
    NESTCodeGeneratorUtils.generate_code_for("iaf_psc_exp_neuron.nestml",
                                             "neuromodulated_stdp_synapse.nestml",
                                             post_ports=["post_spikes"],
                                             logging_level="DEBUG",
                                             codegen_opts={"delay_variable": {"neuromodulated_stdp_synapse": "d"},
                                                           "weight_variable": {"neuromodulated_stdp_synapse": "w"}})





              -- N E S T --
  Copyright (C) 2004 The NEST Initiative

 Version: 3.8.0
 Built: Aug 27 2024 04:38:39

 This program is provided AS IS and comes with
 NO WARRANTY. See the file LICENSE for details.

 Problems or suggestions?
   Visit https://www.nest-simulator.org

 Type 'nest.help()' to find out more about NEST.


              -- N E S T --
  Copyright (C) 2004 The NEST Initiative

 Version: 3.8.0
 Built: Aug 27 2024 04:38:39

 This program is provided AS IS and comes with
 NO WARRANTY. See the file LICENSE for details.

 Problems or suggestions?
   Visit https://www.nest-simulator.org

 Type 'nest.help()' to find out more about NEST.

  cmake_minimum_required() should be called prior to this top-level project()
  call.  Please see the cmake-commands(7) manual for usage documentation of
  both commands.
[0m
-- The CXX compiler identification is Clang 18.1.8
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /

INFO:root:Analysing input:
INFO:root:{
    "dynamics": [
        {
            "expression": "g_e' = (-g_e) / tau_g",
            "initial_values": {
                "g_e": "0.0"
            }
        },
        {
            "expression": "V_m' = (g_e * (E_e - V_m) + E_l - V_m + I_e + I_stim) / tau_m",
            "initial_values": {
                "V_m": "E_l"
            }
        }
    ],
    "options": {
        "output_timestep_symbol": "__h"
    },
    "parameters": {
        "E_e": "0",
        "E_l": "(-74)",
        "I_e": "0",
        "V_reset": "(-60)",
        "V_th": "(-54)",
        "s": "1000",
        "tau_g": "5",
        "tau_m": "10"
    }
}
INFO:root:Processing global options...
INFO:root:Processing input shapes...
INFO:root:
Processing differential-equation form shape g_e with defining expression = "(-g_e) / tau_g"
DEBUG:root:Splitting expression -g_e/tau_g (symbols [g_e])
DEBUG:root:	linear factors: Matrix([[-1/tau_g]])
DEBUG:root:	inhomogeneous term: 0.0
DEBUG:

[42,GLOBAL, INFO]: Successfully constructed neuron-synapse pair iaf_psc_exp_neuron_nestml__with_neuromodulated_stdp_synapse_nestml, neuromodulated_stdp_synapse_nestml__with_iaf_psc_exp_neuron_nestml
[43,GLOBAL, INFO]: Analysing/transforming model 'iaf_psc_exp_neuron_nestml'
[44,iaf_psc_exp_neuron_nestml, INFO, [18:0;58:0]]: Starts processing of the model 'iaf_psc_exp_neuron_nestml'


INFO:root:Analysing input:
INFO:root:{
    "dynamics": [
        {
            "expression": "g_e' = (-g_e) / tau_g",
            "initial_values": {
                "g_e": "0.0"
            }
        },
        {
            "expression": "V_m' = (g_e * (E_e - V_m) + E_l - V_m + I_e + I_stim) / tau_m",
            "initial_values": {
                "V_m": "E_l"
            }
        },
        {
            "expression": "post_trace__for_neuromodulated_stdp_synapse_nestml' = (-post_trace__for_neuromodulated_stdp_synapse_nestml) / tau_tr_post__for_neuromodulated_stdp_synapse_nestml",
            "initial_values": {
                "post_trace__for_neuromodulated_stdp_synapse_nestml": "0.0"
            }
        }
    ],
    "options": {
        "output_timestep_symbol": "__h"
    },
    "parameters": {
        "E_e": "0",
        "E_l": "(-74)",
        "I_e": "0",
        "V_reset": "(-60)",
        "V_th": "(-54)",
        "s": "1000",
        "tau_g": "5",
        "tau_m": "10",
  

[46,iaf_psc_exp_neuron_nestml, INFO, [37:19;37:19]]: Implicit casting from (compatible) type 'integer' to 'real'.
[47,iaf_psc_exp_neuron_nestml, INFO, [40:17;40:17]]: Implicit casting from (compatible) type 'integer' to 'real'.
[48,iaf_psc_exp_neuron_nestml, INFO, [53:15;53:32]]: Implicit casting from (compatible) type '1 / s buffer' to 'real'.
[49,GLOBAL, INFO]: Analysing/transforming model 'iaf_psc_exp_neuron_nestml__with_neuromodulated_stdp_synapse_nestml'
[50,iaf_psc_exp_neuron_nestml__with_neuromodulated_stdp_synapse_nestml, INFO, [18:0;58:0]]: Starts processing of the model 'iaf_psc_exp_neuron_nestml__with_neuromodulated_stdp_synapse_nestml'


INFO:root:In ode-toolbox: returning outdict = 
INFO:root:[
    {
        "initial_values": {
            "g_e": "0.0",
            "post_trace__for_neuromodulated_stdp_synapse_nestml": "0.0"
        },
        "parameters": {
            "tau_g": "5.00000000000000",
            "tau_tr_post__for_neuromodulated_stdp_synapse_nestml": "20.0000000000000"
        },
        "propagators": {
            "__P__g_e__g_e": "exp(-__h/tau_g)",
            "__P__post_trace__for_neuromodulated_stdp_synapse_nestml__post_trace__for_neuromodulated_stdp_synapse_nestml": "exp(-__h/tau_tr_post__for_neuromodulated_stdp_synapse_nestml)"
        },
        "solver": "analytical",
        "state_variables": [
            "g_e",
            "post_trace__for_neuromodulated_stdp_synapse_nestml"
        ],
        "update_expressions": {
            "g_e": "__P__g_e__g_e*g_e",
            "post_trace__for_neuromodulated_stdp_synapse_nestml": "__P__post_trace__for_neuromodulated_stdp_synapse_nestml__post_trace__f

[52,iaf_psc_exp_neuron_nestml__with_neuromodulated_stdp_synapse_nestml, INFO, [37:19;37:19]]: Implicit casting from (compatible) type 'integer' to 'real'.
[53,iaf_psc_exp_neuron_nestml__with_neuromodulated_stdp_synapse_nestml, INFO, [40:17;40:17]]: Implicit casting from (compatible) type 'integer' to 'real'.
[54,iaf_psc_exp_neuron_nestml__with_neuromodulated_stdp_synapse_nestml, INFO, [53:15;53:32]]: Implicit casting from (compatible) type '1 / s buffer' to 'real'.
[55,GLOBAL, INFO]: Analysing/transforming synapse neuromodulated_stdp_synapse_nestml__with_iaf_psc_exp_neuron_nestml.
[56,neuromodulated_stdp_synapse_nestml__with_iaf_psc_exp_neuron_nestml, INFO, [6:0;53:0]]: Starts processing of the model 'neuromodulated_stdp_synapse_nestml__with_iaf_psc_exp_neuron_nestml'
[58,neuromodulated_stdp_synapse_nestml__with_iaf_psc_exp_neuron_nestml, INFO, [8:17;8:17]]: Implicit casting from (compatible) type 'integer' to 'real'.
[59,neuromodulated_stdp_synapse_nestml__with_iaf_psc_exp_neuron_nest

In [41]:
import nest
import json
import os
import enum

nest.set_verbosity("M_ERROR")

class AgentAction(enum.Enum):
    FAILURE = -1
    LEFT = 0
    RIGHT = 1

class SpikingAgent(Agent):
    cycle_period = 40.   # [ms], corresponding to 2 physics steps
   
    def __init__(self, initial_state: Tuple[float,float,float,float], gamma) -> None:
        super().__init__(initial_state)
        self.gamma = gamma
        self.construct_neural_network()
        self.Q_left = 0.
        self.Q_right = 0.
        self.Q_left_prev = 0.
        self.Q_right_prev = 0.
        self.scale_n_output_spikes_to_Q_value = 0.1
        self.dopamine_left = 0.
        self.dopamine_right = 0.
        self.last_action_chosen = AgentAction.LEFT   # ?! choose first action randomly
        self.R = 1.  # reward -- always 1!

    def get_state_neuron(self, state) -> int:
        idx = 0
        thresholds = [self.x_thresholds, self.theta_thresholds, self.v_thresholds, self.w_thresholds]
        for dim, val, thresh in zip(self.dimensions, state, thresholds):
            i = self.discretize(val,thresh)
            if i == -1:
                return -1
            idx = idx * dim + i

        return idx
    
    def get_state_from_id(self, idx) -> None:
        assert idx >= 0 and idx < len(self.input_population)
        state = [-1,-1,-1,-1]
        for i in reversed(range(len(state))):
            state[i] = idx % self.dimensions[i]
            idx = idx // self.dimensions[i]
        return tuple(state)
    
    def construct_neural_network(self):
        nest.ResetKernel()
        nest.Install(input_layer_module_name)   # makes the generated NESTML model available
        nest.Install(output_layer_module_name)   # makes the generated NESTML model available

        self.input_size = self.dimensions[0] * self.dimensions[1] * self.dimensions[2] * self.dimensions[3]
        self.input_population = nest.Create(input_layer_neuron_model_name, self.input_size)
    
        
        self.output_population_left = nest.Create(output_layer_neuron_model_name, 10)
        self.output_population_right = nest.Create(output_layer_neuron_model_name, 10)
        
        self.spike_recorder_input = nest.Create("spike_recorder")
        nest.Connect(self.input_population, self.spike_recorder_input)

        self.multimeter_left = nest.Create('multimeter', 1, {'record_from': ['V_m']})
        nest.Connect(self.multimeter_left, self.output_population_left)
        self.multimeter_right = nest.Create('multimeter', 1, {'record_from': ['V_m']})
        nest.Connect(self.multimeter_right, self.output_population_right)

        syn_opts = {"synapse_model": output_layer_synapse_model_name,
                    "weight": 0.1 + nest.random.uniform(min=0.0, max=1.0) * 0.02,
                    "beta": 0.01,
                    "tau_tr_pre": 20., # [ms]
                    "tau_tr_post": 20.,  # [ms]
                    #"Wmax": 0.3,
                    #"Wmin": 0.005,
                    "wtr_max": 0.1,
                    "wtr_min": 0.,
                    "pre_trace_increment": 0.0001,
                    "post_trace_increment": -1.05E-7}
        
        nest.Connect(self.input_population, self.output_population_left, syn_spec=syn_opts)
        nest.Connect(self.input_population, self.output_population_right, syn_spec=syn_opts)

        self.output_population_spike_recorder_left = nest.Create("spike_recorder")
        nest.Connect(self.output_population_left, self.output_population_spike_recorder_left)

        self.output_population_spike_recorder_right = nest.Create("spike_recorder")
        nest.Connect(self.output_population_right, self.output_population_spike_recorder_right)
        
        # set default values for prev_syn_wtr_right and left
        syn_right = nest.GetConnections(source=self.input_population, target=self.output_population_right)
        self.prev_syn_wtr_right = syn_right.wtr
        syn_left = nest.GetConnections(source=self.input_population, target=self.output_population_left)
        self.prev_syn_wtr_left = syn_left.wtr
        
        
    #stores important connections in a JSON, can be used to plot features of network
    def save_network(self):
        connection_dictionary = {}
        for input_neuron_id in range(len(self.input_population)):
            neuron = self.input_population[input_neuron_id]
            conn_left = nest.GetConnections(source=neuron, target=self.output_population_left)
            conn_right = nest.GetConnections(source=neuron, target=self.output_population_right)
            state = self.get_state_from_id(input_neuron_id) #state is a tuple of the corresponding bins for each variable indexed at 0
            connection_dictionary[str(state)] = {"neuron": neuron.get(),
                                            "connection_left": conn_left.get(),
                                            "connection_right": conn_right.get(),
                                            }
        #os.makedirs("/saved_networks", exist_ok=True)
        with open("saved_networks/network.json", "w") as f:
            json.dump(connection_dictionary, f, indent=4)

    def choose_action(self, Q_left, Q_right) -> AgentAction:
        if Q_left > Q_right:
            return AgentAction.LEFT
        
        return AgentAction.RIGHT

    def compute_Q_values(self) -> None:
        r"""The output of the SNN is interpreted as the (scaled) Q values."""
        self.Q_left_prev = self.Q_left
        self.Q_right_prev = self.Q_right

        n_events_in_last_interval_left = self.output_population_spike_recorder_left.n_events
        n_events_in_last_interval_right = self.output_population_spike_recorder_right.n_events
        self.Q_left = self.scale_n_output_spikes_to_Q_value * n_events_in_last_interval_left
        self.Q_right = self.scale_n_output_spikes_to_Q_value * n_events_in_last_interval_right

    # update Q_value using TD-Error with previous Q_value and reward = 0
    # cooldown_time in case the SNN doesn't need 40ms to update
    def failure_reset(self, cooldown_time) -> None:
        # if for some reason the simulation terminates super fast
        if self.Q_left_prev == None and self.Q_right_prev == None:
            return
        # what would we mean by that? negative dopamine is biologically inaccurate
        # inhibitory neuromodulators?
        if self.choose_action(self.Q_left_prev, self.Q_right_prev) == AgentAction.RIGHT:
            syn = nest.GetConnections(source=self.input_population, target=self.output_population_right)
            syn.n = -self.Q_right
        else:
            syn = nest.GetConnections(source=self.input_population, target=self.output_population_left)
            syn.n = -self.Q_left
        nest.Simulate(cooldown_time)
        
        self.episode += 1

    def update(self, next_state: Tuple[float,float,float,float]) -> Tuple[int, dict]:

        #Reset all spike recorders and multimeters
        #self.multimeter_left.n_events = 0
        #self.multimeter_right.n_events = 0
        #self.spike_recorder_input.n_events = 0
        self.output_population_spike_recorder_left.n_events = 0
        self.output_population_spike_recorder_right.n_events = 0

        # make the correct input neuron fire
        self.input_population.firing_rate = 0.
        neuron_id = self.get_state_neuron(next_state)
        
        self.input_population[neuron_id].firing_rate = 5000. # XXX: value not given in Liu&Pan. Got 500 Hz as max freq from BVogler thesis. n.b. 40 ms cycle time. 
        
        # if state was a failure
        if neuron_id == -1:
            self.failure_reset(SpikingAgent.cycle_period)
            return AgentAction.FAILURE, None
        
        
        # simulate for one cycle
        nest.Simulate(SpikingAgent.cycle_period)
        
        #passed onto Spiking_Plot_Renderer()
        plot_data = {
            "input_spikes": nest.GetStatus(self.spike_recorder_input, keys="events")[0],
            "multimeter_right_events": self.multimeter_right.get("events"),
            "multimeter_left_events": self.multimeter_left.get("events"),
            "n_input_neurons": self.input_size
        }

        self.compute_Q_values()

        # set new dopamine concentration on the synapses
        # PROBLEM: HOW DO WE HANDLE FAILURE? The physics simulation immediately resets after it.
        # Perhaps run the simulation without spiking to let the weights update? (BVogler)

        Q_new = max(self.Q_left, self.Q_right)
        
        if self.last_action_chosen == AgentAction.LEFT:
            Q_old = self.Q_left_prev
        elif self.last_action_chosen == AgentAction.RIGHT:
            Q_old = self.Q_right_prev
        else:
            assert self.last_action_chosen == AgentAction.FAILURE
        
        TD = self.gamma * Q_new + self.R - Q_old
         
        if self.last_action_chosen == AgentAction.RIGHT:
            print("last chosen = right")
            syn = nest.GetConnections(source=self.input_population, target=self.output_population_right)
            syn.w += np.array(syn.beta) * TD * np.array(self.prev_syn_wtr_right)
        else:
            print("last chosen = left")
            assert self.last_action_chosen == AgentAction.LEFT
            syn = nest.GetConnections(source=self.input_population, target=self.output_population_left)
            syn.w += np.array(syn.beta) * TD * np.array(self.prev_syn_wtr_left)
            
#             fig,ax=plt.subplots()
#             ax.plot(np.arange(1200), self.prev_syn_wtr_left)
#             import uuid
#             fig.savefig("/tmp/weights_nest" + str(uuid.uuid4()) + ".png")
            
        self.last_action_chosen = self.choose_action(self.Q_left, self.Q_right)
            
        return self.last_action_chosen, plot_data            
            
            
    def save_prev_syn_wtr(self):
        syn_right = nest.GetConnections(source=self.input_population, target=self.output_population_right)
        self.prev_syn_wtr_right = syn_right.wtr
        syn_left = nest.GetConnections(source=self.input_population, target=self.output_population_left)
        self.prev_syn_wtr_left = syn_left.wtr
        
#         # update Q_value using TD-Error with previous Q_value and reward = 1
#         if self.Q_left_prev != None and self.Q_right_prev != None:
#             if self.choose_action(self.Q_left_prev, self.Q_right_prev) == ...:
#                 last_action_chosen = AgentAction....
#                 syn = nest.GetConnections(source=self.input_population, target=self.output_population_right)
#                 syn.n = self.gamma * self.Q_right + R - self.Q_right_prev
#                 self.dopamine_right = syn.n[0] #for displaying stats
#             else:
#                 syn = nest.GetConnections(source=self.input_population, target=self.output_population_left)
#                 syn.n = self.gamma * self.Q_left + R - self.Q_left_prev
#                 self.dopamine_left = syn.n[0] #for displaying stats
        
#         # 0 if action is "left", else 1


# Executing spiking version

The main loop looks like this: for every iteration of the loop (for every "cycle" or "step"):

- set the rate of the input neurons to the current state of the system
- run the SNN with this input state s_n for a period of time (cycle time, in BVogler's thesis: 40 ms)
- obtain the Q(sn, a) values, by counting nr of spikes in output population over this cycle period
- choose action $a_n$ on the basis of Q-values
- run the environment for the same cycle time (40 ms) to obtain next state $s_{n+1}$
- compute reward on the basis of the last taken action (????)

In [44]:
%%prun
%pdb

import sys

r = Renderer(1200, 800, 600, 500, 400)
clock = pg.time.Clock()
running = True

p = Physics(0, (np.random.rand() - 1) / 10)

a = SpikingAgent(p.get_state(), 0.98)

plot = Spiking_PlotRenderer()

steps_per_episode = 0

window_size = 20
window = np.zeros(window_size)
avg_lifetime = 20000

toggle_sim = True
plot_spikes = True
stepping_sim = False
while running:
    steps_per_episode += 1
    force = 0
    mouse_x = None
   
#     if steps_per_episode > 11:
#         break

    # poll for events
    for event in pg.event.get():
        if event.type == pg.QUIT:
            running = False
            pg.quit()
            sys.exit()
            quit()
        elif event.type == pg.MOUSEBUTTONDOWN:
            mouse_x = r.get_relative_mouse_x(pg.mouse.get_pos()[0])
        elif event.type == pg.KEYDOWN:
            #controls if simulation should be shown or not
            toggle_sim ^= pg.key.get_pressed()[pg.K_1]
            #on button press plots the current spikes
            plot_spikes ^= pg.key.get_pressed()[pg.K_2]
            #on button press stores the network in ./saved_networks/network.json
            if pg.key.get_pressed()[pg.K_3]:
                a.save_network()
            #toggles step-by-step simulation, updating now by pressing space
            stepping_sim ^= pg.key.get_pressed()[pg.K_4]

    if stepping_sim:
        toggle_sim = True
        plot_spikes = True
        next_step = False
        while not next_step and stepping_sim:
            pg.event.wait()
            stepping_sim ^= pg.key.get_pressed()[pg.K_4]
            next_step ^= pg.key.get_pressed()[pg.K_SPACE]

    # Simulate SNN, choose action, simulate physics, receive state
    # Since SNN takes 40ms, it reacts only to every 2nd physics step
    global action
    action, plot_data = a.update(p.get_state())

    if action == AgentAction.FAILURE:
        p.reset()
        a.failure_reset(SpikingAgent.cycle_period)
        window = np.roll(window, 1)
        window[0] = steps_per_episode
        steps_per_episode = 0
    elif action == AgentAction.RIGHT:
        force = 10
    elif action == AgentAction.LEFT:
        force = -10
    else:
        assert False, "Unknown action returned"
    
    theta, x = p.update(force, mouse_x)

    if plot_spikes and (steps_per_episode % 10 == 0):
        plot.update(plot_data)
        #plot_spikes = False
    
    a.save_prev_syn_wtr()
    
    syn_to_left = nest.GetConnections(source=a.input_population, target=a.output_population_left)
    syn_to_right = nest.GetConnections(source=a.input_population, target=a.output_population_right)
    for _syn in [syn_to_left, syn_to_right]:
        _syn.wtr = 0.
        _syn.pre_trace = 0.
        #_syn.post_trace = 0. # need to do this in postsyn. neuron partner...
    
    a.output_population_left.post_trace__for_neuromodulated_stdp_synapse_nestml = 0.
    a.output_population_right.post_trace__for_neuromodulated_stdp_synapse_nestml = 0.
        
    
    
    if np.mean(window) >= avg_lifetime or toggle_sim:
        r.draw_clear()
        r.draw_ground(0.2, "grey")
        r.draw_car(x)
        r.draw_pole(x, theta, 2*p.l, 0.02)
        r.draw_stats(theta*180/np.pi, p.w*180/np.pi, x, p.v, a.get_episode(),
                     a.output_population_spike_recorder_left.n_events, 
                     a.output_population_spike_recorder_right.n_events,
                     a.dopamine_left,
                     a.dopamine_right,
                     action)
        r.display()

        clock.tick(50)  # limits FPS to 50

Automatic pdb calling has been turned OFF
Dimension of input space: (1, 6, 1, 20)
last chosen = left
last chosen = left
last chosen = right
last chosen = right
last chosen = right
last chosen = right
last chosen = right
last chosen = right
last chosen = left
last chosen = left
last chosen = right
last chosen = left
last chosen = right
last chosen = left
last chosen = right
last chosen = right
last chosen = left
last chosen = left
last chosen = right
last chosen = right
last chosen = left
last chosen = right
last chosen = right
last chosen = right
last chosen = left
last chosen = right
last chosen = right
last chosen = left
last chosen = right
last chosen = right
last chosen = right
last chosen = left
last chosen = left
last chosen = right
last chosen = right
last chosen = right
last chosen = left
last chosen = right
last chosen = right
last chosen = right
last chosen = right
last chosen = left
last chosen = left
last chosen = right
last chosen = right
last chosen = right
last chosen = 

         2542751 function calls (2508685 primitive calls) in 7.673 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    11999    1.943    0.000    1.943    0.000 ll_api.py:81(catching_sli_run)
      852    1.547    0.002    2.432    0.003 hl_api_types.py:804(get)
      480    0.524    0.001    1.458    0.003 hl_api_types.py:902(set)
      219    0.441    0.002    0.452    0.002 backend_agg.py:93(draw_path)
       12    0.391    0.033    0.391    0.033 {method 'encode' of 'ImagingEncoder' objects}
      852    0.282    0.000    0.356    0.000 hl_api_helper.py:447(restructure_data)
     86/0    0.245    0.003    0.000          selectors.py:558(select)
     86/0    0.210    0.002    0.000          {method 'control' of 'select.kqueue' objects}
  266/227    0.198    0.001    0.191    0.001 socket.py:632(send)
     95/0    0.183    0.002    0.000          {built-in method pygame.event.get}
      950    0.169    0.000    0.169    0

## Citations

[1] Liu Y, Pan W. Spiking Neural-Networks-Based Data-Driven Control. Electronics. 2023; 12(2):310. https://doi.org/10.3390/electronics12020310 

## Acknowledgements

The authors would like to thank Prof. Wei Pan and Dr. Yuxiang Liu for kindly providing ...

In [None]:
a
