/
iss7.py
89 lines (79 loc) · 2.37 KB
/
iss7.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Issue 7 reproduction script."""
# %% Imports
# Third Party Imports
from gymnasium.spaces import Box, Dict, MultiDiscrete
from ray.air import RunConfig
from ray.rllib.algorithms import ppo
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
# Punch Clock Imports
from punchclock.common.utilities import loadJSONFile
from punchclock.nets.action_mask_model import MyActionMaskModel
from punchclock.ray.build_tuner import buildEnv
# %% Load config
config = loadJSONFile("issues/iss7/iss7_config.json")
# %% Register env and model
register_env("my_env", buildEnv)
ModelCatalog.register_custom_model("action_mask_model", MyActionMaskModel)
# %% Modify config (custom env)
run_config = RunConfig(**config["run_config"])
# Disable preprocessor
# config["param_space"]["model"]["_disable_preprocessor_api"] = True
# %% Random Env
env_random = RandomEnv(
{
"observation_space": Dict(
{
"observations": Box(0, 1, shape=(32,), dtype=float),
"action_mask": Box(0, 1, shape=(10,), dtype=int),
}
),
"action_space": MultiDiscrete([10]),
}
)
algo_config_rand = (
ppo.PPOConfig()
.training(
model={**config["param_space"]["model"]},
)
.environment(
env=RandomEnv,
env_config={
"observation_space": env_random.observation_space,
"action_space": env_random.action_space,
},
)
.framework("torch")
)
algo_random = algo_config_rand.build()
results = algo_random.training_step()
print(f"random env results : \n{results}")
# %% Custom Env
algo_config_customenv = (
ppo.PPOConfig()
.training(model={**config["param_space"]["model"]})
.environment(
env="my_env",
env_config=config["param_space"]["env_config"],
)
.framework("torch")
)
algo_customenv = algo_config_customenv.build()
try:
results = algo_customenv.training_step()
print(f"custom env results : \n{results}")
except Exception as e:
print(e)
# env = buildEnv(config["param_space"]["env_config"])
# obs = env.observation_space.sample()
# action = algo.compute_single_action(obs)
# tuner = Tuner(
# trainable="PPO",
# param_space=config["param_space"],
# run_config=run_config,
# tune_config=config["tune_config"],
# )
# tuner.fit()
# %% done
print("done")