Skip to content

Commit

Permalink
Merge pull request #1 from eidelen/lstm
Browse files Browse the repository at this point in the history
Lstm
  • Loading branch information
eidelen committed May 29, 2023
2 parents d42a375 + 5dbce36 commit 3a9ac9b
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 14 deletions.
9 changes: 7 additions & 2 deletions bomberworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def is_valid_pos(self, pos: Tuple[int, int]) -> bool:
def can_move_to_pos(self, pos: Tuple[int, int]) -> bool:
return self.is_valid_pos(pos) and (not self.stones[pos])

def make_observation_2D(self) -> np.ndarray:
def make_current_board_2D(self) -> np.ndarray:
board = np.zeros((self.size, self.size), dtype=np.float32)
# set rocks
for m, n in np.ndindex(self.stones.shape):
Expand All @@ -95,8 +95,13 @@ def make_observation_2D(self) -> np.ndarray:
for bomb_pos, _ in self.active_bombs:
board[bomb_pos] = self.bomb_val
# set agent
board[self.agent_pos] = self.bomb_and_agent_val if self.is_active_bomb_on_field(self.agent_pos) else self.agent_val
board[self.agent_pos] = self.bomb_and_agent_val if self.is_active_bomb_on_field(
self.agent_pos) else self.agent_val

return board

def make_observation_2D(self) -> np.ndarray:
board = self.make_current_board_2D()
if self.reduced_obs: # cut 3x3 patch around agent
m_ap, n_ap = self.agent_pos
m_center = max(1, m_ap)
Expand Down
24 changes: 18 additions & 6 deletions do_bombing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,38 @@
# Author: Adrian Schneider, armasuisse

from ray.rllib.policy.policy import Policy
import numpy as np
import argparse
import bomberworld
from bomberworld_plotter import BomberworldPlotter

def run_bombing(path_to_checkpoint: str):
def run_bombing(path_to_checkpoint: str, use_lstm: bool):

trained_policy = Policy.from_checkpoint(path_to_checkpoint)
env = bomberworld.BomberworldEnv(10, 150, dead_when_colliding=True, dead_near_bomb=True, indestructible_agent=False, close_bomb_penalty=-1.0)

if use_lstm: # set initial blank lstm states
cell_size = 256
lstm_states = [np.zeros(cell_size, np.float32), np.zeros(cell_size, np.float32)]

env = bomberworld.BomberworldEnv(8, 60, dead_when_colliding=True, reduced_obs=True)
o, info = env.reset()

plotter = BomberworldPlotter(size=env.size, animated_gif_folder_path="gifs")
plotter.add_frame(env.agent_pos, None, None, env.make_observation_2D())
plotter.add_frame(env.agent_pos, None, None, env.make_current_board_2D())

reward_sum = 0
terminated, truncated = False, False
while not (terminated or truncated):
a = trained_policy.compute_single_action(o)[0]

if use_lstm:
a, next_states, _ = trained_policy.compute_single_action(o, state=lstm_states)
lstm_states = next_states # update lstm states
else:
a = trained_policy.compute_single_action(o)[0]

o, r, terminated, truncated, info = env.step(a)
reward_sum += r
plotter.add_frame(agent_position=env.agent_pos, placed_bomb=info["placed_bomb"], exploded_bomb=info["exploded_bomb"], stones=env.make_observation_2D())
plotter.add_frame(agent_position=env.agent_pos, placed_bomb=info["placed_bomb"], exploded_bomb=info["exploded_bomb"], stones=env.make_current_board_2D())
plotter.plot_episode(current_reward=reward_sum)
print("Current Reward:", reward_sum)
print("Overall Reward:", reward_sum)
Expand All @@ -34,4 +46,4 @@ def run_bombing(path_to_checkpoint: str):
description='Runs bombing model')
parser.add_argument('path', help='File path to checkpoint')
args = parser.parse_args()
run_bombing(args.path)
run_bombing(args.path, use_lstm=True)
Binary file added rsc/trained6x6-put-10x10.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rsc/trained6x6-put-6x6.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rsc/trained6x6-put-8x8.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 18 additions & 6 deletions solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ def env_create(env_config: EnvContext):
return bomberworld.BomberworldEnv(**env_config)

def print_ppo_configs(config):
print("Ray Version:", ray.__version__)
print("clip_param", config.clip_param)
print("gamma", config.gamma)
print("lr", config.lr)
print("lamda", config.lambda_)

def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc: str, train_hw: dict):
def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc: str, train_hw: dict, use_lstm: bool):
register_env("GridworldEnv", env_create)

config = PPOConfig()
Expand All @@ -37,8 +38,19 @@ def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc:
config.model['fcnet_hiddens'] = nn_model
config.model['fcnet_activation'] = activation

if use_lstm:
# another help -> https://github.com/ray-project/ray/issues/9220
config.model['use_lstm'] = True
# Max seq len for training the LSTM, defaults to 20.
config.model['max_seq_len'] = 20
# Size of the LSTM cell.
config.model['lstm_cell_size'] = 256
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
config.model['lstm_use_prev_reward'] = False
config.model['lstm_use_prev_action'] = False

config = config.rollouts(num_rollout_workers=train_hw["cpu"])
config = config.training( gamma=ray.tune.grid_search([0.75, 0.80, 0.85, 0.90, 0.95, 0.997])) # lr=ray.tune.grid_search([5e-05, 4e-05])) #, gamma=ray.tune.grid_search([0.99])) , lambda_=ray.tune.grid_search([1.0, 0.997, 0.95]))
config = config.training(gamma=0.75) # lr=ray.tune.grid_search([5e-05, 4e-05])) #, gamma=ray.tune.grid_search([0.99])) , lambda_=ray.tune.grid_search([1.0, 0.997, 0.95]))

config = config.debugging(log_level="ERROR")

Expand All @@ -51,8 +63,8 @@ def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc:
name=experiment_name,
local_dir="out",
verbose=2,
stop=MaximumIterationStopper(200),
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=100)
stop=MaximumIterationStopper(100000),
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=200)
)
)

Expand Down Expand Up @@ -80,12 +92,12 @@ def resume_training():
hw = {"gpu": 0, "cpu": 3} # imac
#hw = {"gpu": 1, "cpu": 11} # adris

env_params = {"size": 6, "max_steps": 40, "reduced_obs": True, "dead_when_colliding": True}
env_params = {"size": 6, "max_steps": 60, "reduced_obs": True, "dead_when_colliding": True, "indestructible_agent": False, "dead_near_bomb": True}
#env_params = {"size": 10, "max_steps": 100, "indestructible_agent": False, "dead_near_bomb": True}
# env_params = {"size": 10, "max_steps": 200, "dead_when_colliding": True, "dead_near_bomb": True, "indestructible_agent": False, "close_bomb_penalty": -1.0}
nn_model = [256, 128, 64]
activation = "relu"
description = "ReducedBomber-Hyper"
description = "ReducedSmartBomber-6x6-Gamma=0.75-LSTM"

grid_search_hypers(env_params, nn_model, activation, description, hw)

0 comments on commit 3a9ac9b

Please sign in to comment.