Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example demonstrating spike-based simulations of Pong #2393

Merged
merged 28 commits into from
Sep 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
1349cfa
add pong example
krachbumm3nte May 9, 2022
05b10e6
Updated authorship for pong.py
krachbumm3nte May 10, 2022
5f7490d
Merge branch 'master' of https://github.com/nest/nest-simulator into …
krachbumm3nte May 17, 2022
4194ee7
Merge branch 'nest-master' into pong_pynest_example
krachbumm3nte May 17, 2022
bc65eea
update docstrings, refactor variables in generate_gif.py
krachbumm3nte May 20, 2022
9d53f8f
update docstrings, refactor variables in networks.py
krachbumm3nte May 20, 2022
9720dcf
update docstrings, refactor variables in pong.py
krachbumm3nte May 20, 2022
f0c6022
format pong.py
krachbumm3nte May 20, 2022
4135714
update docstrings in run_simulations.py
krachbumm3nte May 20, 2022
4814a23
remove trailing whitespaces in pong.py
krachbumm3nte May 20, 2022
13b61d3
remove trailing whitespaces
krachbumm3nte May 20, 2022
1fbc4fd
change error handling in generate_gif.py
krachbumm3nte May 21, 2022
657a5d3
Merge branch 'nest:master' into pong_pynest_example
krachbumm3nte Jun 24, 2022
927d4b0
Update generate_gif.py for flexible resolution
krachbumm3nte Jun 24, 2022
108aff6
Merge branch 'pong_pynest_example' of https://github.com/krachbumm3nt…
krachbumm3nte Jun 24, 2022
0b59ed2
Refactor variables, fix formatting
krachbumm3nte Jun 24, 2022
552cf6e
Revert changes introduced in #2384
krachbumm3nte Jun 24, 2022
160b5ea
Merge nest-master into pong_pynest_example
krachbumm3nte Aug 25, 2022
ca31dbf
add pong to example grid, move gif to static folder, update scaling f…
jessica-mitchell Aug 26, 2022
f5d129c
Apply suggestions from code review
krachbumm3nte Aug 27, 2022
2f9f737
Apply suggestions from code review
krachbumm3nte Aug 27, 2022
1c27dbb
Merge branch 'nest:master' into pong_pynest_example
krachbumm3nte Sep 7, 2022
65911a3
Merge pull request #3 from jessica-mitchell/add-pong-index
krachbumm3nte Sep 12, 2022
95ce901
Change gif generation script from PIL to matplotlib as required by #2441
krachbumm3nte Sep 12, 2022
6c81621
Pretty much just minor changes and efforts to make the whole codebase…
JanVogelsang Sep 26, 2022
72d5f3a
Merge pull request #4 from JanVogelsang/pong_pynest_example
krachbumm3nte Sep 27, 2022
8f04c24
fix indentation
krachbumm3nte Sep 27, 2022
ba12bd5
Update readme.rst
krachbumm3nte Sep 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions doc/htmldoc/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ PyNEST examples
* :doc:`../auto_examples/spatial/test_3d_gauss`


.. grid:: 1 1 2 3

.. grid-item-card:: NEST Sudoku solver
:img-top: ../static/img/sudoku_solution.gif

* :doc:`../auto_examples/sudoku/sudoku_net`
* :doc:`../auto_examples/sudoku/sudoku_solver`
* :doc:`../auto_examples/sudoku/plot_progress`

.. grid-item-card:: NEST Pong game
:img-top: ../static/img/pong_sim.gif

* :doc:`../auto_examples/pong/run_simulations`
* :doc:`../auto_examples/pong/pong`
* :doc:`../auto_examples/pong/generate_gif`

.. grid:: 1 1 2 3

.. grid-item-card:: Random balanced networks (Brunel)
Expand All @@ -52,20 +68,18 @@ PyNEST examples
* :doc:`../auto_examples/brunel_alpha_evolution_strategies`




.. grid-item-card:: Cortical microcircuit (Potjans)
:img-top: ../static/img/pynest/raster_plot.png

* :doc:`cortical_microcircuit_index`

.. grid-item-card:: GLIF (from Allen institute)
:img-top: ../static/img/pynest/glif_cond.png

* :doc:`../auto_examples/glif_cond_neuron`
* :doc:`../auto_examples/glif_psc_neuron`

.. grid-item-card:: NEST Sudoku solver
:img-top: ../static/img/sudoku_solution.gif

* :doc:`../auto_examples/sudoku/sudoku_net`
* :doc:`../auto_examples/sudoku/sudoku_solver`
* :doc:`../auto_examples/sudoku/plot_progress`

.. grid:: 1 1 2 3

Expand Down Expand Up @@ -105,21 +119,14 @@ PyNEST examples

* :doc:`../auto_examples/BrodyHopfield`


.. grid-item-card:: GLIF (from Allen institute)
:img-top: ../static/img/pynest/glif_cond.png

* :doc:`../auto_examples/glif_cond_neuron`
* :doc:`../auto_examples/glif_psc_neuron`

.. grid:: 1 1 2 3

.. grid-item-card:: Brette and Gerstner
:img-top: ../static/img/pynest/brette_gerstner2c.png

* :doc:`../auto_examples/brette_gerstner_fig_2c`
* :doc:`../auto_examples/brette_gerstner_fig_3d`

.. grid:: 1 1 2 3


.. grid-item-card:: Precise spiking
:img-top: ../static/img/pynest/precisespiking.png
Expand Down Expand Up @@ -306,3 +313,9 @@ PyNEST examples
../auto_examples/sudoku/sudoku_net
../auto_examples/sudoku/sudoku_solver
../auto_examples/sudoku/plot_progress

.. toctree::
:hidden:
../auto_examples/pong/run_simulations
../auto_examples/pong/pong
../auto_examples/pong/generate_gif
Binary file added doc/htmldoc/static/img/pong_sim.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/htmldoc/static/img/sudoku_solution.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions pynest/examples/pong/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
NEST-pong
=========
This program simultaneously trains two networks of spiking neurons to play
the classic game of Pong.

Requirements
------------
- NEST 3.3
- NumPy
- Matplotlib

Instructions
------------
To start training between two networks with R-STDP plasticity, run
the ``generate_gif.py`` script. By default, one of the networks will
be stimulated with Gaussian white noise, showing that this is necessary
for learning under this paradigm. In addition to R-STDP, a learning rule
based on the ``stdp_dopamine_synapse`` and temporal difference learning
is implemented, see ``networks.py`` for details.

The learning progress and resulting game can be visualized with the
``generate_gif.py`` script.
290 changes: 290 additions & 0 deletions pynest/examples/pong/generate_gif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# -*- coding: utf-8 -*-
#
# generate_gif.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

r"""Script to visualize a simulated Pong game.
----------------------------------------------------------------
All simulations store data about both networks and the game in .pkl files.
This script reads these files and generates image snapshots at different
times during the simulation. These are subsequently aggregated into a GIF.

:Authors: J Gille, T Wunderlich, Electronic Vision(s)
"""

from copy import copy
import gzip
import os
import sys

import numpy as np
import pickle
import matplotlib.pyplot as plt
import imageio.v2 as imageio
from glob import glob

from pong import GameOfPong as Pong

px = 1 / plt.rcParams['figure.dpi']
plt.subplots(figsize=(400 * px, 300 * px))
plt.rcParams.update({'font.size': 6})

gridsize = (12, 16) # Shape of the grid used for positioning subplots

left_color = np.array((204, 0, 153)) # purple
right_color = np.array((255, 128, 0)) # orange
left_color_hex = "#cc0099"
right_color_hex = "#ff8000"
white = np.array((255, 255, 255))

# Original size of the playing field inside the simulation
GAME_GRID = np.array([Pong.x_grid, Pong.y_grid])
GRID_SCALE = 24
# Field size (in px) after upscaling
GAME_GRID_SCALED = GAME_GRID * GRID_SCALE

# Dimensions of game objects in px
BALL_RAD = 6
PADDLE_LEN = int(0.1 * GAME_GRID_SCALED[1])
PADDLE_WID = 18

# Add margins left and right to the playing field
FIELD_PADDING = PADDLE_WID * 2
FIELD_SIZE = copy(GAME_GRID_SCALED)
FIELD_SIZE[0] += 2 * FIELD_PADDING

# At default, the GIF shows every DEFAULT_SPEEDth simulation step.
DEFAULT_SPEED = 4


def scale_coordinates(coordinates: np.array):
"""Scale a numpy.array of coordinate tuples (x,y) from simulation scale to
pixel scale in the output image.

Args:
pos (float, float): input coordinates to be scaled.

Returns:
(int, int): output coordinates in px
"""
coordinates[:, 0] = coordinates[:, 0] * \
GAME_GRID_SCALED[0] / Pong.x_length + FIELD_PADDING
coordinates[:, 1] = coordinates[:, 1] * \
GAME_GRID_SCALED[1] / Pong.y_length
return coordinates.astype(int)


def grayscale_to_heatmap(in_image, min_val, max_val, base_color):
"""Transform a grayscale image to an RGB heat map. Heatmap will color small
values in base_color and high values in white.

Args:
in_image (numpy.array): 2D numpy.array to be transformed.
min_val (float): smallest value across the entire image - colored in
base_color in the output.
max_val (float): largest value across the entire image - colored
white in the output.
base_color (numpy.array): numpy.array of shape (3,) representing the
base color of the heatmap in RGB.
Returns:
numpy.array: transformed input array with an added 3rd dimension of
length 3, representing RGB values.
"""

x_len, y_len = in_image.shape
out_image = np.ones((x_len, y_len, 3), dtype=np.uint8)

span = max_val - min_val
# Edge case for uniform weight matrix
if span == 0:
return out_image * base_color

for x in range(x_len):
for y in range(y_len):
color_scaled = (in_image[x, y] - min_val) / span
out_image[x, y, :] = base_color + (white - base_color) * color_scaled

return out_image


if __name__ == "__main__":
keep_temps = False
out_file = "pong_sim.gif"

if len(sys.argv) != 2:
print("This programm takes exactly one argument - the location of the "
"output folder generated by the simulation.")
sys.exit(1)
input_folder = sys.argv[1]

if os.path.exists(out_file):
print(f"<{out_file}> already exists, aborting!")
sys.exit(1)

temp_dir = "temp"
if os.path.exists(temp_dir):
print(f"Output folder <{temp_dir}> already exists, aborting!")
sys.exit(1)
else:
os.mkdir(temp_dir)

print(f"Reading simulation data from {input_folder}...")
with open(os.path.join(input_folder, "gamestate.pkl"), 'rb') as f:
game_data = pickle.load(f)

ball_positions = scale_coordinates(np.array(game_data["ball_pos"]))
l_paddle_positions = scale_coordinates(np.array(game_data["left_paddle"]))
# Move left paddle outwards for symmetry
l_paddle_positions[:, 0] -= PADDLE_WID
r_paddle_positions = scale_coordinates(np.array(game_data["right_paddle"]))

score = np.array(game_data["score"]).astype(int)

with gzip.open(os.path.join(input_folder, "data_left.pkl.gz"), 'r') as f:
data = pickle.load(f)
rewards_left = data["rewards"]
weights_left = data["weights"]
name_left = data["network_type"]

with gzip.open(os.path.join(input_folder, "data_right.pkl.gz"), 'r') as f:
data = pickle.load(f)
rewards_right = data["rewards"]
weights_right = data["weights"]
name_right = data["network_type"]

# Extract lowest and highest weights for both players to scale the heatmaps.
min_r, max_r = np.min(weights_right), np.max(weights_right)
min_l, max_l = np.min(weights_left), np.max(weights_left)

# Average rewards at every iteration over all neurons
rewards_left = [np.mean(x) for x in rewards_left]
rewards_right = [np.mean(x) for x in rewards_right]

print(f"Setup complete, generating images to '{temp_dir}'...")
n_iterations = score.shape[0]
i = 0
output_speed = DEFAULT_SPEED

while i < n_iterations:
# Set up the grid containing all components of the output image
title = plt.subplot2grid(gridsize, (0, 0), 1, 16)
l_info = plt.subplot2grid(gridsize, (1, 0), 7, 2)
r_info = plt.subplot2grid(gridsize, (1, 14), 7, 2)
field = plt.subplot2grid(gridsize, (1, 2), 7, 12)
l_hm = plt.subplot2grid(gridsize, (8, 0), 4, 4)
reward_plot = plt.subplot2grid(gridsize, (8, 6), 4, 6)
r_hm = plt.subplot2grid(gridsize, (8, 12), 4, 4)

for ax in [title, l_info, r_info, field, l_hm, r_hm]:
ax.axis("off")

# Create an empty array for the playing field.
playing_field = np.zeros(
(FIELD_SIZE[0], FIELD_SIZE[1], 3), dtype=np.uint8)

# Draw the ball in white
x, y = ball_positions[i]
playing_field[x - BALL_RAD:x + BALL_RAD, y - BALL_RAD:y + BALL_RAD] = white
for (x, y), color in zip([l_paddle_positions[i], r_paddle_positions[i]],
[left_color, right_color]):
# Clip y coordinate of the paddle so it does not exceed the screen
y = max(PADDLE_LEN, y)
y = min(FIELD_SIZE[1] - PADDLE_LEN, y)
playing_field[x:x + PADDLE_WID, y - PADDLE_LEN:y + PADDLE_LEN] = color

field.imshow(np.transpose(playing_field, [1, 0, 2]))

# Left player heatmap
heatmap_l = grayscale_to_heatmap(weights_left[i], min_l,
max_l, left_color)
l_hm.imshow(heatmap_l)
l_hm.set_xlabel("output")
l_hm.set_ylabel("input")
l_hm.set_title("weights", y=-0.3)

# Right player heatmap
heatmap_r = grayscale_to_heatmap(weights_right[i], min_r,
max_r, right_color)
r_hm.imshow(heatmap_r)
r_hm.set_xlabel("output")
r_hm.set_ylabel("input")
r_hm.set_title("weights", y=-0.3)

reward_plot.plot([0, i], [-1, -1])
reward_plot.plot(rewards_right[:i + 1], color=right_color / 255)
reward_plot.plot(rewards_left[:i + 1], color=left_color / 255)

# Change x_ticks and x_min for the first few plots
if i < 1600:
x_min = 0
reward_plot.set_xticks(np.arange(0, n_iterations, 250))
else:
x_min = i - 1600
reward_plot.set_xticks(np.arange(0, n_iterations, 500))

reward_plot.set_ylabel("mean reward")
reward_plot.set_yticks([0, 0.5, 1])
reward_plot.set_ylim(0, 1.0)
reward_plot.set_xlim(x_min, i + 10)

title.text(0.4, 0.75, name_left, ha='right', fontsize=15, c=left_color_hex)
title.text(0.5, 0.75, "VS", ha='center', fontsize=17)
title.text(0.6, 0.75, name_right, ha='left', fontsize=15, c=right_color_hex)

l_score, r_score = score[i]

l_info.text(0, 0.9, "run:", fontsize=14)
l_info.text(0, 0.75, str(i), fontsize=14)
l_info.text(1, 0.5, l_score, ha='right', va='center', fontsize=26, c=left_color_hex)

r_info.text(0, 0.9, "speed:", fontsize=14)
r_info.text(0, 0.75, str(output_speed) + 'x', fontsize=14)
r_info.text(0, 0.5, r_score, ha='left', va='center', fontsize=26, c=right_color_hex)

plt.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.35, hspace=0.35)
plt.savefig(os.path.join(temp_dir, f"img_{str(i).zfill(6)}.png"))

# Change the speed of the video to show performance before and after
# training at DEFAULT_SPEED and fast-forward most of the training
if 75 <= i < 100 or n_iterations - 400 <= i < n_iterations - 350:
output_speed = 10
elif 100 <= i < n_iterations - 350:
output_speed = 50
else:
output_speed = DEFAULT_SPEED

i += output_speed

print("Image creation complete, collecting them into a GIF...")

filenames = sorted(glob(os.path.join(temp_dir, "*.png")))

with imageio.get_writer(out_file, mode='I', fps=6) as writer:
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
print(f"GIF created under: {out_file}")

if not keep_temps:
print("Deleting temporary image files...")
for in_file in filenames:
os.unlink(in_file)
os.rmdir(temp_dir)

print("Done.")
Loading