In [1]:
import os
import cv2
import numpy as np

import torch as th
import torch.nn.functional as F

# video rendering deps
import matplotlib.pyplot as plt
%matplotlib inline
import IPython.display
import imageio
from IPython.display import Video

import gym
import gym_minigrid
import hwm.gym_minigrid_2.fourroom_cstm # custom FourRoom envs
from gym_minigrid.wrappers import ReseedWrapper
from hwm.gym_minigrid_2.wrappers import RGBImgFullGridWrapper, ChannelFirstImgWrapper, \
    RGBImgResizeWrapper, ActionMaskingWrapper, FactoredStateRepWrapper, RenderWithoutHighlightWrapper

from stable_baselines3 import PPO

In [2]:
# Dataset parameterization
# Parallel environments

# Set dataset filename
dataset_filename = f"datasets/MiniGrid-FourRooms-Size11-v0_SB3_PPOAgent_WithFactStates_40000"
padded_dataset_filename = f"{dataset_filename}_Padded"

## Pre-processing dataset

In [3]:
# First, load the original dataset
dataset = np.load(dataset_filename + '.npz', allow_pickle=True)

In [4]:
# Some dataset stats
# checking the dataset, namely the consistency of the various fields saved
print("obs. shape: ", dataset["observations"][0].shape)
print("rendered obs. shape: ", dataset["rendered_observations"][0].shape)
print("actions shape: ", dataset["actions"][0].shape)
print("rewards shape: ", dataset["rewards"][0].shape)
print("terminals shape: ", dataset["terminals"][0].shape)
print("factored_states shape: ", dataset["factored_states"][0].shape)
print("")
# Getting some stats from the dataset
ep_lengths = [len(ep_terminal_list) for ep_terminal_list in dataset["terminals"]]
print(f"Mean ep. length: {np.mean(ep_lengths)}")
print(f"Median ep. length: {np.median(ep_lengths)}")
print(f"Min. ep. length: {np.min(ep_lengths)}")
print(f"Max. ep. length: {np.max(ep_lengths)}")
print(f"Std. Dev. ep. length: {np.std(ep_lengths)}")

# print(train_dset["metadata"])

obs. shape:  (8, 3, 64, 64)
rendered obs. shape:  (8, 352, 352, 3)
actions shape:  (8, 3)
rewards shape:  (8, 1)
terminals shape:  (8, 1)
factored_states shape:  (8, 36)

Mean ep. length: 11.474756167527252
Median ep. length: 11.0
Min. ep. length: 3
Max. ep. length: 20
Std. Dev. ep. length: 3.2040600383600006


In [5]:
print(dataset["max_length"])
print(dataset["min_length"])
print(dataset["act_shape"])
print(dataset["rendered_observation_shape"])
print(dataset["n_episodes"])

20
3
3
[352 352   3]
3486


In [8]:
from tqdm.notebook import tqdm

# Recover some metat data to use for the padding
max_length, n_episodes, obs_shape, rendered_obs_shape, act_shape = \
    dataset["max_length"], \
    dataset["n_episodes"], \
    dataset["observation_shape"], \
    dataset["rendered_observation_shape"], \
    dataset["act_shape"]

# Equivalent to the hwm_init_size = 1 from previous experiments
PAD_LEFT = 1
PAD_RIGHT = 1

PADDED_SEQ_LENGTH = PAD_LEFT + max_length + PAD_RIGHT

real_episode_lengths = [len(dataset["terminals"][ep_idx]) for ep_idx in range(n_episodes)]

padded_dataset = {
    "observations": np.zeros([n_episodes, PADDED_SEQ_LENGTH, *obs_shape], dtype=dataset["observations"][0].dtype),
    "actions": np.zeros([n_episodes, PADDED_SEQ_LENGTH, act_shape], dtype=np.uint8),
    "terminals": np.zeros([n_episodes, PADDED_SEQ_LENGTH, 1], dtype=np.float32),
    # NOTE: For now, don't really need rewards it seems
    # "rewards": np.zeros([n_episodes, PADDED_SEQ_LENGTH, 1]),
    # TODO: make this more memeory efficienct, as we don't really need padding here
    "rendered_observations": dataset["rendered_observations"],
    "factored_states": dataset["factored_states"],

    ## NOTE: We use a custom masking scheme to make sure that the HWM losses are computed
    ## only on the relevant parts of the padded trajectory. In other words, this does not take
    ## into account the PAD_RIGHT and PAD_LEFT
    "depad_masks": np.zeros([n_episodes, max_length, 1], dtype=np.float32),
    ## NOTE: for each padded epeisode trajectory, holds the real start and end indicies, respectively
    "depad_slices": np.zeros([n_episodes, 2], dtype=np.uint8),
    "unpadded_length": np.zeros([n_episodes], dtype=np.uint8),
    
    # Default dataset\s metadata
    "padded_length": PADDED_SEQ_LENGTH,
    "max_length": dataset["max_length"],
    "min_length": dataset["min_length"],
    "act_shape": dataset["act_shape"],
    "factored_state_shape": dataset["factored_state_shape"],
    "observation_shape": dataset["observation_shape"],
    "rendered_observation_shape": dataset["rendered_observation_shape"],
    "n_episodes": dataset["n_episodes"]
}
fields_to_pad = ["observations", "actions", "terminals"]

for ep_idx in tqdm(range(n_episodes)):
    ep_real_length = real_episode_lengths[ep_idx]
    ep_start_idx = PAD_RIGHT
    ep_end_idx = PAD_RIGHT + ep_real_length

    # ep_pad_right = PADDED_SEQ_LENGTH - (ep_real_length + PAD_RIGHT)
    for field_name in fields_to_pad:
        padded_dataset[field_name][ep_idx][ep_start_idx:ep_end_idx] = dataset[field_name][ep_idx]
        padded_dataset["depad_masks"][ep_idx][:(ep_end_idx-PAD_RIGHT)] = 1.
        padded_dataset["depad_slices"][ep_idx] = np.array([ep_start_idx, ep_end_idx], dtype=np.uint8)
        padded_dataset["unpadded_length"][ep_idx] = ep_real_length

  0%|          | 0/3486 [00:00<?, ?it/s]

In [None]:
print(dataset["actions"][0].shape, padded_dataset["actions"][0].shape)

assert padded_dataset["observations"].shape[:2] == padded_dataset["actions"].shape[:2], \
    f'Padded observations and actions dimension do not match: {padded_dataset["observations"].shape[:2]} vs {padded_dataset["actions"].shape[:2]}'

# dataset["actions"][0], padded_dataset["actions"][0]
# start_idx, end_idx = padded_dataset["depad_slices"][0][0], padded_dataset["depad_slices"][0][1]
assert np.array([(dataset["actions"][i] == padded_dataset["actions"][i][padded_dataset["depad_slices"][i][0]:padded_dataset["depad_slices"][i][1], :]).all() for i in range(n_episodes)]).all(), \
    "Test failed while trying to depad the actions"
assert np.array([(dataset["observations"][i] == padded_dataset["observations"][i][padded_dataset["depad_slices"][i][0]:padded_dataset["depad_slices"][i][1], :]).all() for i in range(n_episodes)]).all(), \
    "Test failed while trying to depad the observations"
assert np.array([(dataset["terminals"][i] == padded_dataset["terminals"][i][padded_dataset["depad_slices"][i][0]:padded_dataset["depad_slices"][i][1], :]).all() for i in range(n_episodes)]).all(), \
    "Test failed while trying to depad the terminals"
# padded_dataset["actions"][ep_idx].shape, padded_dataset["unpadded_length"][ep_idx], len(dataset["terminals"][ep_idx]), max_length, padded_dataset["depad_masks"][ep_idx].squeeze(-1)
assert np.array([np.where(padded_dataset["depad_masks"][ep_idx].squeeze(-1) == 1)[0].max() == (len(dataset["terminals"][ep_idx])-1) for ep_idx in range(n_episodes)]).all(), \
    "Test failed while trying to check the depad_masks validity"

# Additional test on the depad_masks: the number of 1. in the mask should match the real length of the episode
ep_idx = 0
np.array([int(padded_dataset["depad_masks"][ep_idx].sum()) == padded_dataset["unpadded_length"][ep_idx] for ep_idx in range(n_episodes)]).all(), \
    "Test failed while trying to check the depad_masks validity"

In [None]:
# Data set saving. Make sure all the testsa above are passed
np.savez_compressed(f"{padded_dataset_filename}.npz", **padded_dataset)