# Intro to Using the NetHack Learning Dataset

There are two different sets of trajectories included in the NetHack Learning Dataset:
- **NLD-NAO**: state-only trajectories from 1.5 M human games played on nethack.alt.org
- **NLD-AA**: state-action-score trajectories from 100k NLE games played by the symbolic-bot winner of the 2021 NetHack Challenge

These trajectories can be used with the `TtyrecDataset` tool which allows for efficiently training on the datasets.  This tutorial describes how to create and use and visualize the dataset.

## Downloading the Data (Alternative for `NLD-NAO`)

For instructions on how to download the data, see the README.md in this repo. It is currently hosted on WeTransfer, but will move to a separate public bucket and site.  However, it is worth noting you can download the nle-nao ttyrecs directly from alt.orgs s3 bucket, and the zipped xlogfiles to reconstruct the `nld-nao`. This should in principle reconstruct the same dataset, though hasnt been verified by the authors.

```
# First, install aws cli and verify access by lsing the file for the user FallenPhoenix81

$ pip install awscli
$ aws s3 ls s3://altorg/ttyrec/FallenPhoenix81/ --no-sign-request --human-readable
2017-03-07 10:42:52    1.1 KiB 2012-07-05.00:46:53.ttyrec.bz2

# Then download the files to a new directory.  This will take a VERY long time.

$ mkdir /path/to/altorg_dataset/
$ aws s3 sync s3://altorg/ttyrec/ /path/to/altorg_dataset/
download: s3://altorg/ttyrec/00/...
....
```

## Setting up the Database

Make sure you have `nle v0.9.0` installed by following the instructions [in the repo README here](https://github.com/facebookresearch/nle). Either clone and install or call: 

`pip install git+https://github.com/facebookresearch/nle.git@main` 

Then, create the database from the datafiles as follows:

In [1]:
import nle.dataset as nld

In [2]:
path_to_nld_aa = "/path/to/nld-aa"
path_to_nld_nao = "/path/to/nld-nao"
path_to_custom = "/path/to/a/custom/nledata/directory"

# Chose a database name/path. By default, most methods with use nld.db.DB (='ttyrecs.db')
dbfilename = "ttyrecs.db"
if not nld.db.exists(dbfilename):
    nld.db.create(dbfilename)
        
    # To add the NLE-AA data, or any data generated from nle, use `add_nledata_directory`.
    nld.add_nledata_directory(path_to_nld_aa, "nld-aa", dbfilename)

    # to add the NLE-NAO data, use the `add_altorg_directory`.
    nld.add_altorg_directory(path_to_nld_nao, "nld-nao", dbfilename) 
    
    # To add a custom NLE directory, as above, use `add_nledata_directory`.
    # nld.add_nledata_directory(path_to_custom, "custom_name", dbfilename)

You can inspect the dataset using the database tooling:

In [3]:
# Create a connection to specify the database to use
db_conn = nld.db.connect(filename=dbfilename)

# Then you can inspect the number of games in each dataset:
print(f"Autoascend Dataset has {nld.db.count_games('nld-aa', conn=db_conn)} games.")
print(f"AltOrg Dataset has {nld.db.count_games('nld-nao', conn=db_conn)} games.")

Autoascend Dataset has 109545 games.
AltOrg Dataset has 1511228 games.


## Visualizing the Data

Next, to actually load the games for training you'll use the `TtyrecDataset` object:

In [4]:
dataset = nld.TtyrecDataset(
    "nld-aa",
    batch_size=32,
    seq_length=32,
    dbfilename=dbfilename,
)

This dataset above will return batches of 128 trajectories, returning sequential chunks of length 32.   That is, assuming the length of all trajectories is >>64, the first batch will give timesteps 0-31 of 128 games and the second batch will provide timesteps 32-63 for the same games, etc.

### Whats in the Observation?

In [5]:
minibatch = next(iter(dataset))
minibatch.keys()

dict_keys(['tty_chars', 'tty_colors', 'tty_cursor', 'timestamps', 'done', 'gameids', 'keypresses', 'scores'])

The observation is made up of three components:
- `tty_chars` is a (batched) 2D np.array of the characters displayed at each point on the screen with shape: `[Batch, Time, H, W]`
- `tty_colors` is the associated colors for those characters
- `tty_cursor` provides the cursor position (NOTE: it's not always on the hero!)

These can be easily visualized usign the `tty_render` utility:

In [6]:
from nle.nethack import tty_render

In [7]:
batch_idx = 0
time_idx = 0
chars = minibatch['tty_chars'][batch_idx, time_idx]
colors = minibatch['tty_colors'][batch_idx, time_idx]
cursor = minibatch['tty_cursor'][batch_idx, time_idx]

print(tty_render(chars, colors, cursor))


[0;37mK[0;37mo[0;37mn[0;37mn[0;37mi[0;37mc[0;37mh[0;37mi[0;30m [0;37mw[0;37ma[0;30m [0;37mA[0;37mg[0;37me[0;37mn[0;37mt[0;37m,[0;30m [0;37mw[0;37me[0;37ml[0;37mc[0;37mo[0;37mm[0;37me[0;30m [0;37mt[0;37mo[0;30m [0;37mN[0;37me[0;37mt[0;37mH[0;37ma[0;37mc[0;37mk[0;37m![0;30m [0;30m [0;37mY[0;37mo[0;37mu[0;30m [0;37ma[0;37mr[0;37me[0;30m [0;37ma[0;30m [0;37ml[0;37ma[0;37mw[0;37mf[0;37mu[0;37ml[0;30m [0;37mf[0;37me[0;37mm[0;37ma[0;37ml[0;37me[0;30m [0;37mh[0;37mu[0;37mm[0;37ma[0;37mn[0;30m [0;37mS[0;37ma[0;37mm[0;37mu[0;37mr[0;37ma[0;37mi[0;37m.[0;30m [0;30m 
[0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30

Then, the other elements of the batch are:
- `gameids`: The gameid for the game which the observation is from.
- `timestamps`: The time when the state was recorded, allowing you to understand how long the player took between frames.
- `keypresses`: The keypresses entered after seeing the observation at this timestep (which produces the observation at the next timestep).
- `scores`: The in-game score at this timestep (the result of the action at the previous timestep)
- `done`: Whether the gameid corresponding to the previous timestep's observation completed. If done is `True` this means that the observation at the current timestep is the beginning of the next gameid.

### Converting Actions from Keypresses to Environment Action Space

Note that the "actions" data is actually a keypress (eg ascii) entered not an action value corresponding to the actions in the nle environment.  To convert from keypresses to the action_space of the environment you can use an embedding as shown below:

In [8]:
import torch
from nle.env.tasks import NetHackChallenge


env = NetHackChallenge(
    savedir=None,  # Do not save any recordings. 
    character='@', # Randomly rotate through characters.
)

# Then use the environment actions to convert the keypresses.
embed_actions = torch.zeros((256, 1))
for i, a in enumerate(env.actions):
    embed_actions[a.value][0] = i
    
embed_actions = torch.nn.Embedding.from_pretrained(embed_actions)
keypresses = torch.Tensor(minibatch["keypresses"]).long()
actions = embed_actions(keypresses).squeeze(-1).long()

## Dataset Configuration Options
`shuffle`: While states within a trajectory are always returned sequentially, it is possible to turn on shuffling of the *gameids*.  When true, the order of the gameids sampled is shuffled but not the order of the `seq_length` chunks returned within a single gameid.

`loop_forever`: It is possible to have the iterator loop forever instead of cycling only through the dataset once.

`gameids`: You can specify a list of gameids to return instead of iterating through the full dataset.

`subselect_sql`: And, you can select even more complicated sets of games using specific sql queries.

**Example 1:** Lets create a small dataset of just 4 games, and see the shuffle functionality:

In [9]:
shuffle_small_dataset = nld.TtyrecDataset(
    "nld-aa",
    batch_size=2,
    seq_length=6000,
    dbfilename=dbfilename,
    shuffle=True,
    loop_forever=False,
    gameids=[109543, 109544, 109545],
)
for epoch in range(3):
    print(f"Epoch: {epoch}")
    for ind, mb in enumerate(shuffle_small_dataset):
        gameids = mb["gameids"][:, 0]
        print(f"  Batch {ind} first timestep gameids: {gameids}")
    print()


Epoch: 0
  Batch 0 first timestep gameids: [109544 109543]
  Batch 1 first timestep gameids: [109544 109543]
  Batch 2 first timestep gameids: [109544 109543]
  Batch 3 first timestep gameids: [109544 109543]
  Batch 4 first timestep gameids: [109544 109543]
  Batch 5 first timestep gameids: [109544 109543]
  Batch 6 first timestep gameids: [109544 109543]
  Batch 7 first timestep gameids: [109544 109543]
  Batch 8 first timestep gameids: [109544 109543]
  Batch 9 first timestep gameids: [109545      0]

Epoch: 1
  Batch 0 first timestep gameids: [109544 109543]
  Batch 1 first timestep gameids: [109544 109543]
  Batch 2 first timestep gameids: [109544 109543]
  Batch 3 first timestep gameids: [109544 109543]
  Batch 4 first timestep gameids: [109544 109543]
  Batch 5 first timestep gameids: [109544 109543]
  Batch 6 first timestep gameids: [109544 109543]
  Batch 7 first timestep gameids: [109544 109543]
  Batch 8 first timestep gameids: [109544 109543]
  Batch 9 first timestep gameid

**Example 2:** We can train just on the data from a specific character, such as "mon-hum-neu-mal" by using the subselect_sql:

In [10]:
# Build the subselect sql query
subselect_sql = "SELECT gameid FROM games WHERE role=? AND race=?"
subselect_sql_args = ("Mon", "Hum")
batch_size = 10

# Build the dataset
monk_dataset = nld.TtyrecDataset(
    "nld-aa",
    batch_size=batch_size,
    seq_length=2,
    dbfilename=dbfilename,
    subselect_sql=subselect_sql,
    subselect_sql_args=subselect_sql_args
)

# See from the error how there are fewer than 10k games despite the full dataset having 109k
print(f"Full Autoascend Dataset has {nld.db.count_games('nld-aa', conn=db_conn):,} games.")
print(f"Human Monk Subdataset Has: {len(monk_dataset._gameids)}")

mb = next(iter(monk_dataset))
    
batch_idx = 0
time_idx = 0
chars = mb['tty_chars'][batch_idx, time_idx]
colors = mb['tty_colors'][batch_idx, time_idx]
cursor = mb['tty_cursor'][batch_idx, time_idx]

print(tty_render(chars, colors, cursor))

Full Autoascend Dataset has 109,545 games.
Human Monk Subdataset Has: 8124

[0;37mH[0;37me[0;37ml[0;37ml[0;37mo[0;30m [0;37mA[0;37mg[0;37me[0;37mn[0;37mt[0;37m,[0;30m [0;37mw[0;37me[0;37ml[0;37mc[0;37mo[0;37mm[0;37me[0;30m [0;37mt[0;37mo[0;30m [0;37mN[0;37me[0;37mt[0;37mH[0;37ma[0;37mc[0;37mk[0;37m![0;30m [0;30m [0;37mY[0;37mo[0;37mu[0;30m [0;37ma[0;37mr[0;37me[0;30m [0;37ma[0;30m [0;37mn[0;37me[0;37mu[0;37mt[0;37mr[0;37ma[0;37ml[0;30m [0;37mf[0;37me[0;37mm[0;37ma[0;37ml[0;37me[0;30m [0;37mh[0;37mu[0;37mm[0;37ma[0;37mn[0;30m [0;37mM[0;37mo[0;37mn[0;37mk[0;37m.[0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m 
[0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0

**Example 3**: Using a threadpool
You can also use a threadpool with the dataset which will speed it up considerably!

In [11]:
from concurrent.futures import ThreadPoolExecutor
import time


with ThreadPoolExecutor(max_workers=10) as tp:
    dataset = nld.TtyrecDataset(
        "nld-aa",
        batch_size=100,
        seq_length=100,
        dbfilename=dbfilename,
        threadpool=tp
    )
    start = time.time()
    for i, mb in enumerate(dataset):
        if i == 10:
            break
    end = time.time()
    chars = mb['tty_chars'][batch_idx, time_idx]
    colors = mb['tty_colors'][batch_idx, time_idx]
    cursor = mb['tty_cursor'][batch_idx, time_idx]

    print(tty_render(chars, colors, cursor))
print(f"Loaded 100,000 frames in {end-start:.2f}s")


[0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m 
[0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30

**Example 4:** Getting Metadata

In [13]:
dataset = nld.TtyrecDataset('nld-aa', dbfilename=dbfilename)
mb = next(iter(dataset))
gameid = mb["gameids"][0][0]

chars = mb['tty_chars'][0, 0]
colors = mb['tty_colors'][0, 0]
cursor = mb['tty_cursor'][0, 0]

print(tty_render(chars, colors, cursor))

dict(dataset.get_meta(gameid))


[0;37mH[0;37me[0;37ml[0;37ml[0;37mo[0;30m [0;37mA[0;37mg[0;37me[0;37mn[0;37mt[0;37m,[0;30m [0;37mw[0;37me[0;37ml[0;37mc[0;37mo[0;37mm[0;37me[0;30m [0;37mt[0;37mo[0;30m [0;37mN[0;37me[0;37mt[0;37mH[0;37ma[0;37mc[0;37mk[0;37m![0;30m [0;30m [0;37mY[0;37mo[0;37mu[0;30m [0;37ma[0;37mr[0;37me[0;30m [0;37ma[0;30m [0;37mn[0;37me[0;37mu[0;37mt[0;37mr[0;37ma[0;37ml[0;30m [0;37mf[0;37me[0;37mm[0;37ma[0;37ml[0;37me[0;30m [0;37mh[0;37mu[0;37mm[0;37ma[0;37mn[0;30m [0;37mH[0;37me[0;37ma[0;37ml[0;37me[0;37mr[0;37m.[0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m 
[0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30m [0;30

{'gameid': 62334,
 'version': '3.6.6',
 'points': 1356,
 'deathdnum': 0,
 'deathlev': 1,
 'maxlvl': 1,
 'hp': 0,
 'maxhp': 35,
 'deaths': 1,
 'deathdate': 20220518,
 'birthdate': 20220518,
 'uid': 1185200751,
 'role': 'Hea',
 'race': 'Hum',
 'gender': 'Fem',
 'align': 'Neu',
 'name': 'Agent',
 'death': 'killed by a hobbit while fainted from lack of food',
 'conduct': '0xf80',
 'turns': 10374,
 'achieve': '0x0',
 'realtime': 46,
 'starttime': 1652885630,
 'endtime': 1652885676,
 'gender0': 'Fem',
 'align0': 'Neu',
 'flags': '0x4'}

**Exammple 5** Generating and loading a custom dataset.

In [6]:
import gym
import nle
import nle.dataset as nld
from datetime import datetime

def generate_rollouts(env):
    obs = env.reset()
    episodes = 0
    while episodes < 10:
        obs, reward, done, info = env.step(env.action_space.sample())
        if done:
            env.reset()
            episodes += 1

# 1. Create some envs, with a savedir directory 'path/to/save/X'
envA = gym.make("NetHackChallenge-v0", savedir="path/to/save/A", save_ttyrec_every=2)
envB = gym.make("NetHackScore-v0", character="Mon-Hum-Neu-Mal", savedir="path/to/save/B", save_ttyrec_every=1)

# 2. Generate rollouts
generate_rollouts(envA)
generate_rollouts(envB)

# 3. Add to directory, with given unique dataset name
name = f"dataset_{datetime.now().time()}"
if not nld.db.exists():
    nld.db.create()
nld.add_nledata_directory("path/to/save", name)

# 4. Use and enjoy!
dataset = nld.TtyrecDataset(name)
print(f"Dataset has {len(dataset._gameids)} entries!")



Adding dataset 'dataset_13:51:04.792888' ('path/to/save') to 'ttyrecs.db' 
Updated 'ttyrecs.db' in 0.00 sec. Size: 0.04 MB, Games: 15
Dataset has 15 entries!


**Example 6:** Use doctstrings - don't forget a lot of the classes and methods have docstrings. Have fun!

In [14]:
help(nld.TtyrecDataset)

Help on class TtyrecDataset in module nle.dataset.dataset:

class TtyrecDataset(builtins.object)
 |  TtyrecDataset(dataset_name, batch_size=128, seq_length=32, rows=24, cols=80, dbfilename='ttyrecs.db', threadpool=None, gameids=None, shuffle=True, loop_forever=False, subselect_sql=None, subselect_sql_args=None)
 |  
 |  Dataset object to allow iteration through the ttyrecs found in our ttyrec
 |  database.
 |  
 |  Methods defined here:
 |  
 |  __init__(self, dataset_name, batch_size=128, seq_length=32, rows=24, cols=80, dbfilename='ttyrecs.db', threadpool=None, gameids=None, shuffle=True, loop_forever=False, subselect_sql=None, subselect_sql_args=None)
 |      An iterable dataset to load minibatches of NetHack games from compressed
 |      ttyrec*.bz2 files into numpy arrays. (shape: [batch_size, seq_length, ...])
 |      
 |      This class makes use of a sqlite3 database at `dbfilename` to find the
 |      metadata and the location of files in a dataset. It then uses these to
 |   