Skip to content

Commit

Permalink
Merge pull request #15 from shelhamer/ale-latest
Browse files Browse the repository at this point in the history
Update ALE for ROM fixes and snapshot/restore of state
  • Loading branch information
shelhamer committed May 5, 2017
2 parents eed97d1 + acea408 commit 72acda2
Show file tree
Hide file tree
Showing 50 changed files with 720 additions and 434 deletions.
3 changes: 3 additions & 0 deletions atari_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ def get_game_path(game_name):
def list_games():
files = os.listdir(_game_dir())
return [os.path.basename(f).split(".")[0] for f in files]

# default to only logging errors
ALEInterface.setLoggerMode(ALEInterface.Logger.Error)
24 changes: 23 additions & 1 deletion atari_py/ale_c_wrapper.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
//all code is currently in the .h file
#include "ale_c_wrapper.h"

#include <cstring>
#include <string>
#include <stdexcept>

void encodeState(ALEState *state, char *buf, int buf_len) {
std::string str = state->serialize();

if (buf_len < int(str.length())) {
throw new std::runtime_error("Buffer is not big enough to hold serialized ALEState. Please use encodeStateLen to determine the correct buffer size");
}

memcpy(buf, str.data(), str.length());
}

int encodeStateLen(ALEState *state) {
return state->serialize().length();
}

ALEState *decodeState(const char *serialized, int len) {
std::string str(serialized, len);

return new ALEState(str);
}
45 changes: 15 additions & 30 deletions atari_py/ale_c_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,13 @@ extern "C" {
int getScreenWidth(ALEInterface *ale){return ale->getScreen().width();}
int getScreenHeight(ALEInterface *ale){return ale->getScreen().height();}

void getScreenRGB(ALEInterface *ale, int *output_buffer){
void getScreenRGB(ALEInterface *ale, unsigned char *output_buffer){
size_t w = ale->getScreen().width();
size_t h = ale->getScreen().height();
size_t screen_size = w*h;
pixel_t *ale_screen_data = ale->getScreen().getArray();

for(int i = 0;i < screen_size;i++){
output_buffer[i] = rgb_palette[ale_screen_data[i]];
}

ale->theOSystem->colourPalette().applyPaletteRGB(output_buffer, ale_screen_data, screen_size);
}

void getScreenRGB2(ALEInterface *ale, unsigned char *output_buffer){
Expand All @@ -77,10 +74,8 @@ extern "C" {
output_buffer[j++] = (zrgb>>8)&0xff;
output_buffer[j++] = (zrgb>>0)&0xff;
}

}


void getScreenGrayscale(ALEInterface *ale, unsigned char *output_buffer){
size_t w = ale->getScreen().width();
size_t h = ale->getScreen().height();
Expand All @@ -92,32 +87,22 @@ extern "C" {

void saveState(ALEInterface *ale){ale->saveState();}
void loadState(ALEInterface *ale){ale->loadState();}
ALEState* cloneState(ALEInterface *ale){return new ALEState(ale->cloneState());}
void restoreState(ALEInterface *ale, ALEState* state){ale->restoreState(*state);}
ALEState* cloneSystemState(ALEInterface *ale){return new ALEState(ale->cloneSystemState());}
void restoreSystemState(ALEInterface *ale, ALEState* state){ale->restoreSystemState(*state);}
void deleteState(ALEState* state){delete state;}
void saveScreenPNG(ALEInterface *ale,const char *filename){ale->saveScreenPNG(filename);}

ALEState* cloneState(ALEInterface *ale) {
return new ALEState(ale->cloneState());
}

void ALEState_del(ALEState* state) {
delete state;
}

void restoreState(ALEInterface *ale, ALEState* state) {
ale->restoreState(*state);
}

int ALEState_getFrameNumber(ALEState* state) {
return state->getFrameNumber();
}

int ALEState_getEpisodeFrameNumber(ALEState* state) {
return state->getEpisodeFrameNumber();
}

bool ALEState_equals(ALEState* a, ALEState *b) {
return a->equals(*b);
}
// Encodes the state as a raw bytestream. This may have multiple '\0' characters
// and thus should not be treated as a C string. Use encodeStateLen to find the length
// of the buffer to pass in, or it will be overrun as this simply memcpys bytes into the buffer.
void encodeState(ALEState *state, char *buf, int buf_len);
int encodeStateLen(ALEState *state);
ALEState *decodeState(const char *serialized, int len);

// 0: Info, 1: Warning, 2: Error
void setLoggerMode(int mode) { ale::Logger::setMode(ale::Logger::mode(mode)); }
}

#endif
55 changes: 39 additions & 16 deletions atari_py/ale_interface/src/ale_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ void ALEInterface::createOSystem(std::auto_ptr<OSystem> &theOSystem,
theSettings.reset(new SettingsUNIX(theOSystem.get()));
#endif

setDefaultSettings(theOSystem->settings());

theOSystem->settings().loadConfig();
}

Expand Down Expand Up @@ -91,16 +89,10 @@ void ALEInterface::loadSettings(const string& romfile,
exit(1);
}

// Seed random number generator
if (theOSystem->settings().getString("random_seed") == "time") {
Logger::Info << "Random Seed: Time" << endl;
Random::seed((uInt32)time(NULL));
} else {
int seed = theOSystem->settings().getInt("random_seed");
assert(seed >= 0);
Logger::Info << "Random Seed: " << seed << endl;
Random::seed((uInt32)seed);
}
// Must force the resetting of the OSystem's random seed, which is set before we change
// choose our random seed.
Logger::Info << "Random seed is " << theOSystem->settings().getInt("random_seed") << std::endl;
theOSystem->resetRNGSeed();

string currentDisplayFormat = theOSystem->console().getFormat();
theOSystem->colourPalette().setPalette("standard", currentDisplayFormat);
Expand Down Expand Up @@ -196,9 +188,8 @@ void ALEInterface::reset_game() {
}

// Indicates if the game has ended.
bool ALEInterface::game_over() {
return (environment->isTerminal() ||
(max_num_frames > 0 && getEpisodeFrameNumber() >= max_num_frames));
bool ALEInterface::game_over() const {
return environment->isTerminal();
}

// The remaining number of lives.
Expand Down Expand Up @@ -250,7 +241,7 @@ int ALEInterface::getFrameNumber() {
}

// Returns the frame number since the start of the current episode
int ALEInterface::getEpisodeFrameNumber() {
int ALEInterface::getEpisodeFrameNumber() const {
return environment->getEpisodeFrameNumber();
}

Expand All @@ -259,6 +250,30 @@ const ALEScreen& ALEInterface::getScreen() {
return environment->getScreen();
}

//This method should receive an empty vector to fill it with
//the grayscale colours
void ALEInterface::getScreenGrayscale(std::vector<unsigned char>& grayscale_output_buffer){
size_t w = environment->getScreen().width();
size_t h = environment->getScreen().height();
size_t screen_size = w*h;

pixel_t *ale_screen_data = environment->getScreen().getArray();
theOSystem->colourPalette().applyPaletteGrayscale(grayscale_output_buffer, ale_screen_data, screen_size);
}

//This method should receive a vector to fill it with
//the RGB colours. The first positions contain the red colours,
//followed by the green colours and then the blue colours
void ALEInterface::getScreenRGB(std::vector<unsigned char>& output_rgb_buffer){
size_t w = environment->getScreen().width();
size_t h = environment->getScreen().height();
size_t screen_size = w*h;

pixel_t *ale_screen_data = environment->getScreen().getArray();

theOSystem->colourPalette().applyPaletteRGB(output_rgb_buffer, ale_screen_data, screen_size * 3);
}

// Returns the current RAM content
const ALERAM& ALEInterface::getRAM() {
return environment->getRAM();
Expand All @@ -282,6 +297,14 @@ void ALEInterface::restoreState(const ALEState& state) {
return environment->restoreState(state);
}

ALEState ALEInterface::cloneSystemState() {
return environment->cloneSystemState();
}

void ALEInterface::restoreSystemState(const ALEState& state) {
return environment->restoreSystemState(state);
}

void ALEInterface::saveScreenPNG(const string& filename) {

ScreenExporter exporter(theOSystem->colourPalette());
Expand Down
28 changes: 24 additions & 4 deletions atari_py/ale_interface/src/ale_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
#include "os_dependent/SettingsUNIX.hxx"
#include "os_dependent/OSystemUNIX.hxx"
#include "games/Roms.hpp"
#include "common/Defaults.hpp"
#include "common/display_screen.h"
#include "environment/stella_environment.hpp"
#include "common/ScreenExporter.hpp"
Expand All @@ -46,7 +45,7 @@
#include <string>
#include <memory>

static const std::string Version = "0.5.0";
static const std::string Version = "0.5.1";

/**
This class interfaces ALE with external code for controlling agents.
Expand Down Expand Up @@ -83,7 +82,7 @@ class ALEInterface {
reward_t act(Action action);

// Indicates if the game has ended.
bool game_over();
bool game_over() const;

// Resets the game, but not the full system.
void reset_game();
Expand All @@ -103,11 +102,20 @@ class ALEInterface {
const int lives();

// Returns the frame number since the start of the current episode
int getEpisodeFrameNumber();
int getEpisodeFrameNumber() const;

// Returns the current game screen
const ALEScreen &getScreen();

//This method should receive an empty vector to fill it with
//the grayscale colours
void getScreenGrayscale(std::vector<unsigned char>& grayscale_output_buffer);

//This method should receive a vector to fill it with
//the RGB colours. The first positions contain the red colours,
//followed by the green colours and then the blue colours
void getScreenRGB(std::vector<unsigned char>& output_rgb_buffer);

// Returns the current RAM content
const ALERAM &getRAM();

Expand All @@ -117,10 +125,22 @@ class ALEInterface {
// Loads the state of the system
void loadState();

// This makes a copy of the environment state. This copy does *not* include pseudorandomness,
// making it suitable for planning purposes. By contrast, see cloneSystemState.
ALEState cloneState();

// Reverse operation of cloneState(). This does not restore pseudorandomness, so that repeated
// calls to restoreState() in the stochastic controls setting will not lead to the same outcomes.
// By contrast, see restoreSystemState.
void restoreState(const ALEState& state);

// This makes a copy of the system & environment state, suitable for serialization. This includes
// pseudorandomness and so is *not* suitable for planning purposes.
ALEState cloneSystemState();

// Reverse operation of cloneSystemState.
void restoreSystemState(const ALEState& state);

// Save the current screen as a png file
void saveScreenPNG(const std::string& filename);

Expand Down
38 changes: 34 additions & 4 deletions atari_py/ale_interface/src/common/ColourPalette.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ inline uInt32 convertGrayscale(uInt32 packedRGBValue)
double g = (packedRGBValue >> 8) & 0xff;
double b = (packedRGBValue >> 0) & 0xff;

uInt8 lum = (uInt8) round(r * 0.2989 + g * 0.5870 + b * 0.1140 );
uInt8 lum = (uInt8) round(r * 0.2989 + g * 0.5870 + b * 0.1140);

return packRGB(lum, lum, lum);
}
Expand All @@ -45,8 +45,10 @@ ColourPalette::ColourPalette(): m_palette(NULL) {

void ColourPalette::getRGB(int val, int &r, int &g, int &b) const
{
assert (m_palette != NULL);
assert(m_palette != NULL);
assert(val >= 0 && val <= 0xFF);
// Make sure we are reading from RGB, not grayscale.
assert((val & 0x01) == 0);

// Set the RGB components accordingly
r = (m_palette[val] >> 16) & 0xFF;
Expand All @@ -56,9 +58,10 @@ void ColourPalette::getRGB(int val, int &r, int &g, int &b) const

uInt8 ColourPalette::getGrayscale(int val) const
{
assert (m_palette != NULL);
assert(m_palette != NULL);
assert(val >= 0 && val < 0xFF);

assert((val & 0x01) == 1);

// Set the RGB components accordingly
return (m_palette[val+1] >> 0) & 0xFF;
}
Expand All @@ -81,6 +84,21 @@ void ColourPalette::applyPaletteRGB(uInt8* dst_buffer, uInt8 *src_buffer, size_t
}
}

void ColourPalette::applyPaletteRGB(std::vector<unsigned char>& dst_buffer, uInt8 *src_buffer, size_t src_size)
{
dst_buffer.resize(3 * src_size);
assert(dst_buffer.size() == 3 * src_size);

uInt8 *p = src_buffer;

for(size_t i = 0; i < src_size * 3; i += 3, p++){
int rgb = m_palette[*p];
dst_buffer[i+0] = (unsigned char) ((rgb >> 16)); // r
dst_buffer[i+1] = (unsigned char) ((rgb >> 8)); // g
dst_buffer[i+2] = (unsigned char) ((rgb >> 0)); // b
}
}

void ColourPalette::applyPaletteGrayscale(uInt8* dst_buffer, uInt8 *src_buffer, size_t src_size)
{
uInt8 *p = src_buffer;
Expand All @@ -91,6 +109,18 @@ void ColourPalette::applyPaletteGrayscale(uInt8* dst_buffer, uInt8 *src_buffer,
}
}

void ColourPalette::applyPaletteGrayscale(std::vector<unsigned char>& dst_buffer, uInt8 *src_buffer, size_t src_size)
{
dst_buffer.resize(src_size);
assert(dst_buffer.size() == src_size);

uInt8 *p = src_buffer;

for(size_t i = 0; i < src_size; i++, p++){
dst_buffer[i] = (unsigned char) (m_palette[*p+1] & 0xFF);
}
}

void ColourPalette::setPalette(const string& type,
const string& displayFormat)
{
Expand Down
3 changes: 3 additions & 0 deletions atari_py/ale_interface/src/common/ColourPalette.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef __COLOUR_PALETTE_HPP__
#define __COLOUR_PALETTE_HPP__

#include <vector>
#include <string>
// Include obscure header file for uInt32 definition
#include "../emucore/m6502/src/bspf/src/bspf.hxx"
Expand All @@ -41,13 +42,15 @@ class ColourPalette {
8 bits => 24 bits
*/
void applyPaletteRGB(uInt8* dst_buffer, uInt8 *src_buffer, size_t src_size);
void applyPaletteRGB(std::vector<unsigned char>& dst_buffer, uInt8 *src_buffer, size_t src_size);

/**
Applies the current grayscale palette to the src_buffer and returns the results in dst_buffer
For each byte in src_buffer, a single byte is returned in dst_buffer
8 bits => 8 bits
*/
void applyPaletteGrayscale(uInt8* dst_buffer, uInt8 *src_buffer, size_t src_size);
void applyPaletteGrayscale(std::vector<unsigned char>& dst_buffer, uInt8 *src_buffer, size_t src_size);

/**
Loads all defined palettes with PAL color-loss data depending
Expand Down
Loading

0 comments on commit 72acda2

Please sign in to comment.