Skip to content

Commit

Permalink
Merge pull request #2393 from krachbumm3nte/pong_pynest_example
Browse files Browse the repository at this point in the history
Add example demonstrating spike-based simulations of Pong
  • Loading branch information
hakonsbm authored Sep 28, 2022
2 parents 66255be + ba12bd5 commit 8028a79
Show file tree
Hide file tree
Showing 14 changed files with 1,276 additions and 16 deletions.
45 changes: 29 additions & 16 deletions doc/htmldoc/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ PyNEST examples
* :doc:`../auto_examples/spatial/test_3d_gauss`


.. grid:: 1 1 2 3

.. grid-item-card:: NEST Sudoku solver
:img-top: ../static/img/sudoku_solution.gif

* :doc:`../auto_examples/sudoku/sudoku_net`
* :doc:`../auto_examples/sudoku/sudoku_solver`
* :doc:`../auto_examples/sudoku/plot_progress`

.. grid-item-card:: NEST Pong game
:img-top: ../static/img/pong_sim.gif

* :doc:`../auto_examples/pong/run_simulations`
* :doc:`../auto_examples/pong/pong`
* :doc:`../auto_examples/pong/generate_gif`

.. grid:: 1 1 2 3

.. grid-item-card:: Random balanced networks (Brunel)
Expand All @@ -52,20 +68,18 @@ PyNEST examples
* :doc:`../auto_examples/brunel_alpha_evolution_strategies`




.. grid-item-card:: Cortical microcircuit (Potjans)
:img-top: ../static/img/pynest/raster_plot.png

* :doc:`cortical_microcircuit_index`

.. grid-item-card:: GLIF (from Allen institute)
:img-top: ../static/img/pynest/glif_cond.png

* :doc:`../auto_examples/glif_cond_neuron`
* :doc:`../auto_examples/glif_psc_neuron`

.. grid-item-card:: NEST Sudoku solver
:img-top: ../static/img/sudoku_solution.gif

* :doc:`../auto_examples/sudoku/sudoku_net`
* :doc:`../auto_examples/sudoku/sudoku_solver`
* :doc:`../auto_examples/sudoku/plot_progress`

.. grid:: 1 1 2 3

Expand Down Expand Up @@ -105,21 +119,14 @@ PyNEST examples

* :doc:`../auto_examples/BrodyHopfield`


.. grid-item-card:: GLIF (from Allen institute)
:img-top: ../static/img/pynest/glif_cond.png

* :doc:`../auto_examples/glif_cond_neuron`
* :doc:`../auto_examples/glif_psc_neuron`

.. grid:: 1 1 2 3

.. grid-item-card:: Brette and Gerstner
:img-top: ../static/img/pynest/brette_gerstner2c.png

* :doc:`../auto_examples/brette_gerstner_fig_2c`
* :doc:`../auto_examples/brette_gerstner_fig_3d`

.. grid:: 1 1 2 3


.. grid-item-card:: Precise spiking
:img-top: ../static/img/pynest/precisespiking.png
Expand Down Expand Up @@ -306,3 +313,9 @@ PyNEST examples
../auto_examples/sudoku/sudoku_net
../auto_examples/sudoku/sudoku_solver
../auto_examples/sudoku/plot_progress

.. toctree::
:hidden:
../auto_examples/pong/run_simulations
../auto_examples/pong/pong
../auto_examples/pong/generate_gif
Binary file added doc/htmldoc/static/img/pong_sim.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified doc/htmldoc/static/img/sudoku_solution.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
22 changes: 22 additions & 0 deletions pynest/examples/pong/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
NEST-pong
=========
This program simultaneously trains two networks of spiking neurons to play
the classic game of Pong.

Requirements
------------
- NEST 3.3
- NumPy
- Matplotlib

Instructions
------------
To start training between two networks with R-STDP plasticity, run
the ``generate_gif.py`` script. By default, one of the networks will
be stimulated with Gaussian white noise, showing that this is necessary
for learning under this paradigm. In addition to R-STDP, a learning rule
based on the ``stdp_dopamine_synapse`` and temporal difference learning
is implemented, see ``networks.py`` for details.

The learning progress and resulting game can be visualized with the
``generate_gif.py`` script.
290 changes: 290 additions & 0 deletions pynest/examples/pong/generate_gif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
# -*- coding: utf-8 -*-
#
# generate_gif.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

r"""Script to visualize a simulated Pong game.
----------------------------------------------------------------
All simulations store data about both networks and the game in .pkl files.
This script reads these files and generates image snapshots at different
times during the simulation. These are subsequently aggregated into a GIF.
:Authors: J Gille, T Wunderlich, Electronic Vision(s)
"""

from copy import copy
import gzip
import os
import sys

import numpy as np
import pickle
import matplotlib.pyplot as plt
import imageio.v2 as imageio
from glob import glob

from pong import GameOfPong as Pong

px = 1 / plt.rcParams['figure.dpi']
plt.subplots(figsize=(400 * px, 300 * px))
plt.rcParams.update({'font.size': 6})

gridsize = (12, 16) # Shape of the grid used for positioning subplots

left_color = np.array((204, 0, 153)) # purple
right_color = np.array((255, 128, 0)) # orange
left_color_hex = "#cc0099"
right_color_hex = "#ff8000"
white = np.array((255, 255, 255))

# Original size of the playing field inside the simulation
GAME_GRID = np.array([Pong.x_grid, Pong.y_grid])
GRID_SCALE = 24
# Field size (in px) after upscaling
GAME_GRID_SCALED = GAME_GRID * GRID_SCALE

# Dimensions of game objects in px
BALL_RAD = 6
PADDLE_LEN = int(0.1 * GAME_GRID_SCALED[1])
PADDLE_WID = 18

# Add margins left and right to the playing field
FIELD_PADDING = PADDLE_WID * 2
FIELD_SIZE = copy(GAME_GRID_SCALED)
FIELD_SIZE[0] += 2 * FIELD_PADDING

# At default, the GIF shows every DEFAULT_SPEEDth simulation step.
DEFAULT_SPEED = 4


def scale_coordinates(coordinates: np.array):
"""Scale a numpy.array of coordinate tuples (x,y) from simulation scale to
pixel scale in the output image.
Args:
pos (float, float): input coordinates to be scaled.
Returns:
(int, int): output coordinates in px
"""
coordinates[:, 0] = coordinates[:, 0] * \
GAME_GRID_SCALED[0] / Pong.x_length + FIELD_PADDING
coordinates[:, 1] = coordinates[:, 1] * \
GAME_GRID_SCALED[1] / Pong.y_length
return coordinates.astype(int)


def grayscale_to_heatmap(in_image, min_val, max_val, base_color):
"""Transform a grayscale image to an RGB heat map. Heatmap will color small
values in base_color and high values in white.
Args:
in_image (numpy.array): 2D numpy.array to be transformed.
min_val (float): smallest value across the entire image - colored in
base_color in the output.
max_val (float): largest value across the entire image - colored
white in the output.
base_color (numpy.array): numpy.array of shape (3,) representing the
base color of the heatmap in RGB.
Returns:
numpy.array: transformed input array with an added 3rd dimension of
length 3, representing RGB values.
"""

x_len, y_len = in_image.shape
out_image = np.ones((x_len, y_len, 3), dtype=np.uint8)

span = max_val - min_val
# Edge case for uniform weight matrix
if span == 0:
return out_image * base_color

for x in range(x_len):
for y in range(y_len):
color_scaled = (in_image[x, y] - min_val) / span
out_image[x, y, :] = base_color + (white - base_color) * color_scaled

return out_image


if __name__ == "__main__":
keep_temps = False
out_file = "pong_sim.gif"

if len(sys.argv) != 2:
print("This programm takes exactly one argument - the location of the "
"output folder generated by the simulation.")
sys.exit(1)
input_folder = sys.argv[1]

if os.path.exists(out_file):
print(f"<{out_file}> already exists, aborting!")
sys.exit(1)

temp_dir = "temp"
if os.path.exists(temp_dir):
print(f"Output folder <{temp_dir}> already exists, aborting!")
sys.exit(1)
else:
os.mkdir(temp_dir)

print(f"Reading simulation data from {input_folder}...")
with open(os.path.join(input_folder, "gamestate.pkl"), 'rb') as f:
game_data = pickle.load(f)

ball_positions = scale_coordinates(np.array(game_data["ball_pos"]))
l_paddle_positions = scale_coordinates(np.array(game_data["left_paddle"]))
# Move left paddle outwards for symmetry
l_paddle_positions[:, 0] -= PADDLE_WID
r_paddle_positions = scale_coordinates(np.array(game_data["right_paddle"]))

score = np.array(game_data["score"]).astype(int)

with gzip.open(os.path.join(input_folder, "data_left.pkl.gz"), 'r') as f:
data = pickle.load(f)
rewards_left = data["rewards"]
weights_left = data["weights"]
name_left = data["network_type"]

with gzip.open(os.path.join(input_folder, "data_right.pkl.gz"), 'r') as f:
data = pickle.load(f)
rewards_right = data["rewards"]
weights_right = data["weights"]
name_right = data["network_type"]

# Extract lowest and highest weights for both players to scale the heatmaps.
min_r, max_r = np.min(weights_right), np.max(weights_right)
min_l, max_l = np.min(weights_left), np.max(weights_left)

# Average rewards at every iteration over all neurons
rewards_left = [np.mean(x) for x in rewards_left]
rewards_right = [np.mean(x) for x in rewards_right]

print(f"Setup complete, generating images to '{temp_dir}'...")
n_iterations = score.shape[0]
i = 0
output_speed = DEFAULT_SPEED

while i < n_iterations:
# Set up the grid containing all components of the output image
title = plt.subplot2grid(gridsize, (0, 0), 1, 16)
l_info = plt.subplot2grid(gridsize, (1, 0), 7, 2)
r_info = plt.subplot2grid(gridsize, (1, 14), 7, 2)
field = plt.subplot2grid(gridsize, (1, 2), 7, 12)
l_hm = plt.subplot2grid(gridsize, (8, 0), 4, 4)
reward_plot = plt.subplot2grid(gridsize, (8, 6), 4, 6)
r_hm = plt.subplot2grid(gridsize, (8, 12), 4, 4)

for ax in [title, l_info, r_info, field, l_hm, r_hm]:
ax.axis("off")

# Create an empty array for the playing field.
playing_field = np.zeros(
(FIELD_SIZE[0], FIELD_SIZE[1], 3), dtype=np.uint8)

# Draw the ball in white
x, y = ball_positions[i]
playing_field[x - BALL_RAD:x + BALL_RAD, y - BALL_RAD:y + BALL_RAD] = white
for (x, y), color in zip([l_paddle_positions[i], r_paddle_positions[i]],
[left_color, right_color]):
# Clip y coordinate of the paddle so it does not exceed the screen
y = max(PADDLE_LEN, y)
y = min(FIELD_SIZE[1] - PADDLE_LEN, y)
playing_field[x:x + PADDLE_WID, y - PADDLE_LEN:y + PADDLE_LEN] = color

field.imshow(np.transpose(playing_field, [1, 0, 2]))

# Left player heatmap
heatmap_l = grayscale_to_heatmap(weights_left[i], min_l,
max_l, left_color)
l_hm.imshow(heatmap_l)
l_hm.set_xlabel("output")
l_hm.set_ylabel("input")
l_hm.set_title("weights", y=-0.3)

# Right player heatmap
heatmap_r = grayscale_to_heatmap(weights_right[i], min_r,
max_r, right_color)
r_hm.imshow(heatmap_r)
r_hm.set_xlabel("output")
r_hm.set_ylabel("input")
r_hm.set_title("weights", y=-0.3)

reward_plot.plot([0, i], [-1, -1])
reward_plot.plot(rewards_right[:i + 1], color=right_color / 255)
reward_plot.plot(rewards_left[:i + 1], color=left_color / 255)

# Change x_ticks and x_min for the first few plots
if i < 1600:
x_min = 0
reward_plot.set_xticks(np.arange(0, n_iterations, 250))
else:
x_min = i - 1600
reward_plot.set_xticks(np.arange(0, n_iterations, 500))

reward_plot.set_ylabel("mean reward")
reward_plot.set_yticks([0, 0.5, 1])
reward_plot.set_ylim(0, 1.0)
reward_plot.set_xlim(x_min, i + 10)

title.text(0.4, 0.75, name_left, ha='right', fontsize=15, c=left_color_hex)
title.text(0.5, 0.75, "VS", ha='center', fontsize=17)
title.text(0.6, 0.75, name_right, ha='left', fontsize=15, c=right_color_hex)

l_score, r_score = score[i]

l_info.text(0, 0.9, "run:", fontsize=14)
l_info.text(0, 0.75, str(i), fontsize=14)
l_info.text(1, 0.5, l_score, ha='right', va='center', fontsize=26, c=left_color_hex)

r_info.text(0, 0.9, "speed:", fontsize=14)
r_info.text(0, 0.75, str(output_speed) + 'x', fontsize=14)
r_info.text(0, 0.5, r_score, ha='left', va='center', fontsize=26, c=right_color_hex)

plt.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.35, hspace=0.35)
plt.savefig(os.path.join(temp_dir, f"img_{str(i).zfill(6)}.png"))

# Change the speed of the video to show performance before and after
# training at DEFAULT_SPEED and fast-forward most of the training
if 75 <= i < 100 or n_iterations - 400 <= i < n_iterations - 350:
output_speed = 10
elif 100 <= i < n_iterations - 350:
output_speed = 50
else:
output_speed = DEFAULT_SPEED

i += output_speed

print("Image creation complete, collecting them into a GIF...")

filenames = sorted(glob(os.path.join(temp_dir, "*.png")))

with imageio.get_writer(out_file, mode='I', fps=6) as writer:
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
print(f"GIF created under: {out_file}")

if not keep_temps:
print("Deleting temporary image files...")
for in_file in filenames:
os.unlink(in_file)
os.rmdir(temp_dir)

print("Done.")
Loading

0 comments on commit 8028a79

Please sign in to comment.