diff --git a/include/nleobs.h b/include/nleobs.h index 87298a489..274085981 100644 --- a/include/nleobs.h +++ b/include/nleobs.h @@ -6,6 +6,7 @@ #define NLE_BLSTATS_SIZE 25 #define NLE_PROGRAM_STATE_SIZE 6 #define NLE_INTERNAL_SIZE 9 +#define NLE_MISC_SIZE 3 #define NLE_INVENTORY_SIZE 55 #define NLE_INVENTORY_STR_LENGTH 80 #define NLE_SCREEN_DESCRIPTION_LENGTH 80 @@ -37,6 +38,7 @@ typedef struct nle_observation { unsigned char *tty_chars; /* Size NLE_TERM_LI * NLE_TERM_CO */ signed char *tty_colors; /* Size NLE_TERM_LI * NLE_TERM_CO */ unsigned char *tty_cursor; /* Size 2 */ + int *misc; /* Size NLE_MISC_SIZE */ } nle_obs; typedef struct { diff --git a/nle/env/base.py b/nle/env/base.py index 67eed3609..669f5745c 100644 --- a/nle/env/base.py +++ b/nle/env/base.py @@ -130,6 +130,14 @@ "tty_cursor", gym.spaces.Box(low=0, high=255, **nethack.OBSERVATION_DESC["tty_cursor"]), ), + ( + "misc", + gym.spaces.Box( + low=np.iinfo(np.int32).min, + high=np.iinfo(np.int32).max, + **nethack.OBSERVATION_DESC["misc"] + ), + ), ) diff --git a/nle/env/tasks.py b/nle/env/tasks.py index 753c6951b..c624621fc 100644 --- a/nle/env/tasks.py +++ b/nle/env/tasks.py @@ -320,6 +320,7 @@ def __init__( "tty_chars", "tty_colors", "tty_cursor", + "misc" ), no_progress_timeout: int = 10_000, **kwargs, diff --git a/nle/nethack/nethack.py b/nle/nethack/nethack.py index 2c8a818b5..f85723cbb 100644 --- a/nle/nethack/nethack.py +++ b/nle/nethack/nethack.py @@ -16,6 +16,7 @@ MESSAGE_SHAPE = (_pynethack.nethack.NLE_MESSAGE_SIZE,) PROGRAM_STATE_SHAPE = (_pynethack.nethack.NLE_PROGRAM_STATE_SIZE,) INTERNAL_SHAPE = (_pynethack.nethack.NLE_INTERNAL_SIZE,) +MISC_SHAPE = (_pynethack.nethack.NLE_MISC_SIZE,) INV_SIZE = (_pynethack.nethack.NLE_INVENTORY_SIZE,) INV_STRS_SHAPE = ( _pynethack.nethack.NLE_INVENTORY_SIZE, @@ -43,6 +44,7 @@ "tty_chars": dict(shape=TERMINAL_SHAPE, dtype=np.uint8), "tty_colors": dict(shape=TERMINAL_SHAPE, dtype=np.int8), "tty_cursor": dict(shape=(2,), dtype=np.uint8), + "misc": dict(shape=MISC_SHAPE, dtype=np.int32), } diff --git a/win/rl/pynethack.cc b/win/rl/pynethack.cc index 5fa4473c8..937bbec1f 100644 --- a/win/rl/pynethack.cc +++ b/win/rl/pynethack.cc @@ -117,7 +117,7 @@ class Nethack py::object inv_glyphs, py::object inv_letters, py::object inv_oclasses, py::object inv_strs, py::object screen_descriptions, py::object tty_chars, - py::object tty_colors, py::object tty_cursor) + py::object tty_colors, py::object tty_cursor, py::object misc) { std::vector dungeon{ ROWNO, COLNO - 1 }; obs_.glyphs = checked_conversion(glyphs, dungeon); @@ -147,6 +147,7 @@ class Nethack obs_.tty_colors = checked_conversion( tty_colors, { NLE_TERM_LI, NLE_TERM_CO }); obs_.tty_cursor = checked_conversion(tty_cursor, { 2 }); + obs_.misc = checked_conversion(misc, { NLE_MISC_SIZE }); py_buffers_ = { std::move(glyphs), std::move(chars), @@ -163,7 +164,8 @@ class Nethack std::move(screen_descriptions), std::move(tty_chars), std::move(tty_colors), - std::move(tty_cursor) }; + std::move(tty_cursor), + std::move(misc) }; } void @@ -281,7 +283,7 @@ PYBIND11_MODULE(_pynethack, m) py::arg("screen_descriptions") = py::none(), py::arg("tty_chars") = py::none(), py::arg("tty_colors") = py::none(), - py::arg("tty_cursor") = py::none()) + py::arg("tty_cursor") = py::none(), py::arg("misc") = py::none()) .def("close", &Nethack::close) .def("set_initial_seeds", &Nethack::set_initial_seeds) .def("set_seeds", &Nethack::set_seeds) @@ -297,6 +299,7 @@ PYBIND11_MODULE(_pynethack, m) mn.attr("NLE_BLSTATS_SIZE") = py::int_(NLE_BLSTATS_SIZE); mn.attr("NLE_PROGRAM_STATE_SIZE") = py::int_(NLE_PROGRAM_STATE_SIZE); mn.attr("NLE_INTERNAL_SIZE") = py::int_(NLE_INTERNAL_SIZE); + mn.attr("NLE_MISC_SIZE") = py::int_(NLE_MISC_SIZE); mn.attr("NLE_INVENTORY_SIZE") = py::int_(NLE_INVENTORY_SIZE); mn.attr("NLE_INVENTORY_STR_LENGTH") = py::int_(NLE_INVENTORY_STR_LENGTH); mn.attr("NLE_SCREEN_DESCRIPTION_LENGTH") = diff --git a/win/rl/winrl.cc b/win/rl/winrl.cc index b5694cf0f..e659503b0 100644 --- a/win/rl/winrl.cc +++ b/win/rl/winrl.cc @@ -279,6 +279,11 @@ NetHackRL::fill_obs(nle_obs *obs) obs->internal[8] = u.urexp; /* score (careful! check botl_score() and end.c) */ } + if (obs->misc) { + obs->misc[0] = in_yn_function; + obs->misc[1] = in_getlin; + obs->misc[2] = xwaitingforspace; + } if ((!program_state.something_worth_saving && !program_state.in_moveloop) || !iflags.window_inited) {