In [5]:
import nle.dataset as nld

In [6]:
nld_nao_path = "/code/nld-nao/nld-nao-unzipped"

dbfilename = "/code/NetHack-Research/data/raw/nld-nao.db"

dataset_name = "nld-nao"

In [7]:
if not nld.db.exists(dbfilename):
    nld.db.create(dbfilename)
    # Add NLD-NAO data, use the `add_altorg_directory`.
    nld.add_altorg_directory(nld_nao_path, "nld-nao", dbfilename)
else:
    print(f"Database already exists: {dbfilename}")
    
# Connect Database and print games to verify
db_conn = nld.db.connect(filename=dbfilename)
print(f"NLD-NAO Database contains {nld.db.count_games('nld-nao', conn=db_conn)} games.")

Database already exists: /code/NetHack-Research/data/raw/nld-nao.db
NLD-NAO Database contains 1511228 games.


In [8]:
random_sample = nld.TtyrecDataset(
    "nld-nao",
    batch_size=32,
    seq_length=32,
    dbfilename=dbfilename,
)

minibatch = next(iter(random_sample))
minibatch.keys()
# Show structure of data
print(minibatch)

{'tty_chars': array([[[[ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         ...,
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32]],

        [[ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         ...,
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32]],

        [[ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         ...,
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32]],

        ...,

        [[ 84, 104, 101, ...,  32,  32,  32],
         [ 32,  32,  32, ...,  32,  32,  32],
    

In [9]:
print(minibatch["gameids"])

[[6548946 6548946 6548946 ... 6548946 6548946 6548946]
 [  25380   25380   25380 ...   25380   25380   25380]
 [  79108   79108   79108 ...   79108   79108   79108]
 ...
 [1933041 1933041 1933041 ... 1933041 1933041 1933041]
 [6323355 6323355 6323355 ... 6323355 6323355 6323355]
 [3772856 3772856 3772856 ... 3772856 3772856 3772856]]


In [None]:
import h5py
import numpy as np

# Define output HDF5 file
HDF5_FILE = "random_sample.hdf5"

def save_to_hdf5(minibatch, output_file):
    """
    Saves the minibatch data in an HDF5 format, organizing trajectories by game ID.

    Args:
        minibatch (dict): The extracted trajectory data from NLD-NAO.
        output_file (str): Path to save the HDF5 file.
    """
    with h5py.File(output_file, "w") as hdf5_file:
        unique_game_ids = np.unique(minibatch["gameids"])

        for game_id in unique_game_ids:
            mask = minibatch["gameids"] == game_id

            grp = hdf5_file.create_group(str(game_id))  # Store each game separately
            grp.create_dataset("tty_chars", data=minibatch["tty_chars"][mask])
            grp.create_dataset("tty_colors", data=minibatch["tty_colors"][mask])
            grp.create_dataset("tty_cursor", data=minibatch["tty_cursor"][mask[0]])
            grp.create_dataset("timestamps", data=minibatch["timestamps"][mask])
            grp.create_dataset("done", data=minibatch["done"][mask])

        print(f"Saved {len(unique_game_ids)} game trajectories to {output_file}")


# Run function to save minibatch
save_to_hdf5(minibatch, HDF5_FILE)

In [6]:
from katakomba.env import NetHackChallenge, OfflineNetHackChallengeWrapper
from katakomba.utils.datasets import SequentialBuffer

# The task is specified using the character field
env = NetHackChallenge (
  character = "mon-hum-neu",
  observation_keys = ["tty_chars", "tty_colors", "tty_cursor"]
)

# A convenient wrapper that provides interfaces for dataset loading, score normalization, and deathlevel extraction
env = OfflineNetHackChallengeWrapper(env)

# Several options for dataset reading (check the paper for details): 
# - from RAM, decompressed ("in_memory"): fast but requires a lot of RAM, takes 5-10 minutes for decompression first
# - from Disk, decompressed ("memmap"): a bit slower than RAM, takes 5-10 minutes for decompression first
# - from Disk, compressed ("compressed"): very slow but no need for decompression, useful for debugging
# Note that this will download the dataset automatically if not found
dataset = env.get_dataset(mode="compressed", scale="small")

# Throws an Error... 
# Auxillary tools for computing normalized scores or extracting deathlevels
# env.get_normalized_score(score=1337.0)
# env.get_current_depth()

Preparing:   0%|          | 0/683 [00:00<?, ?it/s]

In [7]:
buffer = SequentialBuffer(
  dataset=dataset,
  seq_len=32,
  batch_size=32, # Each batch element is a different trajectory
  seed=42,
  add_next_step=True # if you want (s, a, r, s') instead of (s, a, r)
)

# What's inside the batch?
# Note that the next batch will include the +1 element as expected
batch = buffer.sample()
print(
  batch["tty_chars"],  # [batch_size, seq_len + 1, 80, 24]
  batch["tty_colors"], # [batch_size, seq_len + 1, 80, 24]
  batch["tty_cursor"], # [batch_size, seq_len + 1, 2]
  batch["actions"],    # [batch_size, seq_len + 1]
  batch["rewards"],    # [batch_size, seq_len + 1]
  batch["dones"]       # [batch_size, seq_len + 1]
)

# In case you don't want to store the decompressed dataset beyond code execution
dataset.close()

[[[[ 72 101 108 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  83  58  32]
   [ 68 108 118 ...  32  32  32]]

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  83  58  32]
   [ 68 108 118 ...  32  32  32]]

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  83  58  32]
   [ 68 108 118 ...  32  32  32]]

  ...

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 68 105 115 ...  32  32  32]
   [ 32  45  45 ...  32  32  32]]

  [[ 32  32 115 ...  32  32  32]
   [ 83 112 101 ...  32  32  32]
   [ 32  32 115 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  3