In [28]:
import sys
from pathlib import Path
import gym
import d4rl
import numpy as np

PROJECT_ROOT_DIR = Path().absolute().parent
PROJECT_ROOT_DIR

if str(PROJECT_ROOT_DIR.absolute()) not in sys.path:
    sys.path.append(str(PROJECT_ROOT_DIR.absolute()))

from models.sb3_model import PPOWithBCLoss
from utils.sb3_env_wrappers import ScaledObservationWrapper
from rollout import rollout, save_data_to_h5_file
from utils.load_data import load_data, scale_obs, split_data

In [29]:
RL_EXPERIMENT_NAMES = [
    "iter_1/human_pen_1e7steps_8envs_kl1e-1_loss_1",
    "iter_1/human_pen_1e7steps_8envs_kl1e-1_loss_2",
    "iter_1/human_pen_1e7steps_8envs_kl1e-1_loss_3",
    "iter_1/human_pen_1e7steps_8envs_kl1e-1_loss_4",
    "iter_1/human_pen_1e7steps_8envs_kl1e-1_loss_5",
]
ENV_NAME = "pen-human-v1"
deterministic_sample = False

In [30]:
origin_env = gym.make(ENV_NAME)
    
obs, acts, infos = load_data(origin_env)
scaled_obs, scaler = scale_obs(obs)

env = ScaledObservationWrapper(env=origin_env, scaler=scaler)

  logger.warn(
load datafile: 100%|██████████| 8/8 [00:00<00:00, 457.34it/s]


In [31]:
datasets = []
policy_save_dirs = [PROJECT_ROOT_DIR / "checkpoints" / "rl" / item for item in RL_EXPERIMENT_NAMES]
for policy_save_dir in policy_save_dirs:
    algo_ppo = PPOWithBCLoss.load(str((policy_save_dir / "best_model").absolute()))
    algo_ppo.policy.set_training_mode(False)
    dataset = rollout(
        algo_ppo.policy, 
        env, 
        max_path=1000, 
        num_data=int(1e6 / len(RL_EXPERIMENT_NAMES)), 
        deterministic=deterministic_sample
    )
    datasets.append(dataset)

Finished trajectory. Len=100, Returns=5120.775549. Progress:0/200000
Finished trajectory. Len=100, Returns=3377.530078. Progress:100/200000
Finished trajectory. Len=100, Returns=2901.854483. Progress:200/200000
Finished trajectory. Len=100, Returns=4013.691098. Progress:300/200000
Finished trajectory. Len=100, Returns=4537.171444. Progress:400/200000
Finished trajectory. Len=100, Returns=4149.457816. Progress:500/200000
Finished trajectory. Len=100, Returns=2988.685433. Progress:600/200000
Finished trajectory. Len=100, Returns=4368.918125. Progress:700/200000
Finished trajectory. Len=100, Returns=3921.046268. Progress:800/200000
Finished trajectory. Len=100, Returns=166.178103. Progress:900/200000
Finished trajectory. Len=100, Returns=4668.051391. Progress:1000/200000
Finished trajectory. Len=100, Returns=4114.096234. Progress:1100/200000
Finished trajectory. Len=100, Returns=4322.330695. Progress:1200/200000
Finished trajectory. Len=100, Returns=5037.298836. Progress:1300/200000
Finis

In [32]:
res = {}
for k in datasets[0].keys():
    res[k] = np.concatenate([dt[k] for dt in datasets])

In [33]:
res["observations"].shape

(1000000, 45)

In [34]:
save_data_to_h5_file(
    res, 
    str(PROJECT_ROOT_DIR / "rollout" / "data" / "iter_1" / (ENV_NAME + ("_best_sample" if deterministic_sample else "") + ".hdf5"))
)

In [35]:
from load_data import load_data
dataset = load_data(str(PROJECT_ROOT_DIR / "rollout" / "data" / "iter_1" / "pen-human-v1.hdf5"))
dataset

load datafile: 100%|██████████| 9/9 [00:04<00:00,  2.10it/s]


{'actions': array([[-1.1054758 ,  0.13096562,  1.0931482 , ...,  0.79916024,
          0.9697671 ,  0.81118095],
        [-1.1475093 ,  0.10892096,  1.0735291 , ...,  0.86924833,
          0.96660906,  0.8229687 ],
        [-1.1701081 ,  0.07452677,  1.0579853 , ...,  0.9646203 ,
          0.9837736 ,  0.765536  ],
        ...,
        [ 0.02724263, -0.72248864,  0.13412645, ...,  0.6933089 ,
         -0.9144265 , -1.0692381 ],
        [ 0.12028527, -0.78407896, -0.06322058, ...,  0.8675778 ,
         -0.9193287 , -1.1127439 ],
        [ 0.15655415, -0.79959774, -0.12202199, ...,  0.9007965 ,
         -0.9298039 , -1.1037769 ]], dtype=float32),
 'infos/action_log_probs': array([32.407917, 35.826454, 38.553032, ..., 35.43653 , 36.9997  ,
        38.57376 ], dtype=float32),
 'infos/qpos': array([[ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [-0.01715021, -0.00545185,  0.02351521, ...,  0.        ,
          0.        , -0.        ],

In [36]:
dataset["actions"].shape

(1000000, 24)