In [12]:
import nle.dataset as nld
from nle.nethack import tty_render
from nle.dataset import db

In [13]:
nld_taster_path = "/code/nld-aa-taster/nle_data"
dbfilename = "/code/NetHack-Research/data/raw/nld-taster.db"
dataset_name = "nld-taster"

In [14]:
if not nld.db.exists(dbfilename):
    # 3. Create the db and add the directory
    nld.db.create(dbfilename)
    nld.add_nledata_directory(nld_taster_path, "nld-taster", dbfilename)

In [15]:
# Create a connection to specify the database to use
db_conn = nld.db.connect(filename=dbfilename)

# Then you can inspect the number of games in each dataset:
print(f"NLD-AA \"Taster\" Dataset has {nld.db.count_games('nld-taster', conn=db_conn)} games.")

NLD-AA "Taster" Dataset has 1934 games.


In [16]:
taster = nld.TtyrecDataset(
    "nld-taster",
    batch_size=32,
    seq_length=32,
    dbfilename=dbfilename,
)

minibatch = next(iter(taster))
minibatch.keys()
print(
  minibatch["tty_chars"],  # [batch_size, seq_len + 1, 80, 24]
  minibatch["tty_colors"], # [batch_size, seq_len + 1, 80, 24]
  minibatch["tty_cursor"], # [batch_size, seq_len + 1, 2]
  minibatch["keypresses"],    # [batch_size, seq_len + 1]
  minibatch["scores"],    # [batch_size, seq_len + 1]
  minibatch["done"]       # [batch_size, seq_len + 1]
)

[[[[ 83  97 108 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  32  32  32]
   [ 68 108 118 ...  32  32  32]]

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  32  32  32]
   [ 68 108 118 ...  32  32  32]]

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  32  32  32]
   [ 68 108 118 ...  32  32  32]]

  ...

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ... 115  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  32  32  32]
   [ 68 108 118 ...  32  32  32]]

  [[ 83 104 111 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  32  3

In [18]:
from katakomba.env import NetHackChallenge, OfflineNetHackChallengeWrapper
from katakomba.utils.datasets import SequentialBuffer

# The task is specified using the character field
env = NetHackChallenge (
  character = "mon-hum-neu",
  observation_keys = ["tty_chars", "tty_colors", "tty_cursor"]
)

# A convenient wrapper that provides interfaces for dataset loading, score normalization, and deathlevel extraction
env = OfflineNetHackChallengeWrapper(env)

# Several options for dataset reading (check the paper for details): 
# - from RAM, decompressed ("in_memory"): fast but requires a lot of RAM, takes 5-10 minutes for decompression first
# - from Disk, decompressed ("memmap"): a bit slower than RAM, takes 5-10 minutes for decompression first
# - from Disk, compressed ("compressed"): very slow but no need for decompression, useful for debugging
# Note that this will download the dataset automatically if not found
dataset = env.get_dataset(mode="compressed", scale="small")

# Throws an Error... 
# Auxillary tools for computing normalized scores or extracting deathlevels
# env.get_normalized_score(score=1337.0)
# env.get_current_depth()


Preparing:   0%|          | 0/683 [00:00<?, ?it/s]

In [19]:
buffer = SequentialBuffer(
  dataset=dataset,
  seq_len=32,
  batch_size=32, # Each batch element is a different trajectory
  seed=42,
  add_next_step=True # if you want (s, a, r, s') instead of (s, a, r)
)

# What's inside the batch?
# Note that the next batch will include the +1 element as expected
batch = buffer.sample()
print(
  batch["tty_chars"],  # [batch_size, seq_len + 1, 80, 24]
  batch["tty_colors"], # [batch_size, seq_len + 1, 80, 24]
  batch["tty_cursor"], # [batch_size, seq_len + 1, 2]
  batch["actions"],    # [batch_size, seq_len + 1]
  batch["rewards"],    # [batch_size, seq_len + 1]
  batch["dones"]       # [batch_size, seq_len + 1]
)

# In case you don't want to store the decompressed dataset beyond code execution
dataset.close()

[[[[ 72 101 108 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  83  58  32]
   [ 68 108 118 ...  32  32  32]]

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  83  58  32]
   [ 68 108 118 ...  32  32  32]]

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 65 103 101 ...  83  58  32]
   [ 68 108 118 ...  32  32  32]]

  ...

  [[ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 68 105 115 ...  32  32  32]
   [ 32  45  45 ...  32  32  32]]

  [[ 32  32 115 ...  32  32  32]
   [ 83 112 101 ...  32  32  32]
   [ 32  32 115 ...  32  32  32]
   ...
   [ 32  32  32 ...  32  32  32]
   [ 32  32  32 ...  32  3

In [None]:
import h5py

# Open the HDF5 file
file_path = "/code/NetHack-Research/data/raw/data-arc-hum-law-any.hdf5"
with h5py.File(file_path, 'r') as f:
    # Print all groups and datasets
    def print_structure(name, obj):
        print(name)

    f.visititems(print_structure)
