Skip to content

Commit

Permalink
Refine env_sampler & policy manager
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuoran committed Oct 14, 2021
1 parent cf1f70e commit b822d0a
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 168 deletions.
14 changes: 7 additions & 7 deletions examples/rl/cim_v2/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def get_state(self, tick=None):
value in ``state_shaping_conf``), as well as all downstream port features.
"""
if tick is None:
tick = self.env.tick
vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
tick = self._env.tick
vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
Expand All @@ -52,8 +52,8 @@ def get_env_actions(self, action_by_agent):

port_idx, action = list(action_by_agent.items()).pop()
vsl_idx, action_scope = self.event.vessel_idx, self.event.action_scope
vsl_snapshots = self.env.snapshot_list["vessels"]
vsl_space = vsl_snapshots[self.env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")
vsl_snapshots = self._env.snapshot_list["vessels"]
vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")

model_action = action["action"] if isinstance(action, dict) else action
percent = abs(action_space[model_action])
Expand All @@ -63,7 +63,7 @@ def get_env_actions(self, action_by_agent):
actual_action = min(round(percent * action_scope.load), vsl_space)
elif model_action > zero_action_idx:
action_type = ActionType.DISCHARGE
early_discharge = vsl_snapshots[self.env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
else:
Expand All @@ -84,7 +84,7 @@ def get_reward(self, actions, tick):

# Get the ports that took actions at the given tick
ports = [action.port_idx for action in actions]
port_snapshots = self.env.snapshot_list["ports"]
port_snapshots = self._env.snapshot_list["ports"]
future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)

Expand All @@ -101,7 +101,7 @@ def post_step(self, state, action, env_action, reward, tick):
be used to record any information one wishes to keep track of during a roll-out episode. Here we simply record
the latest env metric without keeping the history for logging purposes.
"""
self.tracker["env_metric"] = self.env.metrics
self._tracker["env_metric"] = self._env.metrics


agent2policy = {agent: f"{algorithm}.{agent}" for agent in Env(**env_conf).agent_idx_list}
Expand Down
Loading

0 comments on commit b822d0a

Please sign in to comment.