In [None]:
from typing import Optional

import nle.dataset as nld
from nle.nethack import tty_render
import numpy as np

### Create Dataset

In [None]:
# if not nld.db.exists():
#     nld.db.create()
#     # NB: Different methods are used for data based on NLE and data from NAO.
#     nld.add_altorg_directory('data/nld-nao/', 'nld-nao-v0')

Adding dataset 'nld-nao-v0' ('data/nld-nao/') to 'ttyrecs.db' 
Found 1736841 games in 'data/nld-nao/xlogfile.nh363+'
Found 167705 games in 'data/nld-nao/xlogfile.nh362'
Found 20939 games in 'data/nld-nao/xlogfile.nh361dev'
Found 311980 games in 'data/nld-nao/xlogfile.nh361'
Found 732441 games in 'data/nld-nao/xlogfile.nh360'
Found 3585691 games in 'data/nld-nao/xlogfile.full.txt'
Matching up ttyrecs to games...
Optimizing DB...
Updated 'ttyrecs.db' in 196.83 sec. Size: 79.72 MB, Games: 190922


In [10]:
dataset = nld.TtyrecDataset('nld-nao-v0', batch_size=32, shuffle=True, seq_length=32)
dataloader = iter(dataset)

### One-hot Conversion

In [63]:
def sample_to_one_hot_observation(chars: np.ndarray, colors: np.ndarray, cursor: np.ndarray) -> np.ndarray:
    """
    Converts chars, colors, and cursor arrays into a one-hot encoded boolean observation tensor of shape (257, 24, 80).

    257 channels = 128 for chars + 128 for colors + 1 for cursor position.
    - Channels 0-127: one-hot for chars (ASCII 0-127)
    - Channels 128-255: one-hot for colors (0-127)
    - Channel 256: cursor (1 at (y, x), else 0)

    Args:
        chars (np.ndarray): (24, 80) array of uint8, ASCII codes of chars.
        colors (np.ndarray): (24, 80) array of int8, color codes [0-127].
        cursor (np.ndarray): (2,) array, (y, x) for cursor position.

    Returns:
        np.ndarray: One-hot tensor, shape (257, 24, 80), dtype=bool.
    """
    H, W = chars.shape
    out = np.zeros((257, H, W), dtype=bool)

    # Chars: one-hot encode into channels 0-127
    idx_rows, idx_cols = np.indices((H, W))
    char_mask = (chars < 128)
    out[chars[char_mask], idx_rows[char_mask], idx_cols[char_mask]] = True

    # Colors: one-hot encode into channels 128-255
    color_mask = (colors >= 0) & (colors < 128)
    out[128 + colors[color_mask].astype(np.uint8), idx_rows[color_mask], idx_cols[color_mask]] = True

    # Cursor: channel 256 is True at cursor location
    cy, cx = int(cursor[0]), int(cursor[1])
    if 0 <= cy < H and 0 <= cx < W:
        out[256, cy, cx] = True

    return out

def one_hot_observation_to_sample(one_hot: np.ndarray):
    """
    Converts a one-hot encoded boolean observation tensor (257, 24, 80) back to chars, colors, and cursor arrays.

    Returns:
        chars (np.ndarray): (24, 80) array of uint8 ASCII codes.
        colors (np.ndarray): (24, 80) array of int8 color codes [0-127].
        cursor (np.ndarray): (2,) array (y, x) for cursor position.
    """
    assert one_hot.shape[0] == 257, "Expected channel dimension of size 257"
    H, W = one_hot.shape[1:]

    # Chars: channels 0-127
    chars = np.zeros((H, W), dtype=np.uint8)
    char_layer = one_hot[0:128]
    chars[...] = char_layer.argmax(axis=0).astype(np.uint8)
    # If no one-hot, the default will be 0 (ASCII NUL)

    # Optionally mask locations without a char one-hot (all False in 0:128): could set to 32 (' ')
    nochar_mask = (char_layer.sum(axis=0) == 0)
    chars[nochar_mask] = 32  # space

    # Colors: channels 128-255
    color_layer = one_hot[128:256]
    colors = color_layer.argmax(axis=0).astype(np.int8)
    # If no one-hot, default will be 0
    nocolor_mask = (color_layer.sum(axis=0) == 0)
    colors[nocolor_mask] = 0

    # Cursor: channel 256
    cursor_mask = one_hot[256]
    cursor_idx = np.argwhere(cursor_mask)
    if cursor_idx.shape[0] == 0:
        # If no cursor marked, fallback to (0,0)
        cursor = np.array([0, 0], dtype=np.int64)
    else:
        # If more than one cursor marked, pick the first
        cursor = cursor_idx[0]

    return chars, colors, cursor

In [66]:
one_hot_obs = sample_to_one_hot_observation(
    samples['tty_chars'][0, 0], samples['tty_colors'][0, 0], samples['tty_cursor'][0, 0])

### Visualization

In [None]:
def print_ascii_array(
    ascii_array: np.ndarray,
    colors: Optional[np.ndarray] = None
) -> None:
    """Prints a 2D array of ASCII values with optional ANSI colors.

    Args:
        ascii_array: A 2D numpy array of uint8 ASCII values.
        colors: Optional 2D numpy array of int8 color codes (0-255 for 256-color
            mode, or 0-7 for basic ANSI colors). Must match ascii_array shape.
    """
    if ascii_array.ndim != 2:
        raise ValueError(f"Expected 2D array, got {ascii_array.ndim}D")
    
    if colors is not None:
        if colors.shape != ascii_array.shape:
            raise ValueError(
                f"Color array shape {colors.shape} doesn't match "
                f"ASCII array shape {ascii_array.shape}"
            )
        
        for row_idx in range(ascii_array.shape[0]):
            line_parts = []
            for col_idx in range(ascii_array.shape[1]):
                char = chr(ascii_array[row_idx, col_idx])
                color = colors[row_idx, col_idx]
                # Use 256-color ANSI escape: \033[38;5;{color}m
                line_parts.append(f"\033[38;5;{color}m{char}\033[0m")
            print("".join(line_parts))
    else:
        for row in ascii_array:
            print("".join(chr(val) for val in row))

In [19]:
# Basic usage without colors
ascii_data = np.array([[72, 105, 33], [89, 111, 33]], dtype=np.uint8)  # "Hi!" / "Yo!"
print_ascii_array(ascii_data)

# With colors (using 256-color palette)
colors = np.array([[196, 46, 226], [21, 208, 231]], dtype=np.uint8)  # red, green, magenta / blue, orange, cyan
print_ascii_array(ascii_data, colors)

Hi!
Yo!
[38;5;196mH[0m[38;5;46mi[0m[38;5;226m![0m
[38;5;21mY[0m[38;5;208mo[0m[38;5;231m![0m


In [68]:
samples = next(dataloader)

In [69]:
print_ascii_array(samples['tty_chars'][0, 0], samples['tty_colors'][0, 0])

[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5

In [70]:
print_ascii_array(*one_hot_observation_to_sample(sample_to_one_hot_observation(
    samples['tty_chars'][0, 0], samples['tty_colors'][0, 0], samples['tty_cursor'][0, 0]
))[:2])

[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5;0m [0m[38;5