Skip to content

Commit

Permalink
RL policy redesign (V2) (#405)
Browse files Browse the repository at this point in the history
* Drafi v2.0 for V2

* Polish models with more comments

* Polish policies with more comments

* Lint

* Lint

* Add developer doc for models.

* Add developer doc for policies.

* Remove policy manager V2 since it is not used and out-of-date

* Lint

* Lint
  • Loading branch information
lihuoran committed Nov 1, 2021
1 parent 3e38eb7 commit 7e226cc
Show file tree
Hide file tree
Showing 27 changed files with 3,797 additions and 84 deletions.
6 changes: 3 additions & 3 deletions examples/rl/cim/env_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_state(self, tick=None):
vessel_snapshots, port_snapshots = self.env.snapshot_list["vessels"], self.env.snapshot_list["ports"]
port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
state = np.concatenate([
port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
vessel_snapshots[tick : vessel_idx : vessel_attributes]
Expand All @@ -55,7 +55,7 @@ def get_env_actions(self, action_by_agent):
vsl_snapshots = self.env.snapshot_list["vessels"]
vsl_space = vsl_snapshots[self.env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")

model_action = action["action"] if isinstance(action, dict) else action
model_action = action["action"] if isinstance(action, dict) else action
percent = abs(action_space[model_action])
zero_action_idx = len(action_space) / 2 # index corresponding to value zero.
if model_action < zero_action_idx:
Expand Down Expand Up @@ -112,5 +112,5 @@ def get_env_sampler():
get_policy_func_dict=policy_func_dict,
agent2policy=agent2policy,
reward_eval_delay=reward_shaping_conf["time_window"],
parallel_inference=True
parallel_inference=False
)
9 changes: 9 additions & 0 deletions examples/rl/cim_v2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Container Inventory Management

This example demonstrates the use of MARO's RL toolkit to optimize container inventory management. The scenario consists of a set of ports, each acting as a learning agent, and vessels that transfer empty containers among them. Each port must decide 1) whether to load or discharge containers when a vessel arrives and 2) how many containers to be loaded or discharged. The objective is to minimize the overall container shortage over a certain period of time. In this folder you can find:
* ``config.py``, which contains environment and policy configurations for the scenario;
* ``env_sampler.py``, which defines state, action and reward shaping in the ``CIMEnvSampler`` class;
* ``policies.py``, which defines the Q-net for DQN and the network components for Actor-Critic;
* ``callbacks.py``, which defines routines to be invoked at the end of training or evaluation episodes.

The scripts for running the learning workflows can be found under ``examples/rl/workflows``. See ``README`` under ``examples/rl`` for details about the general applicability of these scripts. We recommend that you follow this example to write your own scenarios.
8 changes: 8 additions & 0 deletions examples/rl/cim_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .callbacks import post_collect, post_evaluate
from .env_sampler import agent2policy, get_env_sampler
from .policies import policy_func_dict

__all__ = ["agent2policy", "post_collect", "post_evaluate", "get_env_sampler", "policy_func_dict"]
33 changes: 33 additions & 0 deletions examples/rl/cim_v2/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import time
from os import makedirs
from os.path import dirname, join, realpath

log_dir = join(dirname(realpath(__file__)), "log", str(time.time()))
makedirs(log_dir, exist_ok=True)


def post_collect(trackers, ep, segment):
# print the env metric from each rollout worker
for tracker in trackers:
print(f"env summary (episode {ep}, segment {segment}): {tracker['env_metric']}")

# print the average env metric
if len(trackers) > 1:
metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers)
avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys}
print(f"average env summary (episode {ep}, segment {segment}): {avg_metric}")


def post_evaluate(trackers, ep):
# print the env metric from each rollout worker
for tracker in trackers:
print(f"env summary (episode {ep}): {tracker['env_metric']}")

# print the average env metric
if len(trackers) > 1:
metric_keys, num_trackers = trackers[0]["env_metric"].keys(), len(trackers)
avg_metric = {key: sum(tr["env_metric"][key] for tr in trackers) / num_trackers for key in metric_keys}
print(f"average env summary (episode {ep}): {avg_metric}")
125 changes: 125 additions & 0 deletions examples/rl/cim_v2/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
from torch.optim import Adam, RMSprop

from maro.rl.exploration import MultiLinearExplorationScheduler, epsilon_greedy


env_conf = {
"scenario": "cim",
"topology": "toy.4p_ssdd_l0.0",
"durations": 560
}

port_attributes = ["empty", "full", "on_shipper", "on_consignee", "booking", "shortage", "fulfillment"]
vessel_attributes = ["empty", "full", "remaining_space"]

state_shaping_conf = {
"look_back": 7,
"max_ports_downstream": 2
}

action_shaping_conf = {
"action_space": [(i - 10) / 10 for i in range(21)],
"finite_vessel_space": True,
"has_early_discharge": True
}

reward_shaping_conf = {
"time_window": 99,
"fulfillment_factor": 1.0,
"shortage_factor": 1.0,
"time_decay": 0.97
}

# obtain state dimension from a temporary env_wrapper instance
state_dim = (
(state_shaping_conf["look_back"] + 1) * (state_shaping_conf["max_ports_downstream"] + 1) * len(port_attributes)
+ len(vessel_attributes)
)

############################################## POLICIES ###############################################

algorithm = "ac"

# DQN settings
q_net_conf = {
"input_dim": state_dim,
"hidden_dims": [256, 128, 64, 32],
"output_dim": len(action_shaping_conf["action_space"]),
"activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"skip_connection": False,
"head": True,
"dropout_p": 0.0
}

q_net_optim_conf = (RMSprop, {"lr": 0.05})

dqn_conf = {
"reward_discount": .0,
"update_target_every": 5,
"num_epochs": 10,
"soft_update_coef": 0.1,
"double": False,
"exploration_strategy": (epsilon_greedy, {"epsilon": 0.4}),
"exploration_scheduling_options": [(
"epsilon", MultiLinearExplorationScheduler, {
"splits": [(2, 0.32)],
"initial_value": 0.4,
"last_ep": 5,
"final_value": 0.0,
}
)],
"replay_memory_capacity": 10000,
"random_overwrite": False,
"warmup": 100,
"rollout_batch_size": 128,
"train_batch_size": 32,
# "prioritized_replay_kwargs": {
# "alpha": 0.6,
# "beta": 0.4,
# "beta_step": 0.001,
# "max_priority": 1e8
# }
}


# AC settings
actor_net_conf = {
"input_dim": state_dim,
"hidden_dims": [256, 128, 64],
"output_dim": len(action_shaping_conf["action_space"]),
"activation": torch.nn.Tanh,
"softmax": True,
"batch_norm": False,
"head": True
}

critic_net_conf = {
"input_dim": state_dim,
"hidden_dims": [256, 128, 64],
"output_dim": 1,
"activation": torch.nn.LeakyReLU,
"softmax": False,
"batch_norm": True,
"head": True
}

actor_optim_conf = (Adam, {"lr": 0.001})
critic_optim_conf = (RMSprop, {"lr": 0.001})

ac_conf = {
"reward_discount": .0,
"grad_iters": 10,
"critic_loss_cls": torch.nn.SmoothL1Loss,
"min_logp": None,
"critic_loss_coef": 0.1,
"entropy_coef": 0.01,
# "clip_ratio": 0.8 # for PPO
"lam": .0,
"get_loss_on_rollout": False
}
116 changes: 116 additions & 0 deletions examples/rl/cim_v2/env_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import sys

import numpy as np

from maro.rl.learning.env_sampler_v2 import AbsEnvSampler
from maro.simulator import Env
from maro.simulator.scenarios.cim.common import Action, ActionType

cim_path = os.path.dirname(os.path.realpath(__file__))
if cim_path not in sys.path:
sys.path.insert(0, cim_path)

from config import (
action_shaping_conf, algorithm, env_conf, port_attributes, reward_shaping_conf, state_shaping_conf,
vessel_attributes
)
from policies import policy_func_dict


class CIMEnvSampler(AbsEnvSampler):
def get_state(self, tick=None):
"""
The state vector includes shortage and remaining vessel space over the past k days (where k is the "look_back"
value in ``state_shaping_conf``), as well as all downstream port features.
"""
if tick is None:
tick = self._env.tick
vessel_snapshots, port_snapshots = self._env.snapshot_list["vessels"], self._env.snapshot_list["ports"]
port_idx, vessel_idx = self.event.port_idx, self.event.vessel_idx
ticks = [max(0, tick - rt) for rt in range(state_shaping_conf["look_back"] - 1)]
future_port_list = vessel_snapshots[tick: vessel_idx: 'future_stop_list'].astype('int')
state = np.concatenate([
port_snapshots[ticks : [port_idx] + list(future_port_list) : port_attributes],
vessel_snapshots[tick : vessel_idx : vessel_attributes]
])
return {port_idx: state}

def get_env_actions(self, action_by_agent):
"""
The policy output is an integer from [0, 20] which is to be interpreted as the index of ``action_space`` in
``action_shaping_conf``. For example, action 5 corresponds to -0.5, which means loading 50% of the containers
available at the current port to the vessel, while action 18 corresponds to 0.8, which means loading 80% of the
containers on the vessel to the port. Note that action 10 corresponds 0.0, which means doing nothing.
"""
action_space = action_shaping_conf["action_space"]
finite_vsl_space = action_shaping_conf["finite_vessel_space"]
has_early_discharge = action_shaping_conf["has_early_discharge"]

port_idx, action = list(action_by_agent.items()).pop()
vsl_idx, action_scope = self.event.vessel_idx, self.event.action_scope
vsl_snapshots = self._env.snapshot_list["vessels"]
vsl_space = vsl_snapshots[self._env.tick:vsl_idx:vessel_attributes][2] if finite_vsl_space else float("inf")

model_action = action["action"] if isinstance(action, dict) else action
percent = abs(action_space[model_action])
zero_action_idx = len(action_space) / 2 # index corresponding to value zero.
if model_action < zero_action_idx:
action_type = ActionType.LOAD
actual_action = min(round(percent * action_scope.load), vsl_space)
elif model_action > zero_action_idx:
action_type = ActionType.DISCHARGE
early_discharge = vsl_snapshots[self._env.tick:vsl_idx:"early_discharge"][0] if has_early_discharge else 0
plan_action = percent * (action_scope.discharge + early_discharge) - early_discharge
actual_action = round(plan_action) if plan_action > 0 else round(percent * action_scope.discharge)
else:
actual_action, action_type = 0, None

return [Action(port_idx=port_idx, vessel_idx=vsl_idx, quantity=actual_action, action_type=action_type)]

def get_reward(self, actions, tick):
"""
The reward is defined as a linear combination of fulfillment and shortage measures. The fulfillment and
shortage measures are the sums of fulfillment and shortage values over the next k days, respectively, each
adjusted with exponential decay factors (using the "time_decay" value in ``reward_shaping_conf``) to put more
emphasis on the near future. Here k is the "time_window" value in ``reward_shaping_conf``. The linear
combination coefficients are given by "fulfillment_factor" and "shortage_factor" in ``reward_shaping_conf``.
"""
start_tick = tick + 1
ticks = list(range(start_tick, start_tick + reward_shaping_conf["time_window"]))

# Get the ports that took actions at the given tick
ports = [action.port_idx for action in actions]
port_snapshots = self._env.snapshot_list["ports"]
future_fulfillment = port_snapshots[ticks:ports:"fulfillment"].reshape(len(ticks), -1)
future_shortage = port_snapshots[ticks:ports:"shortage"].reshape(len(ticks), -1)

decay_list = [reward_shaping_conf["time_decay"] ** i for i in range(reward_shaping_conf["time_window"])]
rewards = np.float32(
reward_shaping_conf["fulfillment_factor"] * np.dot(future_fulfillment.T, decay_list)
- reward_shaping_conf["shortage_factor"] * np.dot(future_shortage.T, decay_list)
)
return {agent_id: reward for agent_id, reward in zip(ports, rewards)}

def post_step(self, state, action, env_action, reward, tick):
"""
The environment sampler contains a "tracker" dict inherited from the "AbsEnvSampler" base class, which can
be used to record any information one wishes to keep track of during a roll-out episode. Here we simply record
the latest env metric without keeping the history for logging purposes.
"""
self._tracker["env_metric"] = self._env.metrics


agent2policy = {agent: f"{algorithm}.{agent}" for agent in Env(**env_conf).agent_idx_list}

def get_env_sampler():
return CIMEnvSampler(
get_env=lambda: Env(**env_conf),
get_policy_func_dict=policy_func_dict,
agent2policy=agent2policy,
reward_eval_delay=reward_shaping_conf["time_window"],
parallel_inference=False
)
Loading

0 comments on commit 7e226cc

Please sign in to comment.