-
Notifications
You must be signed in to change notification settings - Fork 365
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2393 from krachbumm3nte/pong_pynest_example
Add example demonstrating spike-based simulations of Pong
- Loading branch information
Showing
14 changed files
with
1,276 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
NEST-pong | ||
========= | ||
This program simultaneously trains two networks of spiking neurons to play | ||
the classic game of Pong. | ||
|
||
Requirements | ||
------------ | ||
- NEST 3.3 | ||
- NumPy | ||
- Matplotlib | ||
|
||
Instructions | ||
------------ | ||
To start training between two networks with R-STDP plasticity, run | ||
the ``generate_gif.py`` script. By default, one of the networks will | ||
be stimulated with Gaussian white noise, showing that this is necessary | ||
for learning under this paradigm. In addition to R-STDP, a learning rule | ||
based on the ``stdp_dopamine_synapse`` and temporal difference learning | ||
is implemented, see ``networks.py`` for details. | ||
|
||
The learning progress and resulting game can be visualized with the | ||
``generate_gif.py`` script. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,290 @@ | ||
# -*- coding: utf-8 -*- | ||
# | ||
# generate_gif.py | ||
# | ||
# This file is part of NEST. | ||
# | ||
# Copyright (C) 2004 The NEST Initiative | ||
# | ||
# NEST is free software: you can redistribute it and/or modify | ||
# it under the terms of the GNU General Public License as published by | ||
# the Free Software Foundation, either version 2 of the License, or | ||
# (at your option) any later version. | ||
# | ||
# NEST is distributed in the hope that it will be useful, | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
# GNU General Public License for more details. | ||
# | ||
# You should have received a copy of the GNU General Public License | ||
# along with NEST. If not, see <http://www.gnu.org/licenses/>. | ||
|
||
r"""Script to visualize a simulated Pong game. | ||
---------------------------------------------------------------- | ||
All simulations store data about both networks and the game in .pkl files. | ||
This script reads these files and generates image snapshots at different | ||
times during the simulation. These are subsequently aggregated into a GIF. | ||
:Authors: J Gille, T Wunderlich, Electronic Vision(s) | ||
""" | ||
|
||
from copy import copy | ||
import gzip | ||
import os | ||
import sys | ||
|
||
import numpy as np | ||
import pickle | ||
import matplotlib.pyplot as plt | ||
import imageio.v2 as imageio | ||
from glob import glob | ||
|
||
from pong import GameOfPong as Pong | ||
|
||
px = 1 / plt.rcParams['figure.dpi'] | ||
plt.subplots(figsize=(400 * px, 300 * px)) | ||
plt.rcParams.update({'font.size': 6}) | ||
|
||
gridsize = (12, 16) # Shape of the grid used for positioning subplots | ||
|
||
left_color = np.array((204, 0, 153)) # purple | ||
right_color = np.array((255, 128, 0)) # orange | ||
left_color_hex = "#cc0099" | ||
right_color_hex = "#ff8000" | ||
white = np.array((255, 255, 255)) | ||
|
||
# Original size of the playing field inside the simulation | ||
GAME_GRID = np.array([Pong.x_grid, Pong.y_grid]) | ||
GRID_SCALE = 24 | ||
# Field size (in px) after upscaling | ||
GAME_GRID_SCALED = GAME_GRID * GRID_SCALE | ||
|
||
# Dimensions of game objects in px | ||
BALL_RAD = 6 | ||
PADDLE_LEN = int(0.1 * GAME_GRID_SCALED[1]) | ||
PADDLE_WID = 18 | ||
|
||
# Add margins left and right to the playing field | ||
FIELD_PADDING = PADDLE_WID * 2 | ||
FIELD_SIZE = copy(GAME_GRID_SCALED) | ||
FIELD_SIZE[0] += 2 * FIELD_PADDING | ||
|
||
# At default, the GIF shows every DEFAULT_SPEEDth simulation step. | ||
DEFAULT_SPEED = 4 | ||
|
||
|
||
def scale_coordinates(coordinates: np.array): | ||
"""Scale a numpy.array of coordinate tuples (x,y) from simulation scale to | ||
pixel scale in the output image. | ||
Args: | ||
pos (float, float): input coordinates to be scaled. | ||
Returns: | ||
(int, int): output coordinates in px | ||
""" | ||
coordinates[:, 0] = coordinates[:, 0] * \ | ||
GAME_GRID_SCALED[0] / Pong.x_length + FIELD_PADDING | ||
coordinates[:, 1] = coordinates[:, 1] * \ | ||
GAME_GRID_SCALED[1] / Pong.y_length | ||
return coordinates.astype(int) | ||
|
||
|
||
def grayscale_to_heatmap(in_image, min_val, max_val, base_color): | ||
"""Transform a grayscale image to an RGB heat map. Heatmap will color small | ||
values in base_color and high values in white. | ||
Args: | ||
in_image (numpy.array): 2D numpy.array to be transformed. | ||
min_val (float): smallest value across the entire image - colored in | ||
base_color in the output. | ||
max_val (float): largest value across the entire image - colored | ||
white in the output. | ||
base_color (numpy.array): numpy.array of shape (3,) representing the | ||
base color of the heatmap in RGB. | ||
Returns: | ||
numpy.array: transformed input array with an added 3rd dimension of | ||
length 3, representing RGB values. | ||
""" | ||
|
||
x_len, y_len = in_image.shape | ||
out_image = np.ones((x_len, y_len, 3), dtype=np.uint8) | ||
|
||
span = max_val - min_val | ||
# Edge case for uniform weight matrix | ||
if span == 0: | ||
return out_image * base_color | ||
|
||
for x in range(x_len): | ||
for y in range(y_len): | ||
color_scaled = (in_image[x, y] - min_val) / span | ||
out_image[x, y, :] = base_color + (white - base_color) * color_scaled | ||
|
||
return out_image | ||
|
||
|
||
if __name__ == "__main__": | ||
keep_temps = False | ||
out_file = "pong_sim.gif" | ||
|
||
if len(sys.argv) != 2: | ||
print("This programm takes exactly one argument - the location of the " | ||
"output folder generated by the simulation.") | ||
sys.exit(1) | ||
input_folder = sys.argv[1] | ||
|
||
if os.path.exists(out_file): | ||
print(f"<{out_file}> already exists, aborting!") | ||
sys.exit(1) | ||
|
||
temp_dir = "temp" | ||
if os.path.exists(temp_dir): | ||
print(f"Output folder <{temp_dir}> already exists, aborting!") | ||
sys.exit(1) | ||
else: | ||
os.mkdir(temp_dir) | ||
|
||
print(f"Reading simulation data from {input_folder}...") | ||
with open(os.path.join(input_folder, "gamestate.pkl"), 'rb') as f: | ||
game_data = pickle.load(f) | ||
|
||
ball_positions = scale_coordinates(np.array(game_data["ball_pos"])) | ||
l_paddle_positions = scale_coordinates(np.array(game_data["left_paddle"])) | ||
# Move left paddle outwards for symmetry | ||
l_paddle_positions[:, 0] -= PADDLE_WID | ||
r_paddle_positions = scale_coordinates(np.array(game_data["right_paddle"])) | ||
|
||
score = np.array(game_data["score"]).astype(int) | ||
|
||
with gzip.open(os.path.join(input_folder, "data_left.pkl.gz"), 'r') as f: | ||
data = pickle.load(f) | ||
rewards_left = data["rewards"] | ||
weights_left = data["weights"] | ||
name_left = data["network_type"] | ||
|
||
with gzip.open(os.path.join(input_folder, "data_right.pkl.gz"), 'r') as f: | ||
data = pickle.load(f) | ||
rewards_right = data["rewards"] | ||
weights_right = data["weights"] | ||
name_right = data["network_type"] | ||
|
||
# Extract lowest and highest weights for both players to scale the heatmaps. | ||
min_r, max_r = np.min(weights_right), np.max(weights_right) | ||
min_l, max_l = np.min(weights_left), np.max(weights_left) | ||
|
||
# Average rewards at every iteration over all neurons | ||
rewards_left = [np.mean(x) for x in rewards_left] | ||
rewards_right = [np.mean(x) for x in rewards_right] | ||
|
||
print(f"Setup complete, generating images to '{temp_dir}'...") | ||
n_iterations = score.shape[0] | ||
i = 0 | ||
output_speed = DEFAULT_SPEED | ||
|
||
while i < n_iterations: | ||
# Set up the grid containing all components of the output image | ||
title = plt.subplot2grid(gridsize, (0, 0), 1, 16) | ||
l_info = plt.subplot2grid(gridsize, (1, 0), 7, 2) | ||
r_info = plt.subplot2grid(gridsize, (1, 14), 7, 2) | ||
field = plt.subplot2grid(gridsize, (1, 2), 7, 12) | ||
l_hm = plt.subplot2grid(gridsize, (8, 0), 4, 4) | ||
reward_plot = plt.subplot2grid(gridsize, (8, 6), 4, 6) | ||
r_hm = plt.subplot2grid(gridsize, (8, 12), 4, 4) | ||
|
||
for ax in [title, l_info, r_info, field, l_hm, r_hm]: | ||
ax.axis("off") | ||
|
||
# Create an empty array for the playing field. | ||
playing_field = np.zeros( | ||
(FIELD_SIZE[0], FIELD_SIZE[1], 3), dtype=np.uint8) | ||
|
||
# Draw the ball in white | ||
x, y = ball_positions[i] | ||
playing_field[x - BALL_RAD:x + BALL_RAD, y - BALL_RAD:y + BALL_RAD] = white | ||
for (x, y), color in zip([l_paddle_positions[i], r_paddle_positions[i]], | ||
[left_color, right_color]): | ||
# Clip y coordinate of the paddle so it does not exceed the screen | ||
y = max(PADDLE_LEN, y) | ||
y = min(FIELD_SIZE[1] - PADDLE_LEN, y) | ||
playing_field[x:x + PADDLE_WID, y - PADDLE_LEN:y + PADDLE_LEN] = color | ||
|
||
field.imshow(np.transpose(playing_field, [1, 0, 2])) | ||
|
||
# Left player heatmap | ||
heatmap_l = grayscale_to_heatmap(weights_left[i], min_l, | ||
max_l, left_color) | ||
l_hm.imshow(heatmap_l) | ||
l_hm.set_xlabel("output") | ||
l_hm.set_ylabel("input") | ||
l_hm.set_title("weights", y=-0.3) | ||
|
||
# Right player heatmap | ||
heatmap_r = grayscale_to_heatmap(weights_right[i], min_r, | ||
max_r, right_color) | ||
r_hm.imshow(heatmap_r) | ||
r_hm.set_xlabel("output") | ||
r_hm.set_ylabel("input") | ||
r_hm.set_title("weights", y=-0.3) | ||
|
||
reward_plot.plot([0, i], [-1, -1]) | ||
reward_plot.plot(rewards_right[:i + 1], color=right_color / 255) | ||
reward_plot.plot(rewards_left[:i + 1], color=left_color / 255) | ||
|
||
# Change x_ticks and x_min for the first few plots | ||
if i < 1600: | ||
x_min = 0 | ||
reward_plot.set_xticks(np.arange(0, n_iterations, 250)) | ||
else: | ||
x_min = i - 1600 | ||
reward_plot.set_xticks(np.arange(0, n_iterations, 500)) | ||
|
||
reward_plot.set_ylabel("mean reward") | ||
reward_plot.set_yticks([0, 0.5, 1]) | ||
reward_plot.set_ylim(0, 1.0) | ||
reward_plot.set_xlim(x_min, i + 10) | ||
|
||
title.text(0.4, 0.75, name_left, ha='right', fontsize=15, c=left_color_hex) | ||
title.text(0.5, 0.75, "VS", ha='center', fontsize=17) | ||
title.text(0.6, 0.75, name_right, ha='left', fontsize=15, c=right_color_hex) | ||
|
||
l_score, r_score = score[i] | ||
|
||
l_info.text(0, 0.9, "run:", fontsize=14) | ||
l_info.text(0, 0.75, str(i), fontsize=14) | ||
l_info.text(1, 0.5, l_score, ha='right', va='center', fontsize=26, c=left_color_hex) | ||
|
||
r_info.text(0, 0.9, "speed:", fontsize=14) | ||
r_info.text(0, 0.75, str(output_speed) + 'x', fontsize=14) | ||
r_info.text(0, 0.5, r_score, ha='left', va='center', fontsize=26, c=right_color_hex) | ||
|
||
plt.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.35, hspace=0.35) | ||
plt.savefig(os.path.join(temp_dir, f"img_{str(i).zfill(6)}.png")) | ||
|
||
# Change the speed of the video to show performance before and after | ||
# training at DEFAULT_SPEED and fast-forward most of the training | ||
if 75 <= i < 100 or n_iterations - 400 <= i < n_iterations - 350: | ||
output_speed = 10 | ||
elif 100 <= i < n_iterations - 350: | ||
output_speed = 50 | ||
else: | ||
output_speed = DEFAULT_SPEED | ||
|
||
i += output_speed | ||
|
||
print("Image creation complete, collecting them into a GIF...") | ||
|
||
filenames = sorted(glob(os.path.join(temp_dir, "*.png"))) | ||
|
||
with imageio.get_writer(out_file, mode='I', fps=6) as writer: | ||
for filename in filenames: | ||
image = imageio.imread(filename) | ||
writer.append_data(image) | ||
print(f"GIF created under: {out_file}") | ||
|
||
if not keep_temps: | ||
print("Deleting temporary image files...") | ||
for in_file in filenames: | ||
os.unlink(in_file) | ||
os.rmdir(temp_dir) | ||
|
||
print("Done.") |
Oops, something went wrong.