Skip to content

Commit

Permalink
Working fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eidelen committed May 28, 2023
1 parent d70af5f commit 16801c0
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
9 changes: 7 additions & 2 deletions bomberworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def is_valid_pos(self, pos: Tuple[int, int]) -> bool:
def can_move_to_pos(self, pos: Tuple[int, int]) -> bool:
return self.is_valid_pos(pos) and (not self.stones[pos])

def make_observation_2D(self) -> np.ndarray:
def make_current_board_2D(self) -> np.ndarray:
board = np.zeros((self.size, self.size), dtype=np.float32)
# set rocks
for m, n in np.ndindex(self.stones.shape):
Expand All @@ -95,8 +95,13 @@ def make_observation_2D(self) -> np.ndarray:
for bomb_pos, _ in self.active_bombs:
board[bomb_pos] = self.bomb_val
# set agent
board[self.agent_pos] = self.bomb_and_agent_val if self.is_active_bomb_on_field(self.agent_pos) else self.agent_val
board[self.agent_pos] = self.bomb_and_agent_val if self.is_active_bomb_on_field(
self.agent_pos) else self.agent_val

return board

def make_observation_2D(self) -> np.ndarray:
board = self.make_current_board_2D()
if self.reduced_obs: # cut 3x3 patch around agent
m_ap, n_ap = self.agent_pos
m_center = max(1, m_ap)
Expand Down
21 changes: 11 additions & 10 deletions do_bombing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,29 @@ def run_bombing(path_to_checkpoint: str):
trained_policy = Policy.from_checkpoint(path_to_checkpoint)
model_config = trained_policy.model.model_config

# hack to make lstm work -> does not work: 'PPOTorchPolicy' object is not subscriptable
transformer_attention_size = model_config["attention_dim"]
transformer_memory_size = model_config["attention_memory_inference"]
transformer_layer_size = np.zeros([transformer_memory_size, transformer_attention_size])
transformer_length = model_config["attention_num_transformer_units"]
state_list = transformer_length * [transformer_layer_size]
initial_state_list = state_list
# hack to make lstm work
cell_size = 256
lstm_states = [np.zeros(cell_size, np.float32),
np.zeros(cell_size, np.float32)]
# end hack

env = bomberworld.BomberworldEnv(6, 40, dead_when_colliding=True, reduced_obs=True)
o, info = env.reset()

plotter = BomberworldPlotter(size=env.size, animated_gif_folder_path="gifs")
plotter.add_frame(env.agent_pos, None, None, env.make_observation_2D())
plotter.add_frame(env.agent_pos, None, None, env.make_current_board_2D())

reward_sum = 0
terminated, truncated = False, False
while not (terminated or truncated):
a = trained_policy.compute_single_action(o)[0] # When using lstm -> "assert seq_lens is not None" : https://github.com/ray-project/ray/issues/10448#issuecomment-1151468435

# Hack to make lstm work
a, next_states, _ = trained_policy.compute_single_action(o, state=lstm_states) # When using lstm -> "assert seq_lens is not None" : https://github.com/ray-project/ray/issues/10448#issuecomment-1151468435
lstm_states = next_states

o, r, terminated, truncated, info = env.step(a)
reward_sum += r
plotter.add_frame(agent_position=env.agent_pos, placed_bomb=info["placed_bomb"], exploded_bomb=info["exploded_bomb"], stones=env.make_observation_2D())
plotter.add_frame(agent_position=env.agent_pos, placed_bomb=info["placed_bomb"], exploded_bomb=info["exploded_bomb"], stones=env.make_current_board_2D())
plotter.plot_episode(current_reward=reward_sum)
print("Current Reward:", reward_sum)
print("Overall Reward:", reward_sum)
Expand Down
14 changes: 12 additions & 2 deletions solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def env_create(env_config: EnvContext):
return bomberworld.BomberworldEnv(**env_config)

def print_ppo_configs(config):
print("Ray Version:", ray.__version__)
print("clip_param", config.clip_param)
print("gamma", config.gamma)
print("lr", config.lr)
Expand All @@ -36,7 +37,16 @@ def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc:

config.model['fcnet_hiddens'] = nn_model
config.model['fcnet_activation'] = activation
config.model['use_lstm'] = True,

# another help -> https://github.com/ray-project/ray/issues/9220
config.model['use_lstm'] = True
# Max seq len for training the LSTM, defaults to 20.
config.model['max_seq_len'] = 20
# Size of the LSTM cell.
config.model['lstm_cell_size'] = 256
# Whether to feed a_{t-1}, r_{t-1} to LSTM.
config.model['lstm_use_prev_reward'] = False
config.model['lstm_use_prev_action'] = False

config = config.rollouts(num_rollout_workers=train_hw["cpu"])
config = config.training( gamma=ray.tune.grid_search([0.75, 0.80, 0.85, 0.90, 0.95, 0.997])) # lr=ray.tune.grid_search([5e-05, 4e-05])) #, gamma=ray.tune.grid_search([0.99])) , lambda_=ray.tune.grid_search([1.0, 0.997, 0.95]))
Expand All @@ -53,7 +63,7 @@ def grid_search_hypers(env_params: dict, nn_model: list, activation: str, desc:
local_dir="out",
verbose=2,
stop=MaximumIterationStopper(200),
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=100)
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=5)
)
)

Expand Down

0 comments on commit 16801c0

Please sign in to comment.