Skip to content

Commit

Permalink
Ready to merge
Browse files Browse the repository at this point in the history
  • Loading branch information
eidelen committed May 29, 2023
1 parent 16801c0 commit 5dbce36
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 27 deletions.
23 changes: 11 additions & 12 deletions do_bombing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,15 @@
import bomberworld
from bomberworld_plotter import BomberworldPlotter

def run_bombing(path_to_checkpoint: str):
def run_bombing(path_to_checkpoint: str, use_lstm: bool):

trained_policy = Policy.from_checkpoint(path_to_checkpoint)
model_config = trained_policy.model.model_config

# hack to make lstm work
cell_size = 256
lstm_states = [np.zeros(cell_size, np.float32),
np.zeros(cell_size, np.float32)]
# end hack
if use_lstm: # set initial blank lstm states
cell_size = 256
lstm_states = [np.zeros(cell_size, np.float32), np.zeros(cell_size, np.float32)]

env = bomberworld.BomberworldEnv(6, 40, dead_when_colliding=True, reduced_obs=True)
env = bomberworld.BomberworldEnv(8, 60, dead_when_colliding=True, reduced_obs=True)
o, info = env.reset()

plotter = BomberworldPlotter(size=env.size, animated_gif_folder_path="gifs")
Expand All @@ -28,9 +25,11 @@ def run_bombing(path_to_checkpoint: str):
terminated, truncated = False, False
while not (terminated or truncated):

# Hack to make lstm work
a, next_states, _ = trained_policy.compute_single_action(o, state=lstm_states) # When using lstm -> "assert seq_lens is not None" : https://github.com/ray-project/ray/issues/10448#issuecomment-1151468435
lstm_states = next_states
if use_lstm:
a, next_states, _ = trained_policy.compute_single_action(o, state=lstm_states)
lstm_states = next_states # update lstm states
else:
a = trained_policy.compute_single_action(o)[0]

o, r, terminated, truncated, info = env.step(a)
reward_sum += r
Expand All @@ -47,4 +46,4 @@ def run_bombing(path_to_checkpoint: str):
description='Runs bombing model')
parser.add_argument('path', help='File path to checkpoint')
args = parser.parse_args()
run_bombing(args.path)
run_bombing(args.path, use_lstm=True)
Binary file added rsc/trained6x6-put-10x10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rsc/trained6x6-put-6x6.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rsc/trained6x6-put-8x8.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 16 additions & 15 deletions solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def print_ppo_configs(config):
print("lr", config.lr)
print("lamda", config.lambda_)

def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc: str, train_hw: dict):
def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc: str, train_hw: dict, use_lstm: bool):
register_env("GridworldEnv", env_create)

config = PPOConfig()
Expand All @@ -38,18 +38,19 @@ def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc:
config.model['fcnet_hiddens'] = nn_model
config.model['fcnet_activation'] = activation

# another help -> https://github.com/ray-project/ray/issues/9220
config.model['use_lstm'] = True
# Max seq len for training the LSTM, defaults to 20.
config.model['max_seq_len'] = 20
# Size of the LSTM cell.
config.model['lstm_cell_size'] = 256
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
config.model['lstm_use_prev_reward'] = False
config.model['lstm_use_prev_action'] = False
if use_lstm:
# another help -> https://github.com/ray-project/ray/issues/9220
config.model['use_lstm'] = True
# Max seq len for training the LSTM, defaults to 20.
config.model['max_seq_len'] = 20
# Size of the LSTM cell.
config.model['lstm_cell_size'] = 256
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
config.model['lstm_use_prev_reward'] = False
config.model['lstm_use_prev_action'] = False

config = config.rollouts(num_rollout_workers=train_hw["cpu"])
config = config.training( gamma=ray.tune.grid_search([0.75, 0.80, 0.85, 0.90, 0.95, 0.997])) # lr=ray.tune.grid_search([5e-05, 4e-05])) #, gamma=ray.tune.grid_search([0.99])) , lambda_=ray.tune.grid_search([1.0, 0.997, 0.95]))
config = config.training(gamma=0.75) # lr=ray.tune.grid_search([5e-05, 4e-05])) #, gamma=ray.tune.grid_search([0.99])) , lambda_=ray.tune.grid_search([1.0, 0.997, 0.95]))

config = config.debugging(log_level="ERROR")

Expand All @@ -62,8 +63,8 @@ def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc:
name=experiment_name,
local_dir="out",
verbose=2,
stop=MaximumIterationStopper(200),
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=5)
stop=MaximumIterationStopper(100000),
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=200)
)
)

Expand Down Expand Up @@ -91,12 +92,12 @@ def resume_training():
hw = {"gpu": 0, "cpu": 3} # imac
#hw = {"gpu": 1, "cpu": 11} # adris

env_params = {"size": 6, "max_steps": 40, "reduced_obs": True, "dead_when_colliding": True}
env_params = {"size": 6, "max_steps": 60, "reduced_obs": True, "dead_when_colliding": True, "indestructible_agent": False, "dead_near_bomb": True}
#env_params = {"size": 10, "max_steps": 100, "indestructible_agent": False, "dead_near_bomb": True}
# env_params = {"size": 10, "max_steps": 200, "dead_when_colliding": True, "dead_near_bomb": True, "indestructible_agent": False, "close_bomb_penalty": -1.0}
nn_model = [256, 128, 64]
activation = "relu"
description = "ReducedBomber-Hyper"
description = "ReducedSmartBomber-6x6-Gamma=0.75-LSTM"

grid_search_hypers(env_params, nn_model, activation, description, hw)

0 comments on commit 5dbce36

Please sign in to comment.